注意
跳转至末尾 下载完整示例代码。
使用图神经网络进行链接预测
在入门教程中,您已经学习了使用 GNN 进行节点分类(即预测图中节点的类别)的基本工作流程。本教程将教您如何训练 GNN 进行链接预测,即预测图中任意两个节点之间是否存在边。
完成本教程后,您将能够
构建基于 GNN 的链接预测模型。
在 DGL 提供的小型数据集上训练和评估模型。
(预估时间:28 分钟)
import itertools
import os
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
GNN 链接预测概述
许多应用,例如社交推荐、物品推荐、知识图谱补全等,都可以表述为链接预测问题,即预测两个特定节点之间是否存在边。本教程展示了一个示例,预测在引文网络中两篇论文之间是否存在引文关系(引用或被引用)。
本教程将链接预测问题表述为如下二分类问题:
将图中的边视为正样本。
采样一些不存在的边(即之间没有边的节点对)作为负样本。
将正样本和负样本划分为训练集和测试集。
使用任何二分类指标(如 AUC(曲线下面积))评估模型。
注意
本实践源于 SEAL,尽管这里的模型未使用其节点标签的思想。
在某些领域,例如大规模推荐系统或信息检索,您可能偏好强调 Top-K 预测良好性能的指标。在这些情况下,您可能需要考虑其他指标,例如平均精度均值 (mean average precision),并使用其他负采样方法,这些超出了本教程的范围。
加载图和特征
遵循入门教程,本教程首先加载 Cora 数据集。
dataset = dgl.data.CoraGraphDataset()
g = dataset[0]
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
准备训练集和测试集
本教程随机选取 10% 的边作为测试集中的正样本,其余作为训练集。然后,在两个集合中采样相同数量的边作为负样本。
# Split edge set for training and testing
u, v = g.edges()
eids = np.arange(g.num_edges())
eids = np.random.permutation(eids)
test_size = int(len(eids) * 0.1)
train_size = g.num_edges() - test_size
test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]
# Find all negative edges and split them for training and testing
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = 1 - adj.todense() - np.eye(g.num_nodes())
neg_u, neg_v = np.where(adj_neg != 0)
neg_eids = np.random.choice(len(neg_u), g.num_edges())
test_neg_u, test_neg_v = (
neg_u[neg_eids[:test_size]],
neg_v[neg_eids[:test_size]],
)
train_neg_u, train_neg_v = (
neg_u[neg_eids[test_size:]],
neg_v[neg_eids[test_size:]],
)
训练时,您需要从原始图中移除测试集中的边。这可以通过 dgl.remove_edges
实现。
注意
dgl.remove_edges
的工作原理是从原始图创建子图,这会产生一个副本,因此对于大型图可能会很慢。如果出现这种情况,您可以将训练图和测试图保存到磁盘,就像进行预处理一样。
定义 GraphSAGE 模型
本教程构建了一个由两个 GraphSAGE 层组成的模型,每个层通过平均邻居信息计算新的节点表示。DGL 提供了 dgl.nn.SAGEConv
,可方便地创建 GraphSAGE 层。
from dgl.nn import SAGEConv
# ----------- 2. create model -------------- #
# build a two-layer GraphSAGE model
class GraphSAGE(nn.Module):
def __init__(self, in_feats, h_feats):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, "mean")
self.conv2 = SAGEConv(h_feats, h_feats, "mean")
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
return h
然后,模型通过一个函数(例如 MLP 或点积)计算两个关联节点表示之间的分数,从而预测边存在的概率,这将在下一节中看到。
正样本图、负样本图和 apply_edges
在之前的教程中,您学习了如何使用 GNN 计算节点表示。然而,链接预测需要计算节点对的表示。
DGL 建议您将节点对视为另一个图,因为您可以用一条边描述一对节点。在链接预测中,您将有一个由所有正样本作为边组成的正样本图,以及一个由所有负样本组成的负样本图。正样本图和负样本图将包含与原始图相同的节点集。这使得在多个图之间传递节点特征进行计算更加容易。正如稍后您将看到的,您可以直接将整个图上计算的节点表示馈送到正样本图和负样本图,以计算成对分数。
以下代码分别构建训练集和测试集的正样本图和负样本图。
train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.num_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.num_nodes())
test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.num_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.num_nodes())
将节点对视为图的好处是您可以使用 DGLGraph.apply_edges
方法,该方法可以方便地根据关联节点的特征和原始边特征(如果适用)计算新的边特征。
DGL 提供了一系列优化的内置函数,用于根据原始节点/边特征计算新的边特征。例如,dgl.function.u_dot_v
计算每条边的关联节点表示的点积。
import dgl.function as fn
class DotPredictor(nn.Module):
def forward(self, g, h):
with g.local_scope():
g.ndata["h"] = h
# Compute a new edge feature named 'score' by a dot-product between the
# source node feature 'h' and destination node feature 'h'.
g.apply_edges(fn.u_dot_v("h", "h", "score"))
# u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
return g.edata["score"][:, 0]
如果函数比较复杂,您也可以自己编写。例如,以下模块通过拼接关联节点的特征并将其传递给 MLP,在每条边上生成一个标量分数。
class MLPPredictor(nn.Module):
def __init__(self, h_feats):
super().__init__()
self.W1 = nn.Linear(h_feats * 2, h_feats)
self.W2 = nn.Linear(h_feats, 1)
def apply_edges(self, edges):
"""
Computes a scalar score for each edge of the given graph.
Parameters
----------
edges :
Has three members ``src``, ``dst`` and ``data``, each of
which is a dictionary representing the features of the
source nodes, the destination nodes, and the edges
themselves.
Returns
-------
dict
A dictionary of new edge features.
"""
h = torch.cat([edges.src["h"], edges.dst["h"]], 1)
return {"score": self.W2(F.relu(self.W1(h))).squeeze(1)}
def forward(self, g, h):
with g.local_scope():
g.ndata["h"] = h
g.apply_edges(self.apply_edges)
return g.edata["score"]
注意
内置函数在速度和内存方面都经过优化。我们建议尽可能使用内置函数。
注意
如果您阅读过消息传递教程,您会注意到 apply_edges
所接受的参数形式与 update_all
中的消息函数形式完全相同。
训练循环
定义了节点表示计算和边分数计算后,您可以继续定义整体模型、损失函数和评估指标。
损失函数即简单的二元交叉熵损失。
本教程中的评估指标是 AUC。
model = GraphSAGE(train_g.ndata["feat"].shape[1], 16)
# You can replace DotPredictor with MLPPredictor.
# pred = MLPPredictor(16)
pred = DotPredictor()
def compute_loss(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score])
labels = torch.cat(
[torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
)
return F.binary_cross_entropy_with_logits(scores, labels)
def compute_auc(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score]).numpy()
labels = torch.cat(
[torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
).numpy()
return roc_auc_score(labels, scores)
训练循环如下:
注意
本教程不包含在验证集上的评估。在实际操作中,您应该根据在验证集上的性能保存和评估最佳模型。
# ----------- 3. set up loss and optimizer -------------- #
# in this case, loss will in training loop
optimizer = torch.optim.Adam(
itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)
# ----------- 4. training -------------------------------- #
all_logits = []
for e in range(100):
# forward
h = model(train_g, train_g.ndata["feat"])
pos_score = pred(train_pos_g, h)
neg_score = pred(train_neg_g, h)
loss = compute_loss(pos_score, neg_score)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
if e % 5 == 0:
print("In epoch {}, loss: {}".format(e, loss))
# ----------- 5. check results ------------------------ #
from sklearn.metrics import roc_auc_score
with torch.no_grad():
pos_score = pred(test_pos_g, h)
neg_score = pred(test_neg_g, h)
print("AUC", compute_auc(pos_score, neg_score))
# Thumbnail credits: Link Prediction with Neo4j, Mark Needham
# sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png'
In epoch 0, loss: 0.7139173746109009
In epoch 5, loss: 0.6922954320907593
In epoch 10, loss: 0.6852926015853882
In epoch 15, loss: 0.6639790534973145
In epoch 20, loss: 0.6166930794715881
In epoch 25, loss: 0.5746302008628845
In epoch 30, loss: 0.5502054691314697
In epoch 35, loss: 0.517124593257904
In epoch 40, loss: 0.49472111463546753
In epoch 45, loss: 0.4738965630531311
In epoch 50, loss: 0.44900473952293396
In epoch 55, loss: 0.4295833706855774
In epoch 60, loss: 0.4069865047931671
In epoch 65, loss: 0.38496825098991394
In epoch 70, loss: 0.3640557825565338
In epoch 75, loss: 0.34291937947273254
In epoch 80, loss: 0.3209986686706543
In epoch 85, loss: 0.29873526096343994
In epoch 90, loss: 0.27633118629455566
In epoch 95, loss: 0.2535451352596283
AUC 0.8701071404505738
脚本总运行时间: (0 分钟 1.418 秒)