同构图的 OnDiskDataset

Open In Colab GitHub

本教程展示了如何为同构图创建可用于 GraphBolt 框架的 OnDiskDataset

学完本教程后,您将能够

  • 组织图结构数据。

  • 组织特征数据。

  • 为特定任务组织训练/验证/测试集。

要创建一个 OnDiskDataset 对象,您需要将所有数据(包括图结构、特征数据和任务)组织到一个目录中。该目录应包含一个 metadata.yaml 文件,该文件描述了数据集的元数据。

现在,让我们一步步生成各种数据并将它们组织起来,最终实例化 OnDiskDataset

安装 DGL 包

[1]:
# Install required packages.
import os
import torch
import numpy as np
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

# Install the CPU version.
device = torch.device("cpu")
!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html

try:
    import dgl
    import dgl.graphbolt as gb
    installed = True
except ImportError as error:
    installed = False
    print(error)
print("DGL installed!" if installed else "DGL not found!")
Looking in links: https://data.dgl.ai/wheels-test/repo.html
Requirement already satisfied: dgl in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (2.2a240410)
Requirement already satisfied: numpy>=1.14.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (1.26.4)
Requirement already satisfied: scipy>=1.1.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (1.14.1)
Requirement already satisfied: networkx>=2.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (3.4.2)
Requirement already satisfied: requests>=2.19.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (2.32.3)
Requirement already satisfied: tqdm in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (4.66.6)
Requirement already satisfied: psutil>=5.8.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (6.1.0)
Requirement already satisfied: torchdata>=0.5.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (0.9.0)
Requirement already satisfied: pandas in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (2.2.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (2024.8.30)
Requirement already satisfied: torch>=2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torchdata>=0.5.0->dgl) (2.1.0+cpu)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2024.2)
Requirement already satisfied: six>=1.5 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->dgl) (1.16.0)
Requirement already satisfied: filelock in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.16.1)
Requirement already satisfied: typing-extensions in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.12.2)
Requirement already satisfied: sympy in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.13.3)
Requirement already satisfied: jinja2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.4)
Requirement already satisfied: fsspec in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2024.10.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (3.0.2)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pythonlang.cn/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
DGL installed!

数据准备

为了演示如何组织各种数据,我们首先创建一个基础目录。

[2]:
base_dir = './ondisk_dataset_homograph'
os.makedirs(base_dir, exist_ok=True)
print(f"Created base directory: {base_dir}")
Created base directory: ./ondisk_dataset_homograph

生成图结构数据

对于同构图,我们只需要将边(即种子)保存到 NumpyCSV 文件中。

注意:- 保存为 Numpy 时,数组需要是 (2, N) 的形状。推荐使用此格式,因为它比 CSV 文件构造图的速度快得多。- 保存为 CSV 文件时,不要保存索引和标题。

[3]:
import numpy as np
import pandas as pd
num_nodes = 1000
num_edges = 10 * num_nodes
edges_path = os.path.join(base_dir, "edges.csv")
edges = np.random.randint(0, num_nodes, size=(num_edges, 2))

print(f"Part of edges: {edges[:5, :]}")

df = pd.DataFrame(edges)
df.to_csv(edges_path, index=False, header=False)

print(f"Edges are saved into {edges_path}")
Part of edges: [[734 698]
 [492 101]
 [141 102]
 [255 293]
 [172 382]]
Edges are saved into ./ondisk_dataset_homograph/edges.csv

生成图的特征数据

目前,特征数据支持 numpy 数组和 torch 张量。

[4]:
# Generate node feature in numpy array.
node_feat_0_path = os.path.join(base_dir, "node-feat-0.npy")
node_feat_0 = np.random.rand(num_nodes, 5)
print(f"Part of node feature [feat_0]: {node_feat_0[:3, :]}")
np.save(node_feat_0_path, node_feat_0)
print(f"Node feature [feat_0] is saved to {node_feat_0_path}\n")

# Generate another node feature in torch tensor
node_feat_1_path = os.path.join(base_dir, "node-feat-1.pt")
node_feat_1 = torch.rand(num_nodes, 5)
print(f"Part of node feature [feat_1]: {node_feat_1[:3, :]}")
torch.save(node_feat_1, node_feat_1_path)
print(f"Node feature [feat_1] is saved to {node_feat_1_path}\n")

