6.1 使用邻域采样训练用于节点分类的 GNN
要使您的模型进行随机训练,您需要执行以下步骤
定义邻域采样器。
使您的模型适应小批量训练。
修改您的训练循环。
以下小节将逐一介绍这些步骤。
定义邻域采样器和数据加载器
DGL 提供了几个邻域采样器类,它们可以根据我们要计算的节点生成每一层所需的计算依赖关系。
最简单的邻域采样器是 NeighborSampler
或等效的函数式接口 sample_neighbor()
,它使节点能够从其邻居收集消息。
要使用 DGL 提供的采样器,还需要将其与 DataLoader
结合使用,DataLoader 会按小批量迭代一组索引(在本例中为节点)。
例如,以下代码创建一个 DataLoader,该 DataLoader 按批次迭代 ogbn-arxiv
的训练节点 ID 集合,并将生成的 MFG 列表放到 GPU 上。
import dgl
import dgl.graphbolt as gb
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-arxiv").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
# datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
迭代 DataLoader 将产生 MiniBatch
,其中包含一个特殊创建的图列表,代表每层的计算依赖关系。为了使用 DGL 进行训练,您可以通过调用 mini_batch.blocks 访问消息流图 (MFGs)。
mini_batch = next(iter(dataloader))
print(mini_batch.blocks)
使您的模型适应小批量训练
如果您的消息传递模块都由 DGL 提供,则使您的模型适应小批量训练所需的更改是最小的。以多层 GCN 为例。如果您在完整图上的模型实现如下
class TwoLayerGCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.conv1 = dglnn.GraphConv(in_features, hidden_features)
self.conv2 = dglnn.GraphConv(hidden_features, out_features)
def forward(self, g, x):
x = F.relu(self.conv1(g, x))
x = F.relu(self.conv2(g, x))
return x
那么您只需将 g
替换为上面生成的 blocks
。
class StochasticTwoLayerGCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)
self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)
def forward(self, blocks, x):
x = F.relu(self.conv1(blocks[0], x))
x = F.relu(self.conv2(blocks[1], x))
return x
上面的 DGL GraphConv
模块接受数据加载器生成的 blocks
中的一个元素作为参数。
每个 NN 模块的 API 参考会告诉您它是否支持接受 MFG 作为参数。
如果您希望使用自己的消息传递模块,请参阅6.6 实现用于 Mini-batch 训练的自定义 GNN 模块。
训练循环
训练循环仅包含使用自定义批处理迭代器迭代数据集。在每次迭代产生 MiniBatch
时,我们
通过
data.node_features["feat"]
访问与输入节点对应的节点特征。这些特征已经由数据加载器移动到目标设备(CPU 或 GPU)。通过
data.labels
访问与输出节点对应的节点标签。这些标签已经由数据加载器移动到目标设备(CPU 或 GPU)。将 MFG 列表和输入节点特征馈送到多层 GNN 并获取输出。
计算损失并进行反向传播。
model = StochasticTwoLayerGCN(in_features, hidden_features, out_features)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())
for data in dataloader:
input_features = data.node_features["feat"]
output_labels = data.labels
output_predictions = model(data.blocks, input_features)
loss = compute_loss(output_labels, output_predictions)
opt.zero_grad()
loss.backward()
opt.step()
DGL 提供了一个端到端的随机训练示例GraphSAGE 实现。
对于异构图
在异构图上训练用于节点分类的图神经网络是类似的。
例如,我们之前已经看过如何在完整图上训练一个 2 层 RGCN。在小批量训练上实现 RGCN 的代码看起来非常相似(为简单起见,去除了自环、非线性和基分解)
class StochasticTwoLayerRGCN(nn.Module):
def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
for rel in rel_names
})
self.conv2 = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
for rel in rel_names
})
def forward(self, blocks, x):
x = self.conv1(blocks[0], x)
x = self.conv2(blocks[1], x)
return x
DGL 提供的采样器也支持异构图。例如,仍然可以使用提供的 NeighborSampler
类和 DataLoader
类进行随机训练。唯一的区别是 itemset 现在是 HeteroItemSet
的一个实例,它是一个从节点类型到节点 ID 的字典。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-mag").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
# datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
# For heterogeneous graphs, we need to specify the node feature keys
# for each node type.
datapipe = datapipe.fetch_feature(
feature, node_feature_keys={"author": ["feat"], "paper": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
训练循环与同构图的训练循环几乎相同,除了 compute_loss
的实现,这里它将接收两个字典,分别是节点类型和预测结果。
model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features, etypes)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())
for data in dataloader:
# For heterogeneous graphs, we need to specify the node types and
# feature name when accessing the node features. So does the labels.
input_features = {
"author": data.node_features[("author", "feat")],
"paper": data.node_features[("paper", "feat")]
}
output_labels = data.labels["paper"]
output_predictions = model(data.blocks, input_features)
loss = compute_loss(output_labels, output_predictions)
opt.zero_grad()
loss.backward()
opt.step()
DGL 提供了一个端到端的随机训练示例RGCN 实现。