5.2 边分类/回归
有时你希望预测图的边上的属性。在这种情况下,你可能想要一个边分类/回归模型。
这里我们生成一个随机图用于边预测作为演示。
src = np.random.randint(0, 100, 500)
dst = np.random.randint(0, 100, 500)
# make it symmetric
edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))
# synthetic node and edge features, as well as edge labels
edge_pred_graph.ndata['feature'] = torch.randn(100, 10)
edge_pred_graph.edata['feature'] = torch.randn(1000, 10)
edge_pred_graph.edata['label'] = torch.randn(1000)
# synthetic train-validation-test splits
edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)
概述
在前一节中,你学习了如何使用多层 GNN 进行节点分类。同样的技术可以用于计算任何节点的隐藏表示。然后可以从边的关联节点的表示中导出对边的预测。
计算边上预测的最常见情况是将其表达为其关联节点表示以及可选地其自身边特征的参数化函数。
与节点分类的模型实现差异
假设你使用前一节的模型计算了节点表示,你只需要再编写一个组件,该组件使用 apply_edges()
方法计算边预测。
例如,如果你想为每条边计算一个得分进行边回归,以下代码计算了每条边上关联节点表示的点积。
import dgl.function as fn
class DotProductPredictor(nn.Module):
def forward(self, graph, h):
# h contains the node representations computed from the GNN defined
# in the node classification section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
return graph.edata['score']
也可以编写一个预测函数,该函数使用 MLP 为每条边预测一个向量。此类向量可以用于进一步的下游任务,例如作为分类分布的 logits。
class MLPPredictor(nn.Module):
def __init__(self, in_features, out_classes):
super().__init__()
self.W = nn.Linear(in_features * 2, out_classes)
def apply_edges(self, edges):
h_u = edges.src['h']
h_v = edges.dst['h']
score = self.W(torch.cat([h_u, h_v], 1))
return {'score': score}
def forward(self, graph, h):
# h contains the node representations computed from the GNN defined
# in the node classification section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(self.apply_edges)
return graph.edata['score']
训练循环
给定节点表示计算模型和边预测模型,我们可以轻松编写一个全图训练循环,在其中计算所有边上的预测。
以下示例以前一节中的 SAGE
作为节点表示计算模型,DotPredictor
作为边预测模型。
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.sage = SAGE(in_features, hidden_features, out_features)
self.pred = DotProductPredictor()
def forward(self, g, x):
h = self.sage(g, x)
return self.pred(g, h)
在此示例中,我们还假设训练/验证/测试边集是通过边上的布尔掩码标识的。此示例也不包括提前停止和模型保存。
node_features = edge_pred_graph.ndata['feature']
edge_label = edge_pred_graph.edata['label']
train_mask = edge_pred_graph.edata['train_mask']
model = Model(10, 20, 5)
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
pred = model(edge_pred_graph, node_features)
loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
异构图
异构图上的边分类与同构图上的边分类差别不大。如果你希望在某个边类型上进行边分类,你只需计算所有节点类型的节点表示,并使用 apply_edges()
方法在该边类型上进行预测。
例如,要使 DotProductPredictor
在异构图的一个边类型上工作,你只需在 apply_edges
方法中指定边类型。
class HeteroDotProductPredictor(nn.Module):
def forward(self, graph, h, etype):
# h contains the node representations for each edge type computed from
# the GNN for heterogeneous graphs defined in the node classification
# section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = h # assigns 'h' of all node types in one shot
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
return graph.edges[etype].data['score']
你也可以类似地编写一个 HeteroMLPPredictor
。
class HeteroMLPPredictor(nn.Module):
def __init__(self, in_features, out_classes):
super().__init__()
self.W = nn.Linear(in_features * 2, out_classes)
def apply_edges(self, edges):
h_u = edges.src['h']
h_v = edges.dst['h']
score = self.W(torch.cat([h_u, h_v], 1))
return {'score': score}
def forward(self, graph, h, etype):
# h contains the node representations for each edge type computed from
# the GNN for heterogeneous graphs defined in the node classification
# section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = h # assigns 'h' of all node types in one shot
graph.apply_edges(self.apply_edges, etype=etype)
return graph.edges[etype].data['score']
对单个边类型上的每条边预测得分的端到端模型将如下所示
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, rel_names):
super().__init__()
self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
self.pred = HeteroDotProductPredictor()
def forward(self, g, x, etype):
h = self.sage(g, x)
return self.pred(g, h, etype)
使用模型只需向模型提供一个包含节点类型和特征的字典。
model = Model(10, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
label = hetero_graph.edges['click'].data['label']
train_mask = hetero_graph.edges['click'].data['train_mask']
node_features = {'user': user_feats, 'item': item_feats}
然后训练循环与同构图中的几乎相同。例如,如果你希望预测边类型 click
上的边标签,那么你可以简单地这样做
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
pred = model(hetero_graph, node_features, 'click')
loss = ((pred[train_mask] - label[train_mask]) ** 2).mean()
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
预测异构图上现有边的边类型
有时你可能想预测现有边属于哪种类型。
例如,给定异构图示例,你的任务是给定连接用户和项目的边,预测用户是否会 click
(点击) 或 dislike
(不喜欢) 一个项目。
这是评分预测的简化版本,在推荐文献中很常见。
你可以使用异构图卷积网络来获取节点表示。例如,你仍然可以使用之前定义的 RGCN 来实现此目的。
要预测边的类型,你可以简单地重新利用上面的 HeteroDotProductPredictor
,使其接受另一个图,该图仅包含一个将所有要预测的边类型“合并”的边类型,并为每条边输出每种类型的得分。
在此示例中,你需要一个包含两种节点类型 user
和 item> 的图,以及一个合并了
user
和 item
之间所有边类型(即 click
和 dislike
)的单一边类型。这可以使用以下语法方便地创建
dec_graph = hetero_graph['user', :, 'item']
它返回一个异构图,包含节点类型 user
和 item
,以及合并了它们之间所有边类型(即 click
和 dislike
)的单一边类型。
由于上述语句还返回原始边类型作为名为 dgl.ETYPE
的特征,我们可以将其用作标签。
edge_label = dec_graph.edata[dgl.ETYPE]
给定上述图作为边类型预测模块的输入,你可以将预测模块编写如下。
class HeteroMLPPredictor(nn.Module):
def __init__(self, in_dims, n_classes):
super().__init__()
self.W = nn.Linear(in_dims * 2, n_classes)
def apply_edges(self, edges):
x = torch.cat([edges.src['h'], edges.dst['h']], 1)
y = self.W(x)
return {'score': y}
def forward(self, graph, h):
# h contains the node representations for each edge type computed from
# the GNN for heterogeneous graphs defined in the node classification
# section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = h # assigns 'h' of all node types in one shot
graph.apply_edges(self.apply_edges)
return graph.edata['score']
结合节点表示模块和边类型预测模块的模型如下
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, rel_names):
super().__init__()
self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
self.pred = HeteroMLPPredictor(out_features, len(rel_names))
def forward(self, g, x, dec_graph):
h = self.sage(g, x)
return self.pred(dec_graph, h)
训练循环就简单地如下所示
model = Model(10, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
node_features = {'user': user_feats, 'item': item_feats}
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
logits = model(hetero_graph, node_features, dec_graph)
loss = F.cross_entropy(logits, edge_label)
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
DGL 提供了图卷积矩阵补全 (Graph Convolutional Matrix Completion) 作为评分预测的一个示例,它通过预测异构图上现有边的类型来表述。在模型实现文件中,节点表示模块称为 GCMCLayer
,边类型预测模块称为 BiDecoder
。它们都比这里描述的设置更复杂。