OnDiskDataset
- class dgl.graphbolt.OnDiskDataset(path: str, include_original_edge_id: bool = False, force_preprocess: bool | None = None, auto_cast_to_optimal_dtype: bool = True)[source]
基类:
Dataset
一种磁盘数据集,它从磁盘读取图拓扑、特征数据以及训练/验证/测试集。
由于资源有限,对于过大而无法载入 RAM 的数据将保留在磁盘上,而其他数据在
OnDiskDataset
初始化后驻留在 RAM 中。用户可以通过 YAML 文件中的in_memory
字段控制此行为。YAML 文件中的所有路径都是相对于数据集目录的相对路径。YAML 文件的一个完整示例如下
dataset_name: graphbolt_test graph: nodes: - type: paper # could be omitted for homogeneous graph. num: 1000 - type: author num: 1000 edges: - type: author:writes:paper # could be omitted for homogeneous graph. format: csv # Can be csv only. path: edge_data/author-writes-paper.csv - type: paper:cites:paper format: csv path: edge_data/paper-cites-paper.csv feature_data: - domain: node type: paper # could be omitted for homogeneous graph. name: feat format: numpy in_memory: false # If not specified, default to true. path: node_data/paper-feat.npy - domain: edge type: "author:writes:paper" name: feat format: numpy in_memory: false path: edge_data/author-writes-paper-feat.npy tasks: - name: "edge_classification" num_classes: 10 train_set: - type: paper # could be omitted for homogeneous graph. data: # multiple data sources could be specified. - name: seeds format: numpy # Can be numpy or torch. in_memory: true # If not specified, default to true. path: set/paper-train-seeds.npy - name: labels format: numpy path: set/paper-train-labels.npy validation_set: - type: paper data: - name: seeds format: numpy path: set/paper-validation-seeds.npy - name: labels format: numpy path: set/paper-validation-labels.npy test_set: - type: paper data: - name: seeds format: numpy path: set/paper-test-seeds.npy - name: labels format: numpy path: set/paper-test-labels.npy
- 参数:
- load(tasks: List[str] | None = None)[source]
加载数据集。
- 参数:
tasks (List[str] = None) – 要加载的任务名称。对于单个任务,tasks 的类型可以是字符串或 List[str]。对于多个任务,只接受 List[str]。
示例
1. 通过单个任务名称 “node_classification” 加载。
>>> dataset = gb.OnDiskDataset(base_dir).load( ... tasks="node_classification") >>> len(dataset.tasks) 1 >>> dataset.tasks[0].metadata["name"] "node_classification"
2. 通过单个任务名称 [“node_classification”] 加载。
>>> dataset = gb.OnDiskDataset(base_dir).load( ... tasks=["node_classification"]) >>> len(dataset.tasks) 1 >>> dataset.tasks[0].metadata["name"] "node_classification"
3. 通过多个任务名称 [“node_classification”, “link_prediction”] 加载。
>>> dataset = gb.OnDiskDataset(base_dir).load( ... tasks=["node_classification","link_prediction"]) >>> len(dataset.tasks) 2 >>> dataset.tasks[0].metadata["name"] "node_classification" >>> dataset.tasks[1].metadata["name"] "link_prediction"
- property all_nodes_set: ItemSet | HeteroItemSet
返回包含所有节点的 ItemSet。
- property feature: TorchBasedFeatureStore
返回特征。
- property graph: SamplingGraph
返回图。