5.1 节点分类/回归
节点分类是图神经网络最流行和广泛采用的任务之一,其中训练/验证/测试集中的每个节点都被分配一个预定义类别集中的真实类别。节点回归类似,其中训练/验证/测试集中的每个节点都被分配一个真实数值。
概述
为了对节点进行分类,图神经网络执行 第 2 章:消息传递 中讨论的消息传递,以利用节点自身的特征,以及其邻居节点和边的特征。消息传递可以重复多轮,以整合来自更广泛邻域的信息。
编写神经网络模型
DGL 提供了一些内置的图卷积模块,可以执行一轮消息传递。在本指南中,我们选择 dgl.nn.pytorch.SAGEConv
(也可在 MXNet 和 Tensorflow 中使用),这是 GraphSAGE 的图卷积模块。
通常对于图上的深度学习模型,我们需要一个多层图神经网络,其中我们进行多轮消息传递。这可以通过堆叠图卷积模块来实现,如下所示。
# Contruct a two-layer GNN model
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
class SAGE(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats):
super().__init__()
self.conv1 = dglnn.SAGEConv(
in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
self.conv2 = dglnn.SAGEConv(
in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')
def forward(self, graph, inputs):
# inputs are features of nodes
h = self.conv1(graph, inputs)
h = F.relu(h)
h = self.conv2(graph, h)
return h
请注意,上面的模型不仅可以用于节点分类,还可以用于获取隐藏节点表示,以便进行其他下游任务,例如 5.2 边分类/回归、5.3 链接预测 或 5.4 图分类。
有关内置图卷积模块的完整列表,请参阅 apinn。
有关 DGL 神经网络模块如何工作以及如何编写包含消息传递的自定义神经网络模块的更多详细信息,请参阅 第 3 章:构建 GNN 模块 中的示例。
训练循环
在完整图上进行训练只需对上面定义的模型进行前向传播,并通过比较预测结果与训练节点上的真实标签来计算损失。
本节使用 DGL 内置数据集 dgl.data.CiteseerGraphDataset
来展示训练循环。节点特征和标签存储在其图实例上,训练-验证-测试划分也作为布尔掩码存储在图上。这与您在 第 4 章:图数据管道 中看到的情况类似。
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)
以下是按准确率评估模型的示例。
def evaluate(model, graph, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(graph, features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
然后,您可以编写训练循环,如下所示。
model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
model.train()
# forward propagation by using all nodes
logits = model(graph, node_features)
# compute loss
loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
# compute validation accuracy
acc = evaluate(model, graph, node_features, node_labels, valid_mask)
# backward propagation
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
# Save model if necessary. Omitted in this example.
GraphSAGE 提供了一个端到端的同质图节点分类示例。您可以在示例中的 GraphSAGE
类中看到相应的模型实现,该实现具有可调节的层数、dropout 概率以及可定制的聚合函数和非线性函数。
异构图
如果您的图是异构的,您可能希望沿着所有边类型从邻居收集消息。您可以使用模块 dgl.nn.pytorch.HeteroGraphConv
(也可在 MXNet 和 Tensorflow 中使用)在所有边类型上执行消息传递,然后为每种边类型组合不同的图卷积模块。
以下代码将定义一个异构图卷积模块,该模块首先在每种边类型上执行单独的图卷积,然后将每种边类型上的消息聚合结果相加作为所有节点类型的最终结果。
# Define a Heterograph Conv model
class RGCN(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(in_feats, hid_feats)
for rel in rel_names}, aggregate='sum')
self.conv2 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(hid_feats, out_feats)
for rel in rel_names}, aggregate='sum')
def forward(self, graph, inputs):
# inputs are features of nodes
h = self.conv1(graph, inputs)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h)
return h
dgl.nn.HeteroGraphConv
接受一个包含节点类型和节点特征张量的字典作为输入,并返回另一个包含节点类型和节点特征的字典。
因此,考虑到我们在 异构图示例 中有用户和物品特征。
model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
labels = hetero_graph.nodes['user'].data['label']
train_mask = hetero_graph.nodes['user'].data['train_mask']
可以简单地进行前向传播,如下所示
node_features = {'user': user_feats, 'item': item_feats}
h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})
h_user = h_dict['user']
h_item = h_dict['item']
训练循环与同质图的训练循环相同,不同之处在于现在您有一个节点表示字典,您可以从中计算预测结果。例如,如果您只预测 user
节点,您只需从返回的字典中提取 user
节点嵌入。
opt = torch.optim.Adam(model.parameters())
for epoch in range(5):
model.train()
# forward propagation by using all nodes and extracting the user embeddings
logits = model(hetero_graph, node_features)['user']
# compute loss
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
# Compute validation accuracy. Omitted in this example.
# backward propagation
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
# Save model if necessary. Omitted in the example.
DGL 提供了一个端到端的 RGCN 节点分类示例。您可以在 模型实现文件 的 RelGraphConvLayer
中看到异构图卷积的定义。