5.3 链接预测
在某些其他场景中,您可能想要预测两个给定节点之间是否存在边。这样的任务称为链接预测任务。
概述
基于 GNN 的链接预测模型将两个节点 \(u\) 和 \(v\) 之间的连接可能性表示为它们由多层 GNN 计算出的节点表示 \(\boldsymbol{h}_u^{(L)}\) 和 \(\boldsymbol{h}_v^{(L)}\) 的函数。
在本节中,我们将节点 \(u\) 和节点 \(v\) 之间的 \(y_{u,v}\) 称为得分。
训练链接预测模型需要比较由边连接的节点对的得分与任意一对节点之间的得分。例如,给定一条连接 \(u\) 和 \(v\) 的边,我们鼓励节点 \(u\) 和 \(v\) 之间的得分高于节点 \(u\) 和从任意噪声分布 \(v' \sim P_n(v)\) 中采样的节点 \(v'\) 之间的得分。这种方法称为负采样。
如果最小化,有许多损失函数可以实现上述行为。非穷尽列表包括
交叉熵损失: \(\mathcal{L} = - \log \sigma (y_{u,v}) - \sum_{v_i \sim P_n(v), i=1,\dots,k}\log \left[ 1 - \sigma (y_{u,v_i})\right]\)
BPR 损失: \(\mathcal{L} = \sum_{v_i \sim P_n(v), i=1,\dots,k} - \log \sigma (y_{u,v} - y_{u,v_i})\)
间隔损失: \(\mathcal{L} = \sum_{v_i \sim P_n(v), i=1,\dots,k} \max(0, M - y_{u, v} + y_{u, v_i})\),其中 \(M\) 是一个常数超参数。
如果您了解什么是隐式反馈或噪声对比估计,您可能会觉得这个想法很熟悉。
用于计算 \(u\) 和 \(v\) 之间得分的神经网络模型与 上面 描述的边回归模型相同。
这是一个使用点积计算边得分的示例。
class DotProductPredictor(nn.Module):
def forward(self, graph, h):
# h contains the node representations computed from the GNN defined
# in the node classification section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
return graph.edata['score']
训练循环
由于我们的得分预测模型在图上操作,我们需要将负例表示为另一个图。该图将包含所有负节点对作为边。
下面展示了将负例表示为图的示例。每条边 \((u,v)\) 会得到 \(k\) 个负例 \((u,v_i)\),其中 \(v_i\) 从均匀分布中采样得到。
def construct_negative_graph(graph, k):
src, dst = graph.edges()
neg_src = src.repeat_interleave(k)
neg_dst = torch.randint(0, graph.num_nodes(), (len(src) * k,))
return dgl.graph((neg_src, neg_dst), num_nodes=graph.num_nodes())
预测边得分的模型与边分类/回归模型相同。
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.sage = SAGE(in_features, hidden_features, out_features)
self.pred = DotProductPredictor()
def forward(self, g, neg_g, x):
h = self.sage(g, x)
return self.pred(g, h), self.pred(neg_g, h)
然后,训练循环会重复构建负图并计算损失。
def compute_loss(pos_score, neg_score):
# Margin loss
n_edges = pos_score.shape[0]
return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()
node_features = graph.ndata['feat']
n_features = node_features.shape[1]
k = 5
model = Model(n_features, 100, 100)
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
negative_graph = construct_negative_graph(graph, k)
pos_score, neg_score = model(graph, negative_graph, node_features)
loss = compute_loss(pos_score, neg_score)
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
训练后,可以通过以下方式获取节点表示
node_embeddings = model.sage(graph, node_features)
使用节点嵌入有多种方式。示例包括训练下游分类器,或进行最近邻搜索或最大内积搜索以推荐相关实体。
异构图
异构图上的链接预测与同构图上的链接预测没有太大区别。以下假设我们正在预测一种边类型,并且很容易将其扩展到多种边类型。
例如,您可以重用 上面 的 HeteroDotProductPredictor
来计算某种边类型的链接预测得分。
class HeteroDotProductPredictor(nn.Module):
def forward(self, graph, h, etype):
# h contains the node representations for each node type computed from
# the GNN defined in the previous section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
return graph.edges[etype].data['score']
为了执行负采样,您也可以为您正在进行链接预测的边类型构建一个负图。
def construct_negative_graph(graph, k, etype):
utype, _, vtype = etype
src, dst = graph.edges(etype=etype)
neg_src = src.repeat_interleave(k)
neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
return dgl.heterograph(
{etype: (neg_src, neg_dst)},
num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})
该模型与异构图上的边分类模型略有不同,因为您需要指定执行链接预测的边类型。
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, rel_names):
super().__init__()
self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
self.pred = HeteroDotProductPredictor()
def forward(self, g, neg_g, x, etype):
h = self.sage(g, x)
return self.pred(g, h, etype), self.pred(neg_g, h, etype)
训练循环与同构图的训练循环类似。
def compute_loss(pos_score, neg_score):
# Margin loss
n_edges = pos_score.shape[0]
return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()
k = 5
model = Model(10, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
node_features = {'user': user_feats, 'item': item_feats}
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item'))
pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item'))
loss = compute_loss(pos_score, neg_score)
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())