# Generate edge feature in numpy array.
edge_feat_0_path = os.path.join(base_dir, "edge-feat-0.npy")
edge_feat_0 = np.random.rand(num_edges, 5)
print(f"Part of edge feature [feat_0]: {edge_feat_0[:3, :]}")
np.save(edge_feat_0_path, edge_feat_0)
print(f"Edge feature [feat_0] is saved to {edge_feat_0_path}\n")

# Generate another edge feature in torch tensor
edge_feat_1_path = os.path.join(base_dir, "edge-feat-1.pt")
edge_feat_1 = torch.rand(num_edges, 5)
print(f"Part of edge feature [feat_1]: {edge_feat_1[:3, :]}")
torch.save(edge_feat_1, edge_feat_1_path)
print(f"Edge feature [feat_1] is saved to {edge_feat_1_path}\n")

Part of node feature [feat_0]: [[0.2675768  0.84555141 0.31953485 0.70518215 0.7384711 ]
 [0.17017616 0.67410909 0.49357539 0.17954053 0.51379857]
 [0.20808962 0.62090961 0.00869142 0.76270778 0.75740362]]
Node feature [feat_0] is saved to ./ondisk_dataset_homograph/node-feat-0.npy

Part of node feature [feat_1]: tensor([[0.3102, 0.6617, 0.3103, 0.1763, 0.4377],
        [0.3336, 0.4147, 0.4776, 0.6154, 0.4325],
        [0.9472, 0.4797, 0.4150, 0.9046, 0.7426]])
Node feature [feat_1] is saved to ./ondisk_dataset_homograph/node-feat-1.pt

Part of edge feature [feat_0]: [[0.48323927 0.16915343 0.64657681 0.95671693 0.67171557]
 [0.73523352 0.25524394 0.82357219 0.84688155 0.09598407]
 [0.03860003 0.93619916 0.81360089 0.47665546 0.93298402]]
Edge feature [feat_0] is saved to ./ondisk_dataset_homograph/edge-feat-0.npy

Part of edge feature [feat_1]: tensor([[0.5663, 0.9633, 0.1347, 0.3310, 0.9384],
        [0.8327, 0.9789, 0.8282, 0.2175, 0.5416],
        [0.0256, 0.3471, 0.4384, 0.0020, 0.7780]])
Edge feature [feat_1] is saved to ./ondisk_dataset_homograph/edge-feat-1.pt

生成任务

OnDiskDataset 支持多个任务。对于每个任务,我们需要分别准备训练/验证/测试集。这些集合通常因任务而异。在本教程中,我们将创建一个节点分类任务和一个链接预测任务。

节点分类任务

对于节点分类任务,我们需要每个训练/验证/测试集的节点 ID 和相应的标签。与特征数据一样,这些集合支持 numpy 数组和 torch 张量。

[5]:
num_trains = int(num_nodes * 0.6)
num_vals = int(num_nodes * 0.2)
num_tests = num_nodes - num_trains - num_vals

ids = np.arange(num_nodes)
np.random.shuffle(ids)

nc_train_ids_path = os.path.join(base_dir, "nc-train-ids.npy")
nc_train_ids = ids[:num_trains]
print(f"Part of train ids for node classification: {nc_train_ids[:3]}")
np.save(nc_train_ids_path, nc_train_ids)
print(f"NC train ids are saved to {nc_train_ids_path}\n")

nc_train_labels_path = os.path.join(base_dir, "nc-train-labels.pt")
nc_train_labels = torch.randint(0, 10, (num_trains,))
print(f"Part of train labels for node classification: {nc_train_labels[:3]}")
torch.save(nc_train_labels, nc_train_labels_path)
print(f"NC train labels are saved to {nc_train_labels_path}\n")

nc_val_ids_path = os.path.join(base_dir, "nc-val-ids.npy")
nc_val_ids = ids[num_trains:num_trains+num_vals]
print(f"Part of val ids for node classification: {nc_val_ids[:3]}")
np.save(nc_val_ids_path, nc_val_ids)
print(f"NC val ids are saved to {nc_val_ids_path}\n")

nc_val_labels_path = os.path.join(base_dir, "nc-val-labels.pt")
nc_val_labels = torch.randint(0, 10, (num_vals,))
print(f"Part of val labels for node classification: {nc_val_labels[:3]}")
torch.save(nc_val_labels, nc_val_labels_path)
print(f"NC val labels are saved to {nc_val_labels_path}\n")

