节点分类

本教程展示了如何使用 DGL 的邻居采样组件在单个 GPU 上训练多层 GraphSAGE 用于节点分类。我们在 Open Graph Benchmark (OGB) 提供的 ogbn-arxiv 数据集上进行训练。该数据集包含大约 17 万个节点和 100 万条边。

Open In Colab GitHub

在本教程结束时,你将能够

  • 使用 DGL 的邻居采样组件,在单个 GPU 上训练 GNN 模型进行节点分类。

安装 DGL 包

[1]:
# Install required packages.
import os
import torch
import numpy as np
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

# Install the CPU version in default. If you want to install CUDA version,
# please refer to https://dgl.ac.cn/pages/start.html and change runtime type
# accordingly.
device = torch.device("cpu")
!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html

try:
    import dgl
    import dgl.graphbolt as gb
    installed = True
except ImportError as error:
    installed = False
    print(error)
print("DGL installed!" if installed else "DGL not found!")
Looking in links: https://data.dgl.ai/wheels-test/repo.html
Requirement already satisfied: dgl in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (2.2a240410)
Requirement already satisfied: numpy>=1.14.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (1.26.4)
Requirement already satisfied: scipy>=1.1.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (1.14.1)
Requirement already satisfied: networkx>=2.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (3.4.2)
Requirement already satisfied: requests>=2.19.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (2.32.3)
Requirement already satisfied: tqdm in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (4.66.6)
Requirement already satisfied: psutil>=5.8.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (6.1.0)
Requirement already satisfied: torchdata>=0.5.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (0.9.0)
Requirement already satisfied: pandas in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (2.2.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (2024.8.30)
Requirement already satisfied: torch>=2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torchdata>=0.5.0->dgl) (2.1.0+cpu)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2024.2)
Requirement already satisfied: six>=1.5 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->dgl) (1.16.0)
Requirement already satisfied: filelock in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.16.1)
Requirement already satisfied: typing-extensions in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.12.2)
Requirement already satisfied: sympy in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.13.3)
Requirement already satisfied: jinja2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.4)
Requirement already satisfied: fsspec in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2024.10.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (3.0.2)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pythonlang.cn/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
DGL installed!

加载数据集

ogbn-arxivGraphBolt 中已经被准备为 BuiltinDataset

[2]:
dataset = gb.BuiltinDataset("ogbn-arxiv-seeds").load()
Downloading datasets/ogbn-arxiv-seeds.zip from https://data.dgl.ai/dataset/graphbolt/ogbn-arxiv-seeds.zip...
Extracting file to datasets
The dataset is already preprocessed.

数据集包含图、特征和任务。你可以从任务中获取训练集、验证集和测试集。种子节点及其对应的标签已经存储在每个训练集、验证集和测试集中。其他元数据(如类别数量)也存储在任务中。在此数据集中,只有一个任务:节点分类

[3]:
graph = dataset.graph.to(device)
feature = dataset.feature.to(device)
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
task_name = dataset.tasks[0].metadata["name"]
num_classes = dataset.tasks[0].metadata["num_classes"]
print(f"Task: {task_name}. Number of classes: {num_classes}")
Task: node_classification. Number of classes: 40

DGL 如何处理计算依赖¶

单个节点消息传递的计算依赖可以描述为一系列消息流图 (MFG)。

DGL Computation

在 DGL 中定义邻居采样器和数据加载器

DGL 提供了工具来以 mini-batch 的方式迭代数据集,同时生成计算其输出所需的计算依赖(如上述 MFG 所示)。对于节点分类,你可以使用 dgl.graphbolt.DataLoader 来迭代数据集。它接受一个数据管道 (data pipe),该管道生成节点的 mini-batch 及其标签,为每个节点采样邻居,并以 MFG 的形式生成计算依赖。它也支持特征获取、块创建和复制到目标设备。所有这些操作在数据管道中被分成独立的阶段,这样你就可以通过插入自己的操作来定制数据管道。

假设每个节点在每一层都会从 4 个邻居收集消息。定义数据加载器和邻居采样的代码如下所示。

[4]:
def create_dataloader(itemset, shuffle):
    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=shuffle)
    datapipe = datapipe.copy_to(device)
    datapipe = datapipe.sample_neighbor(graph, [4, 4])
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    return gb.DataLoader(datapipe)

你可以迭代数据加载器,每次迭代会产生一个 MiniBatch 对象。

[5]:
data = next(iter(create_dataloader(train_set, shuffle=True)))
print(data)
MiniBatch(seeds=tensor([ 65183, 163561,  76677,  ...,  12099,  16843,   4191]),
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([    0,     4,     8,  ..., 14635, 14639, 14643], dtype=torch.int32),
                                                                         indices=tensor([ 1502,  1024,  3847,  ...,  3846, 11707, 11708], dtype=torch.int32),
                                                           ),
                                               original_row_node_ids=tensor([ 65183, 163561,  76677,  ...,  82248,  65213,   9690]),
                                               original_edge_ids=tensor([ 962702,  978322, 1499230,  ..., 2465385,  894690,  135709]),
                                               original_column_node_ids=tensor([ 65183, 163561,  76677,  ...,  29265,   6099, 149787]),
                            ),
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([   0,    4,    8,  ..., 3655, 3658, 3662], dtype=torch.int32),
                                                                         indices=tensor([1024,    0, 1025,  ..., 3845, 3846, 1023], dtype=torch.int32),
                                                           ),
                                               original_row_node_ids=tensor([ 65183, 163561,  76677,  ...,  29265,   6099, 149787]),
                                               original_edge_ids=tensor([ 978322, 2380781, 2044949,  ...,   87067, 2052817, 2319789]),
                                               original_column_node_ids=tensor([ 65183, 163561,  76677,  ...,  12099,  16843,   4191]),
                            )],
          node_features={'feat': tensor([[-0.0422, -0.0014, -0.1854,  ...,  0.0134, -0.0261, -0.0950],
                                [-0.0507,  0.0820, -0.2800,  ...,  0.0044, -0.0816,  0.1165],
                                [-0.2002,  0.0058, -0.3861,  ...,  0.2154,  0.1007, -0.0761],
                                ...,
                                [-0.0414, -0.0854, -0.1887,  ...,  0.1158, -0.0962, -0.0719],
                                [-0.0779,  0.0138, -0.1698,  ...,  0.0613, -0.0654, -0.1182],
                                [-0.0545, -0.0352, -0.0440,  ...,  0.3083, -0.0255, -0.2265]])},
          labels=tensor([19, 34, 27,  ..., 28,  2, 28]),
          input_nodes=tensor([ 65183, 163561,  76677,  ...,  82248,  65213,   9690]),
          indexes=None,
          edge_features=[{},
                        {}],
          compacted_seeds=None,
          blocks=[Block(num_src_nodes=11709, num_dst_nodes=3847, num_edges=14643),
                 Block(num_src_nodes=3847, num_dst_nodes=1024, num_edges=3662)],
       )

你可以从 MFG 中获取输入节点的 ID。

[6]:
mfgs = data.blocks
input_nodes = mfgs[0].srcdata[dgl.NID]
print(f"Input nodes: {input_nodes}.")
Input nodes: tensor([ 65183, 163561,  76677,  ...,  82248,  65213,   9690]).

定义模型

让我们考虑训练一个带有邻居采样的 2 层 GraphSAGE。模型可以写成如下形式

[7]:
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv


class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type="mean")
        self.h_feats = h_feats

    def forward(self, mfgs, x):
        h = self.conv1(mfgs[0], x)
        h = F.relu(h)
        h = self.conv2(mfgs[1], h)
        return h


in_size = feature.size("node", None, "feat")[0]
model = Model(in_size, 64, num_classes).to(device)

定义训练循环

以下代码初始化模型并定义优化器。

[8]:
opt = torch.optim.Adam(model.parameters())

在计算验证分数以进行模型选择时,通常你也可以进行邻居采样。我们可以重用 create_dataloader 函数来创建两个独立的数据加载器,一个用于训练,一个用于验证。

[9]:
train_dataloader = create_dataloader(train_set, shuffle=True)
valid_dataloader = create_dataloader(valid_set, shuffle=False)

import sklearn.metrics

以下是一个训练循环,它在每个 epoch 进行验证。它还将具有最佳验证精度的模型保存到文件中。

[10]:
from tqdm.auto import tqdm

for epoch in range(10):
    model.train()

    with tqdm(train_dataloader) as tq:
        for step, data in enumerate(tq):
            x = data.node_features["feat"]
            labels = data.labels

            predictions = model(data.blocks, x)

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(
                labels.cpu().numpy(),
                predictions.argmax(1).detach().cpu().numpy(),
            )

            tq.set_postfix(
                {"loss": "%.03f" % loss.item(), "acc": "%.03f" % accuracy},
                refresh=False,
            )

    model.eval()

    predictions = []
    labels = []
    with tqdm(valid_dataloader) as tq, torch.no_grad():
        for data in tq:
            x = data.node_features["feat"]
            labels.append(data.labels.cpu().numpy())
            predictions.append(model(data.blocks, x).argmax(1).cpu().numpy())
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print("Epoch {} Validation Accuracy {}".format(epoch, accuracy))
Epoch 0 Validation Accuracy 0.47280110070807746
Epoch 1 Validation Accuracy 0.5744152488338535
Epoch 2 Validation Accuracy 0.6083090036578409
Epoch 3 Validation Accuracy 0.6244169267425081
Epoch 4 Validation Accuracy 0.633813215208564
Epoch 5 Validation Accuracy 0.6389476156918017
Epoch 6 Validation Accuracy 0.6423705493472935
Epoch 7 Validation Accuracy 0.6425718983858518
Epoch 8 Validation Accuracy 0.6542501426222357
Epoch 9 Validation Accuracy 0.6531091647370717

结论

在本教程中,你学习了如何使用邻居采样训练多层 GraphSAGE。