使用稀疏矩阵构建图卷积网络
本教程逐步说明如何使用 DGL 的稀疏矩阵 API 编写和训练图卷积网络 (Kipf 等人 (2017))。
[ ]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"
# Uncomment below to install required packages. If the CUDA version is not 11.8,
# check the https://dgl.ac.cn/pages/start.html to find the supported CUDA
# version and corresponding command to install DGL.
#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null
try:
import dgl
installed = True
except ImportError:
installed = False
print("DGL installed!" if installed else "DGL not found!")
图卷积层
数学上,图卷积层定义为
\[f(X^{(l)}, A) = \sigma(\bar{D}^{-\frac{1}{2}}\bar{A}\bar{D}^{-\frac{1}{2}}X^{(l)}W^{(l)})\]
其中 \(\bar{A} = A + I\),\(A\) 表示邻接矩阵,\(I\) 表示单位矩阵。\(\bar{D}\) 指的是 \(\bar{A}\) 的对角节点度矩阵,\(W^{(l)}\) 表示一个可训练的权重矩阵。\(\sigma\) 指的是非线性激活函数(例如 relu)。
下面的代码展示了如何使用 dgl.sparse
包来实现它。核心操作包括:
dgl.sparse.identity
创建单位矩阵 \(I\)。增强的邻接矩阵 \(\bar{A}\) 是通过将单位矩阵加到邻接矩阵 \(A\) 上计算得出的。
A_hat.sum(0)
沿着第一个维度聚合增强的邻接矩阵 \(\bar{A}\),得到增强图的度向量。对角度矩阵 \(\bar{D}\) 然后通过dgl.sparse.diag
创建。计算 \(\bar{D}^{-\frac{1}{2}}\)。
D_hat_invsqrt @ A_hat @ D_hat_invsqrt
计算卷积矩阵,该矩阵随后与线性变换后的节点特征相乘。
[ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.sparse as dglsp
class GCNLayer(nn.Module):
def __init__(self, in_size, out_size):
super(GCNLayer, self).__init__()
self.W = nn.Linear(in_size, out_size)
def forward(self, A, X):
########################################################################
# (HIGHLIGHT) Compute the symmetrically normalized adjacency matrix with
# Sparse Matrix API
########################################################################
I = dglsp.identity(A.shape)
A_hat = A + I
D_hat = dglsp.diag(A_hat.sum(0))
D_hat_invsqrt = D_hat ** -0.5
return D_hat_invsqrt @ A_hat @ D_hat_invsqrt @ self.W(X)
图卷积网络通过堆叠此层来定义。
[ ]:
# Create a GCN with the GCN layer.
class GCN(nn.Module):
def __init__(self, in_size, out_size, hidden_size):
super(GCN, self).__init__()
self.conv1 = GCNLayer(in_size, hidden_size)
self.conv2 = GCNLayer(hidden_size, out_size)
def forward(self, A, X):
X = self.conv1(A, X)
X = F.relu(X)
return self.conv2(A, X)
训练 GCN
然后我们在 Cora 数据集上训练 GCN 模型用于节点分类。请注意,由于模型需要邻接矩阵作为第一个参数,我们首先使用 dgl.sparse.from_coo
API 从图中构建邻接矩阵,该 API 返回一个 DGL SparseMatrix
对象。
[ ]:
def evaluate(g, pred):
label = g.ndata["label"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
# Compute accuracy on validation/test set.
val_acc = (pred[val_mask] == label[val_mask]).float().mean()
test_acc = (pred[test_mask] == label[test_mask]).float().mean()
return val_acc, test_acc
def train(model, g):
features = g.ndata["feat"]
label = g.ndata["label"]
train_mask = g.ndata["train_mask"]
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
loss_fcn = nn.CrossEntropyLoss()
# Preprocess to get the adjacency matrix of the graph.
indices = torch.stack(g.edges())
N = g.num_nodes()
A = dglsp.spmatrix(indices, shape=(N, N))
for epoch in range(100):
model.train()
# Forward.
logits = model(A, features)
# Compute loss with nodes in the training set.
loss = loss_fcn(logits[train_mask], label[train_mask])
# Backward.
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Compute prediction.
pred = logits.argmax(dim=1)
# Evaluate the prediction.
val_acc, test_acc = evaluate(g, pred)
if epoch % 5 == 0:
print(
f"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}"
f", test acc: {test_acc:.3f}"
)
# Load graph from the existing dataset.
dataset = dgl.data.CoraGraphDataset()
g = dataset[0]
# Create model.
feature = g.ndata['feat']
in_size = feature.shape[1]
out_size = dataset.num_classes
gcn_model = GCN(in_size, out_size, 16)
# Kick off training.
train(gcn_model, g)
Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
Extracting file to /root/.dgl/cora_v2
Finished data loading and preprocessing.
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done saving data into cached files.
In epoch 0, loss: 1.954, val acc: 0.114, test acc: 0.103
In epoch 5, loss: 1.921, val acc: 0.158, test acc: 0.147
In epoch 10, loss: 1.878, val acc: 0.288, test acc: 0.283
In epoch 15, loss: 1.822, val acc: 0.344, test acc: 0.353
In epoch 20, loss: 1.751, val acc: 0.388, test acc: 0.389
In epoch 25, loss: 1.663, val acc: 0.406, test acc: 0.410
In epoch 30, loss: 1.562, val acc: 0.472, test acc: 0.481
In epoch 35, loss: 1.450, val acc: 0.558, test acc: 0.573
In epoch 40, loss: 1.333, val acc: 0.636, test acc: 0.641
In epoch 45, loss: 1.216, val acc: 0.684, test acc: 0.683
In epoch 50, loss: 1.102, val acc: 0.726, test acc: 0.713
In epoch 55, loss: 0.996, val acc: 0.740, test acc: 0.740
In epoch 60, loss: 0.899, val acc: 0.754, test acc: 0.760
In epoch 65, loss: 0.813, val acc: 0.762, test acc: 0.771
In epoch 70, loss: 0.737, val acc: 0.768, test acc: 0.781
In epoch 75, loss: 0.671, val acc: 0.776, test acc: 0.786
In epoch 80, loss: 0.614, val acc: 0.784, test acc: 0.790
In epoch 85, loss: 0.566, val acc: 0.780, test acc: 0.788
In epoch 90, loss: 0.524, val acc: 0.780, test acc: 0.791
In epoch 95, loss: 0.489, val acc: 0.772, test acc: 0.795
查看完整的示例脚本 此处。