nc_test_ids_path = os.path.join(base_dir, "nc-test-ids.npy")
nc_test_ids = ids[-num_tests:]
print(f"Part of test ids for node classification: {nc_test_ids[:3]}")
np.save(nc_test_ids_path, nc_test_ids)
print(f"NC test ids are saved to {nc_test_ids_path}\n")

nc_test_labels_path = os.path.join(base_dir, "nc-test-labels.pt")
nc_test_labels = torch.randint(0, 10, (num_tests,))
print(f"Part of test labels for node classification: {nc_test_labels[:3]}")
torch.save(nc_test_labels, nc_test_labels_path)
print(f"NC test labels are saved to {nc_test_labels_path}\n")
Part of train ids for node classification: [809 209 773]
NC train ids are saved to ./ondisk_dataset_homograph/nc-train-ids.npy

Part of train labels for node classification: tensor([1, 2, 6])
NC train labels are saved to ./ondisk_dataset_homograph/nc-train-labels.pt

Part of val ids for node classification: [156 777 233]
NC val ids are saved to ./ondisk_dataset_homograph/nc-val-ids.npy

Part of val labels for node classification: tensor([2, 8, 3])
NC val labels are saved to ./ondisk_dataset_homograph/nc-val-labels.pt

Part of test ids for node classification: [484 372  48]
NC test ids are saved to ./ondisk_dataset_homograph/nc-test-ids.npy

Part of test labels for node classification: tensor([8, 6, 8])
NC test labels are saved to ./ondisk_dataset_homograph/nc-test-labels.pt

将数据组织到 YAML 文件中

现在我们需要创建一个 metadata.yaml 文件,其中包含图结构、特征数据、训练/验证/测试集的路径和数据类型。

注意:- 所有路径应相对于 metadata.yaml。- 以下字段是可选的,未在以下示例中指定。- in_memory:指示是将数据加载到内存还是使用 mmap。默认值为 True

请参阅 YAML 规范 了解更多详情。

[7]:
yaml_content = f"""
    dataset_name: homogeneous_graph_nc_lp
    graph:
      nodes:
        - num: {num_nodes}
      edges:
        - format: csv
          path: {os.path.basename(edges_path)}
    feature_data:
      - domain: node
        name: feat_0
        format: numpy
        path: {os.path.basename(node_feat_0_path)}
      - domain: node
        name: feat_1
        format: torch
        path: {os.path.basename(node_feat_1_path)}
      - domain: edge
        name: feat_0
        format: numpy
        path: {os.path.basename(edge_feat_0_path)}
      - domain: edge
        name: feat_1
        format: torch
        path: {os.path.basename(edge_feat_1_path)}
    tasks:
      - name: node_classification
        num_classes: 10
        train_set:
          - data:
              - name: seeds
                format: numpy
                path: {os.path.basename(nc_train_ids_path)}
              - name: labels
                format: torch
                path: {os.path.basename(nc_train_labels_path)}
        validation_set:
          - data:
              - name: seeds
                format: numpy
                path: {os.path.basename(nc_val_ids_path)}
              - name: labels
                format: torch
                path: {os.path.basename(nc_val_labels_path)}
        test_set:
          - data:
              - name: seeds
                format: numpy
                path: {os.path.basename(nc_test_ids_path)}
              - name: labels
                format: torch
                path: {os.path.basename(nc_test_labels_path)}
      - name: link_prediction
        num_classes: 10
        train_set:
          - data:
              - name: seeds
                format: numpy
                path: {os.path.basename(lp_train_seeds_path)}
        validation_set:
          - data:
              - name: seeds
                format: numpy
                path: {os.path.basename(lp_val_seeds_path)}
              - name: labels
                format: numpy
                path: {os.path.basename(lp_val_labels_path)}
              - name: indexes
                format: numpy
                path: {os.path.basename(lp_val_indexes_path)}
        test_set:
          - data:
              - name: seeds
                format: numpy
                path: {os.path.basename(lp_test_seeds_path)}
              - name: labels
                format: numpy
                path: {os.path.basename(lp_test_labels_path)}
              - name: indexes
                format: numpy
                path: {os.path.basename(lp_test_indexes_path)}
"""
metadata_path = os.path.join(base_dir, "metadata.yaml")
with open(metadata_path, "w") as f:
  f.write(yaml_content)

