图神经网络中的图扩散
本教程首先简要介绍图上的扩散过程。然后阐述图神经网络如何利用这一概念来增强预测能力。
[ ]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"
# Uncomment below to install required packages. If the CUDA version is not 11.8,
# check the https://dgl.ac.cn/pages/start.html to find the supported CUDA
# version and corresponding command to install DGL.
#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null
#!pip install --upgrade scipy networkx > /dev/null
try:
import dgl
installed = True
except ImportError:
installed = False
print("DGL installed!" if installed else "Failed to install DGL!")
图扩散
扩散描述了物质从一个区域移动到另一个区域的过程。在图的上下文中,扩散物质(例如,实值信号)沿边从节点传播到节点。
在数学上,令 \(\vec x\) 为节点信号向量,则图扩散操作可以定义为
,其中 \(\tilde{A}\) 是通常从图的邻接矩阵导出的扩散矩阵。尽管扩散矩阵的选择可能不同,但扩散矩阵通常是稀疏的,因此 \(\tilde{A} \vec{x}\) 是稀疏-稠密矩阵乘法。
让我们通过一个简单的例子进一步理解。首先,我们获取著名的 Karate Club 网络的邻接矩阵。
[ ]:
import dgl
import dgl.sparse as dglsp
from dgl.data import KarateClubDataset
# Get the graph from DGL's builtin dataset.
dataset = KarateClubDataset()
dgl_g = dataset[0]
# Get its adjacency matrix.
indices = torch.stack(dgl_g.edges())
N = dgl_g.num_nodes()
A = dglsp.spmatrix(indices, shape=(N, N))
print(A.to_dense())
tensor([[0., 1., 1., ..., 1., 0., 0.],
[1., 0., 1., ..., 0., 0., 0.],
[1., 1., 0., ..., 0., 1., 0.],
...,
[1., 0., 0., ..., 0., 1., 1.],
[0., 0., 1., ..., 1., 0., 1.],
[0., 0., 0., ..., 1., 1., 0.]])
在本例中,我们使用图卷积网络中的图卷积矩阵作为扩散矩阵。图卷积矩阵定义为
,其中 \(\bar{A} = A + I\),\(A\) 表示邻接矩阵,\(I\) 表示单位矩阵,\(\bar{D}\) 指的是 \(\bar{A}\) 的对角节点度矩阵。
[ ]:
# Compute graph convolution matrix.
I = dglsp.identity(A.shape)
A_hat = A + I
D_hat = dglsp.diag(A_hat.sum(dim=1))
D_hat_invsqrt = D_hat ** -0.5
A_tilde = D_hat_invsqrt @ A_hat @ D_hat_invsqrt
print(A_tilde.to_dense())
tensor([[0.0588, 0.0767, 0.0731, ..., 0.0917, 0.0000, 0.0000],
[0.0767, 0.1000, 0.0953, ..., 0.0000, 0.0000, 0.0000],
[0.0731, 0.0953, 0.0909, ..., 0.0000, 0.0836, 0.0000],
...,
[0.0917, 0.0000, 0.0000, ..., 0.1429, 0.1048, 0.0891],
[0.0000, 0.0000, 0.0836, ..., 0.1048, 0.0769, 0.0654],
[0.0000, 0.0000, 0.0000, ..., 0.0891, 0.0654, 0.0556]])
对于节点信号,我们将除一个节点外所有节点的信号设为零。
[ ]:
# Initial node signals. All nodes except one are set to zero.
X = torch.zeros(N)
X[0] = 5.
# Number of diffusion steps.
r = 8
# Record the signals after each diffusion step.
results = [X]
for _ in range(r):
X = A_tilde @ X
results.append(X)
下面的程序通过动画可视化了扩散过程。要播放动画,请点击“播放”图标。您将看到节点特征如何随时间收敛。
[ ]:
import matplotlib.pyplot as plt
import networkx as nx
from IPython.display import HTML
from matplotlib import animation
nx_g = dgl_g.to_networkx().to_undirected()
pos = nx.spring_layout(nx_g)
fig, ax = plt.subplots()
plt.close()
def animate(i):
ax.cla()
# Color nodes based on their features.
nodes = nx.draw_networkx_nodes(nx_g, pos, ax=ax, node_size=200, node_color=results[i].tolist(), cmap=plt.cm.Blues)
# Set boundary color of the nodes.
nodes.set_edgecolor("#000000")
nx.draw_networkx_edges(nx_g, pos, ax=ax)
ani = animation.FuncAnimation(fig, animate, frames=len(results), interval=1000)
HTML(ani.to_jshtml())
GNNs 中的图扩散
可伸缩 Inception 图神经网络 (SIGN) 同时利用了多个扩散算子。形式上,它定义为
其中: * \(\sigma\) 和 \(\xi\) 是非线性激活函数。 * \([\cdot,\cdots,\cdot]\) 是连接(concatenation)操作。 * \(X\in\mathbb{R}^{n\times d}\) 是输入节点特征矩阵,包含 \(n\) 个节点,每个节点具有 \(d\) 维特征向量。 * \(\Theta_0,\cdots,\Theta_r\in\mathbb{R}^{d\times d'}\) 是可学习的权重矩阵。 * \(A_1,\cdots, A_r\in\mathbb{R}^{n\times n}\) 是线性扩散算子。在下面的示例中,我们将 \(A_i\) 视为 \(A^i\),其中 \(A\) 是图的卷积矩阵。 - \(\Omega\in\mathbb{R}^{d'(r+1)\times c}\) 是可学习的权重矩阵,\(c\) 是类别数量。
下面的代码实现了扩散函数,用于计算 \(A_1X, A_2X, \cdots, A_rX\),以及组合所有扩散后的节点特征的模块。
[ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
################################################################################
# (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement the feature
# diffusion in SIGN laconically.
################################################################################
def sign_diffusion(A, X, r):
# Perform the r-hop diffusion operation.
X_sign = [X]
for i in range(r):
# A^i X
X = A @ X
X_sign.append(X)
return X_sign
class SIGN(nn.Module):
def __init__(self, in_size, out_size, r, hidden_size=256):
super().__init__()
self.theta = nn.ModuleList(
[nn.Linear(in_size, hidden_size) for _ in range(r + 1)]
)
self.omega = nn.Linear(hidden_size * (r + 1), out_size)
def forward(self, X_sign):
results = []
for i in range(len(X_sign)):
results.append(self.theta[i](X_sign[i]))
Z = F.relu(torch.cat(results, dim=1))
return self.omega(Z)
训练
我们在 Cora 数据集上训练 SIGN 模型。节点特征在预处理阶段进行扩散。
[ ]:
from dgl.data import CoraGraphDataset
from torch.optim import Adam
def evaluate(g, pred):
label = g.ndata["label"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
# Compute accuracy on validation/test set.
val_acc = (pred[val_mask] == label[val_mask]).float().mean()
test_acc = (pred[test_mask] == label[test_mask]).float().mean()
return val_acc, test_acc
def train(model, g, X_sign):
label = g.ndata["label"]
train_mask = g.ndata["train_mask"]
optimizer = Adam(model.parameters(), lr=3e-3)
for epoch in range(10):
# Switch the model to training mode.
model.train()
# Forward.
logits = model(X_sign)
# Compute loss with nodes in training set.
loss = F.cross_entropy(logits[train_mask], label[train_mask])
# Backward.
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Switch the model to evaluating mode.
model.eval()
# Compute prediction.
logits = model(X_sign)
pred = logits.argmax(1)
# Evaluate the prediction.
val_acc, test_acc = evaluate(g, pred)
print(
f"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test"
f" acc: {test_acc:.3f}"
)
# If CUDA is available, use GPU to accelerate the training, use CPU
# otherwise.
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load graph from the existing dataset.
dataset = CoraGraphDataset()
g = dataset[0].to(dev)
# Create the sparse adjacency matrix A (note that W was used as the notation
# for adjacency matrix in the original paper).
indices = torch.stack(g.edges())
N = g.num_nodes()
A = dglsp.spmatrix(indices, shape=(N, N))
# Calculate the graph convolution matrix.
I = dglsp.identity(A.shape, device=dev)
A_hat = A + I
D_hat_invsqrt = dglsp.diag(A_hat.sum(dim=1)) ** -0.5
A_hat = D_hat_invsqrt @ A_hat @ D_hat_invsqrt
# 2-hop diffusion.
r = 2
X = g.ndata["feat"]
X_sign = sign_diffusion(A_hat, X, r)
# Create SIGN model.
in_size = X.shape[1]
out_size = dataset.num_classes
model = SIGN(in_size, out_size, r).to(dev)
# Kick off training.
train(model, g, X_sign)
Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
Extracting file to /root/.dgl/cora_v2
Finished data loading and preprocessing.
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done saving data into cached files.
In epoch 0, loss: 1.946, val acc: 0.164, test acc: 0.200
In epoch 1, loss: 1.937, val acc: 0.712, test acc: 0.690
In epoch 2, loss: 1.926, val acc: 0.610, test acc: 0.595
In epoch 3, loss: 1.914, val acc: 0.656, test acc: 0.640
In epoch 4, loss: 1.898, val acc: 0.724, test acc: 0.726
In epoch 5, loss: 1.880, val acc: 0.734, test acc: 0.753
In epoch 6, loss: 1.859, val acc: 0.730, test acc: 0.746
In epoch 7, loss: 1.834, val acc: 0.732, test acc: 0.743
In epoch 8, loss: 1.807, val acc: 0.734, test acc: 0.746
In epoch 9, loss: 1.776, val acc: 0.734, test acc: 0.745
请查看完整的示例脚本此处。了解更多关于图扩散如何在其他 GNN 模型中使用: