6.1 使用邻域采样训练用于节点分类的 GNN

(中文版)

要使您的模型进行随机训练,您需要执行以下步骤

  • 定义邻域采样器。

  • 使您的模型适应小批量训练。

  • 修改您的训练循环。

以下小节将逐一介绍这些步骤。

定义邻域采样器和数据加载器

DGL 提供了几个邻域采样器类,它们可以根据我们要计算的节点生成每一层所需的计算依赖关系。

最简单的邻域采样器是 NeighborSampler 或等效的函数式接口 sample_neighbor(),它使节点能够从其邻居收集消息。

要使用 DGL 提供的采样器,还需要将其与 DataLoader 结合使用,DataLoader 会按小批量迭代一组索引(在本例中为节点)。

例如,以下代码创建一个 DataLoader,该 DataLoader 按批次迭代 ogbn-arxiv 的训练节点 ID 集合,并将生成的 MFG 列表放到 GPU 上。

import dgl
import dgl.graphbolt as gb
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-arxiv").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
# datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

迭代 DataLoader 将产生 MiniBatch,其中包含一个特殊创建的图列表,代表每层的计算依赖关系。为了使用 DGL 进行训练,您可以通过调用 mini_batch.blocks 访问消息流图 (MFGs)。

mini_batch = next(iter(dataloader))
print(mini_batch.blocks)

注意

请参阅随机训练教程了解消息流图的概念。

如果您希望开发自己的邻域采样器或想更详细地了解 MFG 的概念,请参阅6.4 实现自定义图采样器

使您的模型适应小批量训练

如果您的消息传递模块都由 DGL 提供,则使您的模型适应小批量训练所需的更改是最小的。以多层 GCN 为例。如果您在完整图上的模型实现如下

class TwoLayerGCN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.conv1 = dglnn.GraphConv(in_features, hidden_features)
        self.conv2 = dglnn.GraphConv(hidden_features, out_features)

    def forward(self, g, x):
        x = F.relu(self.conv1(g, x))
        x = F.relu(self.conv2(g, x))
        return x

那么您只需将 g 替换为上面生成的 blocks

class StochasticTwoLayerGCN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)
        self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)

    def forward(self, blocks, x):
        x = F.relu(self.conv1(blocks[0], x))
        x = F.relu(self.conv2(blocks[1], x))
        return x

上面的 DGL GraphConv 模块接受数据加载器生成的 blocks 中的一个元素作为参数。

每个 NN 模块的 API 参考会告诉您它是否支持接受 MFG 作为参数。

如果您希望使用自己的消息传递模块,请参阅6.6 实现用于 Mini-batch 训练的自定义 GNN 模块

训练循环

训练循环仅包含使用自定义批处理迭代器迭代数据集。在每次迭代产生 MiniBatch 时,我们

  1. 通过 data.node_features["feat"] 访问与输入节点对应的节点特征。这些特征已经由数据加载器移动到目标设备(CPU 或 GPU)。

  2. 通过 data.labels 访问与输出节点对应的节点标签。这些标签已经由数据加载器移动到目标设备(CPU 或 GPU)。

  3. 将 MFG 列表和输入节点特征馈送到多层 GNN 并获取输出。

  4. 计算损失并进行反向传播。

model = StochasticTwoLayerGCN(in_features, hidden_features, out_features)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())

for data in dataloader:
    input_features = data.node_features["feat"]
    output_labels = data.labels
    output_predictions = model(data.blocks, input_features)
    loss = compute_loss(output_labels, output_predictions)
    opt.zero_grad()
    loss.backward()
    opt.step()

DGL 提供了一个端到端的随机训练示例GraphSAGE 实现

对于异构图

在异构图上训练用于节点分类的图神经网络是类似的。

例如,我们之前已经看过如何在完整图上训练一个 2 层 RGCN。在小批量训练上实现 RGCN 的代码看起来非常相似(为简单起见,去除了自环、非线性和基分解)

class StochasticTwoLayerRGCN(nn.Module):
    def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
                rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
                for rel in rel_names
            })
        self.conv2 = dglnn.HeteroGraphConv({
                rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
                for rel in rel_names
            })

    def forward(self, blocks, x):
        x = self.conv1(blocks[0], x)
        x = self.conv2(blocks[1], x)
        return x

DGL 提供的采样器也支持异构图。例如,仍然可以使用提供的 NeighborSampler 类和 DataLoader 类进行随机训练。唯一的区别是 itemset 现在是 HeteroItemSet 的一个实例,它是一个从节点类型到节点 ID 的字典。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-mag").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
# datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
# For heterogeneous graphs, we need to specify the node feature keys
# for each node type.
datapipe = datapipe.fetch_feature(
    feature, node_feature_keys={"author": ["feat"], "paper": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

训练循环与同构图的训练循环几乎相同,除了 compute_loss 的实现,这里它将接收两个字典,分别是节点类型和预测结果。

model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features, etypes)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())

for data in dataloader:
    # For heterogeneous graphs, we need to specify the node types and
    # feature name when accessing the node features. So does the labels.
    input_features = {
        "author": data.node_features[("author", "feat")],
        "paper": data.node_features[("paper", "feat")]
    }
    output_labels = data.labels["paper"]
    output_predictions = model(data.blocks, input_features)
    loss = compute_loss(output_labels, output_predictions)
    opt.zero_grad()
    loss.backward()
    opt.step()

DGL 提供了一个端到端的随机训练示例RGCN 实现