实例化 OnDiskDataset

现在我们可以通过 dgl.graphbolt.OnDiskDataset 加载数据集。实例化时,只需传入包含 metadata.yaml 文件的基础目录即可。

首次实例化时,GraphBolt 会预处理原始数据,例如从边构建 FusedCSCSamplingGraph。预处理后,所有数据(包括图、特征数据、训练/验证/测试集)都会被放入 preprocessed 目录。后续任何数据集加载都会跳过预处理阶段。

预处理后,需要显式调用 load() 以加载图、特征数据和任务。

[8]:
dataset = gb.OnDiskDataset(base_dir).load()
graph = dataset.graph
print(f"Loaded graph: {graph}\n")

feature = dataset.feature
print(f"Loaded feature store: {feature}\n")

tasks = dataset.tasks
nc_task = tasks[0]
print(f"Loaded node classification task: {nc_task}\n")
lp_task = tasks[1]
print(f"Loaded link prediction task: {lp_task}\n")
Start to preprocess the on-disk dataset.
Finish preprocessing the on-disk dataset.
Loaded graph: FusedCSCSamplingGraph(csc_indptr=tensor([    0,     7,    15,  ...,  9983,  9988, 10000], dtype=torch.int32),
                      indices=tensor([188, 589, 176,  ..., 294, 762, 730], dtype=torch.int32),
                      total_num_nodes=1000, num_edges=10000,)

Loaded feature store: TorchBasedFeatureStore(
    {(<OnDiskFeatureDataDomain.NODE: 'node'>, None, 'feat_0'): TorchBasedFeature(
        feature=tensor([[0.2676, 0.8456, 0.3195, 0.7052, 0.7385],
                        [0.1702, 0.6741, 0.4936, 0.1795, 0.5138],
                        [0.2081, 0.6209, 0.0087, 0.7627, 0.7574],
                        ...,
                        [0.9070, 0.4060, 0.9906, 0.6465, 0.1518],
                        [0.1824, 0.9145, 0.4194, 0.6864, 0.4178],
                        [0.9964, 0.0864, 0.5270, 0.4842, 0.0228]], dtype=torch.float64),
        metadata={},
    ), (<OnDiskFeatureDataDomain.NODE: 'node'>, None, 'feat_1'): TorchBasedFeature(
        feature=tensor([[0.3102, 0.6617, 0.3103, 0.1763, 0.4377],
                        [0.3336, 0.4147, 0.4776, 0.6154, 0.4325],
                        [0.9472, 0.4797, 0.4150, 0.9046, 0.7426],
                        ...,
                        [0.0891, 0.8304, 0.5157, 0.1804, 0.8821],
                        [0.5526, 0.8321, 0.5452, 0.4415, 0.0907],
                        [0.2525, 0.3944, 0.8356, 0.9236, 0.3284]]),
        metadata={},
    ), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, None, 'feat_0'): TorchBasedFeature(
        feature=tensor([[0.4832, 0.1692, 0.6466, 0.9567, 0.6717],
                        [0.7352, 0.2552, 0.8236, 0.8469, 0.0960],
                        [0.0386, 0.9362, 0.8136, 0.4767, 0.9330],
                        ...,
                        [0.7652, 0.2556, 0.4112, 0.4190, 0.0296],
                        [0.9907, 0.8527, 0.5779, 0.3108, 0.8355],
                        [0.0917, 0.6557, 0.5226, 0.9362, 0.3608]], dtype=torch.float64),
        metadata={},
    ), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, None, 'feat_1'): TorchBasedFeature(
        feature=tensor([[0.5663, 0.9633, 0.1347, 0.3310, 0.9384],
                        [0.8327, 0.9789, 0.8282, 0.2175, 0.5416],
                        [0.0256, 0.3471, 0.4384, 0.0020, 0.7780],
                        ...,
                        [0.3607, 0.8425, 0.5213, 0.5604, 0.7548],
                        [0.1914, 0.1043, 0.7555, 0.8857, 0.1084],
                        [0.3130, 0.3773, 0.9874, 0.2341, 0.6229]]),
        metadata={},
    )}
)

