6.6 实现用于小批量训练的自定义 GNN 模块
注意
本教程的内容与同质图情况下的本节内容相似。
如果您熟悉如何为同质图或异质图编写更新整个图的自定义 GNN 模块(参见第 3 章:构建 GNN 模块),那么在 MFG 上进行计算的代码是相似的,只是节点被分为输入节点和输出节点。
例如,请考虑以下自定义图卷积模块代码。请注意,它不一定是最高效的实现之一——它们仅用于示例,展示自定义 GNN 模块可能的样子。
class CustomGraphConv(nn.Module):
def __init__(self, in_feats, out_feats):
super().__init__()
self.W = nn.Linear(in_feats * 2, out_feats)
def forward(self, g, h):
with g.local_scope():
g.ndata['h'] = h
g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))
如果您有一个用于整图的自定义消息传递 NN 模块,并且希望使其适用于 MFG,则只需按如下方式重写 forward 函数即可。请注意,整图实现中的相应语句已被注释;您可以比较原始语句和新语句。
class CustomGraphConv(nn.Module):
def __init__(self, in_feats, out_feats):
super().__init__()
self.W = nn.Linear(in_feats * 2, out_feats)
# h is now a pair of feature tensors for input and output nodes, instead of
# a single feature tensor.
# def forward(self, g, h):
def forward(self, block, h):
# with g.local_scope():
with block.local_scope():
# g.ndata['h'] = h
h_src = h
h_dst = h[:block.number_of_dst_nodes()]
block.srcdata['h'] = h_src
block.dstdata['h'] = h_dst
# g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
# return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))
return self.W(torch.cat(
[block.dstdata['h'], block.dstdata['h_neigh']], 1))
通常,您需要执行以下操作才能使您的 NN 模块适用于 MFG。
通过对输入特征进行切片来获取输出节点的特征,切片的前几行就是输出节点的特征。行数可以通过
block.number_of_dst_nodes
获取。如果原始图只有一种节点类型,则将
g.ndata
替换为block.srcdata
(用于输入节点特征)或block.dstdata
(用于输出节点特征)。如果原始图有多种节点类型,则将
g.nodes
替换为block.srcnodes
(用于输入节点特征)或block.dstnodes
(用于输出节点特征)。将
g.num_nodes
替换为block.number_of_src_nodes
或block.number_of_dst_nodes
,分别用于获取输入节点数或输出节点数。
异质图
对于异质图,编写自定义 GNN 模块的方式类似。例如,考虑以下适用于整图的模块。
class CustomHeteroGraphConv(nn.Module):
def __init__(self, g, in_feats, out_feats):
super().__init__()
self.Ws = nn.ModuleDict()
for etype in g.canonical_etypes:
utype, _, vtype = etype
self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
for ntype in g.ntypes:
self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])
def forward(self, g, h):
with g.local_scope():
for ntype in g.ntypes:
g.nodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
g.nodes[ntype].data['h_src'] = h[ntype]
for etype in g.canonical_etypes:
utype, _, vtype = etype
g.update_all(
fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
etype=etype)
g.nodes[vtype].data['h_dst'] = g.nodes[vtype].data['h_dst'] + \
self.Ws[etype](g.nodes[vtype].data['h_neigh'])
return {ntype: g.nodes[ntype].data['h_dst'] for ntype in g.ntypes}
对于 CustomHeteroGraphConv
,原则是根据特征是用于输入还是输出,将 g.nodes
替换为 g.srcnodes
或 g.dstnodes
。
class CustomHeteroGraphConv(nn.Module):
def __init__(self, g, in_feats, out_feats):
super().__init__()
self.Ws = nn.ModuleDict()
for etype in g.canonical_etypes:
utype, _, vtype = etype
self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
for ntype in g.ntypes:
self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])
def forward(self, g, h):
with g.local_scope():
for ntype in g.ntypes:
h_src, h_dst = h[ntype]
g.dstnodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
g.srcnodes[ntype].data['h_src'] = h[ntype]
for etype in g.canonical_etypes:
utype, _, vtype = etype
g.update_all(
fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
etype=etype)
g.dstnodes[vtype].data['h_dst'] = \
g.dstnodes[vtype].data['h_dst'] + \
self.Ws[etype](g.dstnodes[vtype].data['h_neigh'])
return {ntype: g.dstnodes[ntype].data['h_dst']
for ntype in g.ntypes}
编写适用于同质图、二部图和 MFG 的模块
DGL 中的所有消息传递模块都适用于同质图、单向二部图(具有两种节点类型和一种边类型)以及具有一种边类型的 MFG。本质上,内置 DGL 神经网络模块的输入图和特征必须满足以下任一情况。
如果输入特征是一对张量,则输入图必须是单向二部图。
如果输入特征是单个张量且输入图是 MFG,DGL 将自动将输出节点的特征设置为输入节点特征的前几行。
如果输入特征必须是单个张量且输入图不是 MFG,则输入图必须是同质图。
例如,以下是 dgl.nn.pytorch.SAGEConv
(也在 MXNet 和 Tensorflow 中可用)的 PyTorch 实现的简化版本(去除了归一化,仅处理均值聚合等)。
import dgl.function as fn
class SAGEConv(nn.Module):
def __init__(self, in_feats, out_feats):
super().__init__()
self.W = nn.Linear(in_feats * 2, out_feats)
def forward(self, g, h):
if isinstance(h, tuple):
h_src, h_dst = h
elif g.is_block:
h_src = h
h_dst = h[:g.number_of_dst_nodes()]
else:
h_src = h_dst = h
g.srcdata['h'] = h_src
g.dstdata['h'] = h_dst
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_neigh'))
return F.relu(
self.W(torch.cat([g.dstdata['h'], g.dstdata['h_neigh']], 1)))
第 3 章:构建 GNN 模块 还提供了 dgl.nn.pytorch.SAGEConv
的详细介绍,该模块适用于单向二部图、同质图和 MFG。