注意
跳转到末尾 下载完整示例代码。
图卷积网络
作者: Qi Huang, Minjie Wang, Yu Gai, Quan Gan, Zheng Zhang
警告
本教程旨在深入理解论文,代码作为解释手段。因此,本实现并未针对运行效率进行优化。有关推荐的实现,请参考官方示例。
这是使用 DGL 实现图卷积网络(Kipf & Welling 等人的论文 Semi-Supervised Classification with Graph Convolutional Networks)的温和介绍。我们解释了 GraphConv
模块的内部原理。读者将学习如何使用 DGL 的消息传递 API 定义新的 GNN 层。
模型概述
从消息传递的角度看 GCN
我们从消息传递的角度描述图卷积神经网络的一层;数学原理可以在这里找到。对于每个节点 \(u\),它归结为以下步骤:
1) 聚合邻居的表示 \(h_{v}\) 以产生中间表示 \(\hat{h}_u\)。 2) 使用线性投影和非线性变换处理聚合表示 \(\hat{h}_{u}\): \(h_{u} = f(W_{u} \hat{h}_u)\)。
我们将使用 DGL 消息传递实现步骤 1,并使用 PyTorch nn.Module
实现步骤 2。
使用 DGL 实现 GCN
我们首先像往常一样定义消息函数和规约函数。由于节点 \(u\) 上的聚合只涉及对邻居表示 \(h_v\) 进行求和,我们可以简单地使用内置函数
import os
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
gcn_msg = fn.copy_u(u="h", out="m")
gcn_reduce = fn.sum(msg="m", out="h")
然后我们开始定义 GCNLayer 模块。GCNLayer 本质上是对所有节点执行消息传递,然后应用一个全连接层。
注意
这展示了如何从头开始实现 GCN。DGL 提供了一个更高效的内置 GCN 层模块
。
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, g, feature):
# Creating a local scope so that all the stored ndata and edata
# (such as the `'h'` ndata below) are automatically popped out
# when the scope exits.
with g.local_scope():
g.ndata["h"] = feature
g.update_all(gcn_msg, gcn_reduce)
h = g.ndata["h"]
return self.linear(h)
前向函数本质上与 PyTorch 中任何常见的神经网络模型相同。我们可以像初始化任何 nn.Module
一样初始化 GCN。例如,让我们定义一个由两个 GCN 层组成的简单神经网络。假设我们正在为 cora 数据集训练分类器(输入特征大小为 1433,类别数为 7)。最后一个 GCN 层计算节点嵌入,因此最后一层通常不应用激活函数。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = GCNLayer(1433, 16)
self.layer2 = GCNLayer(16, 7)
def forward(self, g, features):
x = F.relu(self.layer1(g, features))
x = self.layer2(g, x)
return x
net = Net()
print(net)
Net(
(layer1): GCNLayer(
(linear): Linear(in_features=1433, out_features=16, bias=True)
)
(layer2): GCNLayer(
(linear): Linear(in_features=16, out_features=7, bias=True)
)
)
我们使用 DGL 的内置数据模块加载 cora 数据集。
from dgl.data import CoraGraphDataset
def load_cora_data():
dataset = CoraGraphDataset()
g = dataset[0]
features = g.ndata["feat"]
labels = g.ndata["label"]
train_mask = g.ndata["train_mask"]
test_mask = g.ndata["test_mask"]
return g, features, labels, train_mask, test_mask
模型训练完成后,我们可以使用以下方法评估模型在测试数据集上的性能
def evaluate(model, g, features, labels, mask):
model.eval()
with th.no_grad():
logits = model(g, features)
logits = logits[mask]
labels = labels[mask]
_, indices = th.max(logits, dim=1)
correct = th.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
然后我们按照如下方式训练网络
import time
import numpy as np
g, features, labels, train_mask, test_mask = load_cora_data()
# Add edges between each node and itself to preserve old node representations
g.add_edges(g.nodes(), g.nodes())
optimizer = th.optim.Adam(net.parameters(), lr=1e-2)
dur = []
for epoch in range(50):
if epoch >= 3:
t0 = time.time()
net.train()
logits = net(g, features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
acc = evaluate(net, g, features, labels, test_mask)
print(
"Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), acc, np.mean(dur)
)
)
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
/opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.
return _methods._mean(a, axis=axis, dtype=dtype,
/opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide
ret = ret.dtype.type(ret / rcount)
Epoch 00000 | Loss 1.9446 | Test Acc 0.4060 | Time(s) nan
Epoch 00001 | Loss 1.7764 | Test Acc 0.5710 | Time(s) nan
Epoch 00002 | Loss 1.5769 | Test Acc 0.6260 | Time(s) nan
Epoch 00003 | Loss 1.4453 | Test Acc 0.6760 | Time(s) 0.0083
Epoch 00004 | Loss 1.3253 | Test Acc 0.7270 | Time(s) 0.0084
Epoch 00005 | Loss 1.2157 | Test Acc 0.7520 | Time(s) 0.0084
Epoch 00006 | Loss 1.1186 | Test Acc 0.7590 | Time(s) 0.0084
Epoch 00007 | Loss 1.0321 | Test Acc 0.7650 | Time(s) 0.0084
Epoch 00008 | Loss 0.9506 | Test Acc 0.7560 | Time(s) 0.0084
Epoch 00009 | Loss 0.8720 | Test Acc 0.7520 | Time(s) 0.0084
Epoch 00010 | Loss 0.7984 | Test Acc 0.7430 | Time(s) 0.0084
Epoch 00011 | Loss 0.7302 | Test Acc 0.7450 | Time(s) 0.0083
Epoch 00012 | Loss 0.6676 | Test Acc 0.7420 | Time(s) 0.0083
Epoch 00013 | Loss 0.6103 | Test Acc 0.7420 | Time(s) 0.0083
Epoch 00014 | Loss 0.5568 | Test Acc 0.7420 | Time(s) 0.0083
Epoch 00015 | Loss 0.5067 | Test Acc 0.7450 | Time(s) 0.0083
Epoch 00016 | Loss 0.4606 | Test Acc 0.7450 | Time(s) 0.0083
Epoch 00017 | Loss 0.4186 | Test Acc 0.7470 | Time(s) 0.0083
Epoch 00018 | Loss 0.3802 | Test Acc 0.7460 | Time(s) 0.0083
Epoch 00019 | Loss 0.3451 | Test Acc 0.7450 | Time(s) 0.0083
Epoch 00020 | Loss 0.3132 | Test Acc 0.7500 | Time(s) 0.0083
Epoch 00021 | Loss 0.2844 | Test Acc 0.7520 | Time(s) 0.0083
Epoch 00022 | Loss 0.2584 | Test Acc 0.7580 | Time(s) 0.0082
Epoch 00023 | Loss 0.2349 | Test Acc 0.7630 | Time(s) 0.0082
Epoch 00024 | Loss 0.2135 | Test Acc 0.7590 | Time(s) 0.0081
Epoch 00025 | Loss 0.1938 | Test Acc 0.7590 | Time(s) 0.0081
Epoch 00026 | Loss 0.1758 | Test Acc 0.7570 | Time(s) 0.0081
Epoch 00027 | Loss 0.1594 | Test Acc 0.7550 | Time(s) 0.0080
Epoch 00028 | Loss 0.1446 | Test Acc 0.7540 | Time(s) 0.0080
Epoch 00029 | Loss 0.1312 | Test Acc 0.7490 | Time(s) 0.0080
Epoch 00030 | Loss 0.1190 | Test Acc 0.7470 | Time(s) 0.0080
Epoch 00031 | Loss 0.1081 | Test Acc 0.7460 | Time(s) 0.0080
Epoch 00032 | Loss 0.0983 | Test Acc 0.7450 | Time(s) 0.0079
Epoch 00033 | Loss 0.0895 | Test Acc 0.7410 | Time(s) 0.0079
Epoch 00034 | Loss 0.0816 | Test Acc 0.7430 | Time(s) 0.0079
Epoch 00035 | Loss 0.0745 | Test Acc 0.7420 | Time(s) 0.0079
Epoch 00036 | Loss 0.0680 | Test Acc 0.7430 | Time(s) 0.0079
Epoch 00037 | Loss 0.0623 | Test Acc 0.7430 | Time(s) 0.0079
Epoch 00038 | Loss 0.0571 | Test Acc 0.7430 | Time(s) 0.0079
Epoch 00039 | Loss 0.0524 | Test Acc 0.7450 | Time(s) 0.0079
Epoch 00040 | Loss 0.0482 | Test Acc 0.7460 | Time(s) 0.0078
Epoch 00041 | Loss 0.0444 | Test Acc 0.7470 | Time(s) 0.0078
Epoch 00042 | Loss 0.0410 | Test Acc 0.7460 | Time(s) 0.0078
Epoch 00043 | Loss 0.0380 | Test Acc 0.7470 | Time(s) 0.0078
Epoch 00044 | Loss 0.0353 | Test Acc 0.7470 | Time(s) 0.0078
Epoch 00045 | Loss 0.0328 | Test Acc 0.7470 | Time(s) 0.0078
Epoch 00046 | Loss 0.0306 | Test Acc 0.7470 | Time(s) 0.0078
Epoch 00047 | Loss 0.0286 | Test Acc 0.7470 | Time(s) 0.0078
Epoch 00048 | Loss 0.0267 | Test Acc 0.7460 | Time(s) 0.0078
Epoch 00049 | Loss 0.0250 | Test Acc 0.7440 | Time(s) 0.0078
GCN 的一个公式
数学上,GCN 模型遵循以下公式
\(H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})\)
这里,\(H^{(l)}\) 表示网络中的第 \(l\) 层,\(\sigma\) 是非线性激活函数,\(W\) 是该层的权重矩阵。\(\tilde{D}\) 和 \(\tilde{A}\) 分别是图的度矩阵和邻接矩阵。带有上标 ~ 的变体表示我们在每个节点与其自身之间添加了额外的边,以在图卷积中保留其旧的表示。输入 \(H^{(0)}\) 的形状是 \(N \times D\),其中 \(N\) 是节点数,\(D\) 是输入特征的数量。我们可以像这样连接多个层,以生成形状为 \(N \times F\) 的节点级表示输出,其中 \(F\) 是输出节点特征向量的维度。
该方程可以使用稀疏矩阵乘法核(例如 Kipf 的 pygcn 代码)高效实现。上面的 DGL 实现由于使用了内置函数,实际上已经使用了这个技巧。
请注意,本教程代码实现的是简化版 GCN,其中我们将 \(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}\) 替换为 \(\tilde{A}\)。有关完整实现,请参阅我们的示例这里。
脚本总运行时间: (0 minutes 0.743 seconds)