Loaded node classification task: OnDiskTask(validation_set=ItemSet(
               items=(tensor([156, 777, 233, 648, 371, 319, 351, 538, 537, 730,  83, 817, 401, 256,
                   717, 606, 719, 211, 389, 226, 742,  37, 950, 414, 282, 556, 289, 863,
                   293, 273,  86,  97, 529, 301, 492, 140, 268,  11, 477, 766, 654, 999,
                   132, 245, 181, 733, 977, 207, 491, 969, 334,  14, 995, 576, 167, 732,
                   302,  71, 363, 275, 426, 805, 320, 250, 483, 277, 958, 290,  17, 535,
                   296, 647, 108, 867, 815, 821, 270, 819, 223, 471, 584, 928, 877, 621,
                   406, 497, 878,  62, 478, 997, 675, 396, 595, 828,  45, 987, 994, 361,
                   804, 906,  65, 701,  91, 385,   9, 822, 533, 944, 862, 993, 707, 613,
                   961, 755, 429, 467, 769, 740, 340, 912, 684, 616, 510, 459, 131, 945,
                   625, 267, 940, 816, 231, 695, 489, 224, 383,  40, 966, 258,  58, 609,
                   506, 127, 443, 399, 790, 232, 378, 764, 503, 516, 589, 368, 272, 664,
                   807,  38, 579, 376, 704, 607, 318, 114, 978, 907, 726, 735, 982, 909,
                   676,  94, 324, 671, 590, 992, 709,  46, 550, 257, 188, 423, 913, 185,
                   360, 159, 262, 330, 718, 373, 469, 639, 155, 588, 517, 532, 153, 911,
                   472, 456, 806, 594], dtype=torch.int32), tensor([2, 8, 3, 3, 4, 0, 6, 7, 1, 4, 8, 4, 2, 6, 3, 3, 5, 2, 3, 2, 6, 5, 2, 3,
                   1, 0, 5, 8, 2, 0, 9, 8, 2, 7, 9, 3, 4, 3, 4, 3, 9, 2, 4, 6, 8, 4, 5, 8,
                   2, 8, 2, 4, 5, 0, 8, 2, 0, 0, 4, 6, 0, 9, 4, 7, 5, 1, 2, 7, 8, 5, 6, 6,
                   7, 7, 7, 4, 9, 6, 7, 8, 8, 9, 5, 6, 6, 6, 3, 7, 9, 1, 7, 7, 7, 5, 1, 2,
                   1, 7, 7, 3, 7, 1, 0, 0, 6, 5, 6, 2, 7, 6, 3, 0, 3, 2, 2, 3, 9, 3, 4, 7,
                   0, 8, 6, 0, 9, 7, 2, 9, 8, 5, 1, 2, 3, 7, 7, 3, 4, 8, 4, 1, 1, 3, 6, 5,
                   3, 0, 3, 9, 0, 1, 1, 1, 7, 6, 7, 6, 2, 5, 4, 4, 6, 9, 2, 0, 7, 4, 1, 4,
                   7, 9, 4, 5, 6, 1, 0, 7, 3, 6, 3, 7, 1, 9, 0, 9, 3, 1, 7, 8, 1, 1, 4, 0,
                   9, 8, 9, 1, 2, 3, 8, 5])),
               names=('seeds', 'labels'),
           ),
           train_set=ItemSet(
               items=(tensor([809, 209, 773, 871, 422, 244, 549, 736, 875, 100, 846,  32, 295, 380,
                   799, 768, 635, 593, 779,  34, 854, 905, 762, 455, 205, 888, 126, 493,
                   868, 173, 539, 279, 284, 326, 322, 678, 212,  19, 914, 162, 880, 367,
                   146, 306, 242,  67, 900, 112, 220, 448, 963, 241, 260, 409, 661, 724,
                   416, 592,   4, 782, 274, 941, 333, 923, 618, 842, 857,  72,  56, 699,
                   844,  79, 174, 960, 292,  10, 144, 541, 435, 398,  27, 335,  39, 175,
                   929, 237, 772, 681, 150, 234, 741, 366, 869, 344, 650, 783, 400, 449,
                   313, 836, 605, 352, 808, 219, 152, 697, 204, 610, 130, 308, 394,  21,
                   255, 190, 886, 408, 700, 357, 466,  80, 206, 339,  15, 519, 713, 976,
                   265, 792, 927, 710, 309, 460, 314, 187, 433, 747, 288, 527, 967, 261,
                    29, 221, 820, 899, 286, 354, 560, 343, 853, 603, 528, 874, 441, 329,
                   760, 705, 523,  93, 716, 620, 679, 837, 508,  35, 151, 786, 113,  81,
                     2, 147, 893, 739, 932, 850, 168, 673, 956, 788, 974, 641, 952, 655,
                   824, 567, 348, 104, 826, 498, 109, 608, 898, 771, 629, 847, 328, 198,
                   432, 775, 631, 457, 475, 947, 115, 813, 651, 298, 715, 128, 734, 542,
                   753, 500, 178, 891, 611, 586, 141, 636,   8, 217, 677, 545, 191, 192,
                   461,  73, 798, 851, 936, 511, 283, 397, 213, 388, 437, 525, 834, 557,
                   975,   6, 436, 759, 666, 561, 612, 957,  76, 125, 617, 794,   1,  90,
                   251, 350,  60,  16, 922, 942, 119, 754,  85, 990, 737, 507, 811,  77,
                   332,  18, 660, 186, 600, 464,  42, 890, 778, 543, 624, 381, 465, 575,
                     5, 515, 892, 810, 102, 935, 669, 643, 518, 571, 276, 197, 176, 138,
                   632,  98, 170, 751, 239, 910, 105, 349, 182, 218, 447, 667, 450, 193,
                   139, 522, 259, 881, 802, 418, 473, 215, 738, 889, 143, 745,  33, 728,
                   106, 540, 746, 883, 572, 403,  99, 670, 285, 749, 486, 225, 377, 427,
                   849, 362, 179, 393, 691, 553, 402, 797, 796, 479, 365, 711, 665, 902,
                   649, 916, 390, 526, 145, 514, 748, 370, 668, 565, 604, 658, 598, 421,
                   534, 840, 633, 386, 196, 312, 122, 474, 129, 243, 524, 110, 812, 919,
                   962, 327, 872, 948, 656, 955, 861, 791,  31, 568, 307, 870, 998, 725,
                    52, 438, 135, 194,  23, 107, 774, 744, 551, 253, 789,  87, 908, 342,
                   859, 420,  53, 829, 103, 712, 856, 470, 585, 509, 405, 622, 698, 504,
                   708, 379, 425, 263, 451, 827, 985, 573, 830, 703,  82, 743, 440, 512,
                   855, 934, 485, 930, 269,   7, 841, 965, 358, 415, 795, 172, 623, 404,
                   980, 723, 831, 180, 228, 702, 488, 583,  59, 391, 848, 336, 337, 321,
                   685, 252, 984,  25, 570, 501, 184, 858, 566, 694, 894, 495, 901, 305,
                   413, 645, 123, 756, 818, 552,  84,  89, 387, 480, 317, 246,  41, 227,
                   189, 727, 248, 287, 896, 885, 278, 294, 430,  30,  75,  50, 177,  12,
                   674, 353, 630, 291, 690,   3,   0, 428, 569, 160, 971, 882, 663, 266,
                   866, 770, 341, 991, 395, 563, 696,  55, 101, 845, 657,  78, 638, 672,
                   839, 972, 124, 776, 235, 599, 446, 384, 758,  61,  44, 949, 920, 513,
                    64, 979, 445,  70, 157, 210, 355,  47, 662, 731, 133,  63, 411, 574,
                   750,  36, 345, 558, 359, 202, 424, 752,  26, 439, 392, 946, 989, 765,
                   417, 434, 311, 463, 496, 926, 490, 720, 823, 158, 564, 304, 938, 596,
                   121, 407, 555, 904, 200, 973, 148, 149, 240, 959, 706, 614],
                  dtype=torch.int32), tensor([1, 2, 6, 7, 1, 2, 7, 5, 5, 4, 1, 5, 4, 1, 2, 6, 7, 4, 5, 0, 8, 9, 7, 3,
                   3, 8, 9, 9, 7, 4, 2, 2, 1, 4, 4, 4, 2, 4, 0, 9, 4, 9, 2, 4, 2, 9, 2, 3,
                   2, 9, 1, 8, 9, 5, 9, 8, 3, 3, 0, 3, 4, 8, 0, 2, 8, 8, 2, 4, 7, 2, 2, 6,
                   9, 7, 0, 4, 3, 9, 5, 6, 3, 4, 7, 7, 6, 6, 0, 4, 4, 7, 7, 1, 0, 0, 1, 6,
                   3, 9, 0, 3, 4, 9, 0, 2, 8, 1, 6, 0, 0, 4, 2, 9, 0, 0, 7, 9, 0, 0, 1, 7,
                   9, 4, 3, 0, 1, 2, 4, 9, 1, 6, 3, 5, 0, 9, 5, 8, 4, 3, 8, 1, 7, 7, 8, 3,
                   6, 3, 5, 0, 4, 9, 8, 5, 5, 9, 5, 9, 7, 0, 6, 9, 4, 1, 4, 6, 5, 4, 9, 0,
                   9, 2, 0, 1, 5, 5, 0, 1, 0, 7, 3, 2, 2, 9, 6, 4, 2, 2, 4, 9, 1, 9, 9, 9,
                   6, 9, 6, 5, 2, 8, 5, 9, 1, 8, 5, 7, 1, 1, 9, 5, 4, 1, 0, 4, 8, 1, 3, 4,
                   0, 5, 0, 9, 8, 4, 3, 5, 3, 1, 5, 5, 8, 1, 5, 3, 3, 9, 7, 1, 1, 2, 0, 1,
                   4, 8, 8, 0, 0, 2, 4, 6, 8, 8, 9, 9, 5, 1, 8, 4, 1, 6, 6, 4, 4, 2, 4, 6,
                   1, 3, 8, 7, 1, 7, 6, 7, 9, 2, 1, 7, 3, 3, 0, 6, 3, 0, 2, 2, 3, 5, 0, 7,
                   2, 3, 3, 2, 3, 1, 4, 9, 4, 2, 1, 0, 2, 0, 0, 8, 8, 1, 2, 9, 7, 5, 5, 8,
                   8, 9, 1, 1, 1, 9, 9, 2, 1, 0, 6, 0, 0, 0, 5, 6, 6, 2, 7, 6, 2, 5, 8, 3,
                   9, 2, 1, 5, 0, 5, 1, 8, 6, 2, 7, 0, 4, 0, 5, 8, 2, 7, 8, 9, 6, 9, 8, 9,
                   6, 9, 5, 7, 9, 4, 6, 4, 4, 7, 2, 7, 5, 7, 2, 1, 8, 5, 0, 7, 6, 4, 7, 5,
                   7, 7, 0, 7, 1, 6, 7, 9, 7, 2, 9, 9, 2, 8, 3, 9, 2, 8, 9, 7, 4, 0, 3, 1,
                   1, 2, 1, 9, 4, 1, 5, 4, 3, 9, 0, 4, 9, 0, 9, 2, 6, 5, 1, 2, 8, 0, 5, 3,
                   0, 7, 8, 7, 3, 2, 3, 7, 2, 5, 5, 4, 4, 6, 9, 9, 8, 4, 8, 1, 8, 3, 3, 5,
                   8, 4, 5, 5, 5, 5, 0, 7, 3, 6, 9, 1, 5, 8, 2, 0, 0, 0, 1, 5, 5, 9, 6, 3,
                   6, 9, 1, 7, 3, 5, 7, 4, 7, 7, 2, 6, 5, 2, 1, 4, 5, 6, 8, 4, 1, 5, 7, 0,
                   2, 9, 3, 5, 4, 7, 6, 5, 8, 9, 6, 6, 8, 2, 7, 9, 7, 9, 7, 8, 5, 8, 6, 1,
                   5, 4, 1, 8, 5, 1, 5, 3, 7, 0, 7, 7, 9, 9, 8, 3, 9, 4, 0, 4, 4, 6, 9, 1,
                   2, 5, 6, 7, 0, 6, 8, 3, 2, 1, 2, 1, 5, 6, 6, 6, 5, 2, 1, 8, 8, 4, 3, 1,
                   1, 9, 2, 1, 0, 6, 3, 4, 6, 0, 8, 2, 5, 7, 4, 4, 2, 9, 0, 9, 4, 0, 8, 2])),
               names=('seeds', 'labels'),
           ),
           test_set=ItemSet(
               items=(tensor([484, 372,  48, 254, 281, 626, 864, 986, 338,  66, 587, 865, 118, 452,
                   860,  92, 419, 833, 686, 356, 757, 375, 171, 201, 988, 887, 627, 931,
                   970, 876, 154, 458, 642, 236, 481, 601, 761, 951, 195, 116, 835, 693,
                   369, 136, 767, 852, 785, 722, 787, 937, 548, 238, 653,  54, 582, 547,
                   374, 580, 619, 300, 954, 310, 602, 442, 536, 996,  51, 546, 800, 921,
                   924, 382, 692, 781, 531, 784, 111, 142, 410, 918, 939, 364, 634, 578,
                   230, 562,  20, 165, 183, 968,  13, 615, 933, 137, 682, 134, 554, 468,
                   203, 780, 544, 803,  69, 494, 323, 412, 801, 521, 721,  49, 687, 299,
                   591, 271, 214,  24, 264, 482, 838, 347,  74, 689, 117, 925, 164, 964,
                   208, 943, 953, 897, 163, 199, 873, 843, 166, 431,  88,  22, 879, 520,
                    43, 597, 325, 659,  28, 303, 161, 915, 683, 222, 476, 462, 640,  96,
                   644, 903, 297, 917,  68, 502, 216, 453, 825, 637, 983,  57, 646, 249,
                   169, 652, 763, 331, 120, 559, 247, 577, 895, 814, 487, 729, 346, 444,
                   714, 884, 229, 280, 688, 680, 499, 793, 581, 316,  95, 454, 530, 832,
                   315, 981, 505, 628], dtype=torch.int32), tensor([8, 6, 8, 0, 9, 6, 1, 4, 6, 2, 6, 4, 1, 1, 4, 0, 2, 9, 0, 3, 4, 6, 1, 1,
                   8, 4, 3, 2, 7, 5, 9, 7, 8, 3, 3, 2, 9, 7, 5, 2, 7, 0, 4, 7, 1, 6, 2, 2,
                   9, 3, 0, 5, 4, 6, 6, 1, 6, 8, 0, 7, 4, 8, 5, 8, 3, 1, 8, 8, 5, 5, 8, 4,
                   5, 3, 5, 7, 4, 7, 3, 8, 1, 5, 0, 4, 7, 9, 2, 1, 1, 2, 0, 6, 1, 4, 3, 5,
                   8, 9, 9, 7, 0, 8, 5, 5, 2, 0, 3, 3, 5, 9, 5, 3, 5, 2, 2, 1, 6, 0, 3, 1,
                   6, 2, 9, 0, 4, 7, 4, 0, 1, 7, 8, 6, 4, 1, 2, 2, 8, 4, 7, 1, 3, 6, 7, 4,
                   3, 3, 9, 2, 8, 3, 3, 5, 7, 0, 7, 4, 8, 2, 5, 6, 6, 3, 6, 2, 4, 1, 6, 3,
                   1, 2, 3, 1, 0, 9, 2, 7, 7, 9, 6, 4, 1, 4, 2, 0, 9, 3, 6, 3, 6, 1, 5, 1,
                   1, 3, 5, 5, 2, 5, 6, 6])),
               names=('seeds', 'labels'),
           ),
           metadata={'name': 'node_classification', 'num_classes': 10},)

Loaded link prediction task: OnDiskTask(validation_set=ItemSet(
               items=(tensor([[771, 495],
                   [715,  87],
                   [590, 983],
                   ...,
                   [ 55,  17],
                   [ 55, 659],
                   [ 55, 904]], dtype=torch.int32), tensor([1., 1., 1.,  ..., 0., 0., 0.], dtype=torch.float64), tensor([   0,    1,    2,  ..., 1999, 1999, 1999])),
               names=('seeds', 'labels', 'indexes'),
           ),
           train_set=ItemSet(
               items=(tensor([[734, 698],
                   [492, 101],
                   [141, 102],
                   ...,
                   [447, 161],
                   [543, 184],
                   [346, 301]], dtype=torch.int32),),
               names=('seeds',),
           ),
           test_set=ItemSet(
               items=(tensor([[166, 289],
                   [697, 620],
                   [976, 534],
                   ...,
                   [841, 267],
                   [841, 373],
                   [841, 500]], dtype=torch.int32), tensor([1., 1., 1.,  ..., 0., 0., 0.], dtype=torch.float64), tensor([   0,    1,    2,  ..., 1999, 1999, 1999])),
               names=('seeds', 'labels', 'indexes'),
           ),
           metadata={'name': 'link_prediction', 'num_classes': 10},)

/dgl/python/dgl/graphbolt/impl/ondisk_dataset.py:463: GBWarning: Edge feature is stored, but edge IDs are not saved.
  gb_warning("Edge feature is stored, but edge IDs are not saved.")