6.6 实现用于小批量训练的自定义 GNN 模块

(中文版)

注意

本教程的内容与同质图情况下的本节内容相似。

如果您熟悉如何为同质图或异质图编写更新整个图的自定义 GNN 模块(参见第 3 章:构建 GNN 模块),那么在 MFG 上进行计算的代码是相似的,只是节点被分为输入节点和输出节点。

例如,请考虑以下自定义图卷积模块代码。请注意,它不一定是最高效的实现之一——它们仅用于示例,展示自定义 GNN 模块可能的样子。

class CustomGraphConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super().__init__()
        self.W = nn.Linear(in_feats * 2, out_feats)

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
            return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))

如果您有一个用于整图的自定义消息传递 NN 模块,并且希望使其适用于 MFG,则只需按如下方式重写 forward 函数即可。请注意,整图实现中的相应语句已被注释;您可以比较原始语句和新语句。

class CustomGraphConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super().__init__()
        self.W = nn.Linear(in_feats * 2, out_feats)

    # h is now a pair of feature tensors for input and output nodes, instead of
    # a single feature tensor.
    # def forward(self, g, h):
    def forward(self, block, h):
        # with g.local_scope():
        with block.local_scope():
            # g.ndata['h'] = h
            h_src = h
            h_dst = h[:block.number_of_dst_nodes()]
            block.srcdata['h'] = h_src
            block.dstdata['h'] = h_dst

            # g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
            block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))

            # return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))
            return self.W(torch.cat(
                [block.dstdata['h'], block.dstdata['h_neigh']], 1))

通常,您需要执行以下操作才能使您的 NN 模块适用于 MFG。

异质图

对于异质图,编写自定义 GNN 模块的方式类似。例如,考虑以下适用于整图的模块。

class CustomHeteroGraphConv(nn.Module):
    def __init__(self, g, in_feats, out_feats):
        super().__init__()
        self.Ws = nn.ModuleDict()
        for etype in g.canonical_etypes:
            utype, _, vtype = etype
            self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
        for ntype in g.ntypes:
            self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])

    def forward(self, g, h):
        with g.local_scope():
            for ntype in g.ntypes:
                g.nodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
                g.nodes[ntype].data['h_src'] = h[ntype]
            for etype in g.canonical_etypes:
                utype, _, vtype = etype
                g.update_all(
                    fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
                    etype=etype)
                g.nodes[vtype].data['h_dst'] = g.nodes[vtype].data['h_dst'] + \
                    self.Ws[etype](g.nodes[vtype].data['h_neigh'])
            return {ntype: g.nodes[ntype].data['h_dst'] for ntype in g.ntypes}

对于 CustomHeteroGraphConv,原则是根据特征是用于输入还是输出,将 g.nodes 替换为 g.srcnodesg.dstnodes

class CustomHeteroGraphConv(nn.Module):
    def __init__(self, g, in_feats, out_feats):
        super().__init__()
        self.Ws = nn.ModuleDict()
        for etype in g.canonical_etypes:
            utype, _, vtype = etype
            self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
        for ntype in g.ntypes:
            self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])

    def forward(self, g, h):
        with g.local_scope():
            for ntype in g.ntypes:
                h_src, h_dst = h[ntype]
                g.dstnodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
                g.srcnodes[ntype].data['h_src'] = h[ntype]
            for etype in g.canonical_etypes:
                utype, _, vtype = etype
                g.update_all(
                    fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
                    etype=etype)
                g.dstnodes[vtype].data['h_dst'] = \
                    g.dstnodes[vtype].data['h_dst'] + \
                    self.Ws[etype](g.dstnodes[vtype].data['h_neigh'])
            return {ntype: g.dstnodes[ntype].data['h_dst']
                    for ntype in g.ntypes}

编写适用于同质图、二部图和 MFG 的模块

DGL 中的所有消息传递模块都适用于同质图、单向二部图(具有两种节点类型和一种边类型)以及具有一种边类型的 MFG。本质上,内置 DGL 神经网络模块的输入图和特征必须满足以下任一情况。

  • 如果输入特征是一对张量,则输入图必须是单向二部图。

  • 如果输入特征是单个张量且输入图是 MFG,DGL 将自动将输出节点的特征设置为输入节点特征的前几行。

  • 如果输入特征必须是单个张量且输入图不是 MFG,则输入图必须是同质图。

例如,以下是 dgl.nn.pytorch.SAGEConv(也在 MXNet 和 Tensorflow 中可用)的 PyTorch 实现的简化版本(去除了归一化,仅处理均值聚合等)。

import dgl.function as fn
class SAGEConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super().__init__()
        self.W = nn.Linear(in_feats * 2, out_feats)

    def forward(self, g, h):
        if isinstance(h, tuple):
            h_src, h_dst = h
        elif g.is_block:
            h_src = h
            h_dst = h[:g.number_of_dst_nodes()]
        else:
            h_src = h_dst = h

        g.srcdata['h'] = h_src
        g.dstdata['h'] = h_dst
        g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_neigh'))
        return F.relu(
            self.W(torch.cat([g.dstdata['h'], g.dstdata['h_neigh']], 1)))

第 3 章:构建 GNN 模块 还提供了 dgl.nn.pytorch.SAGEConv 的详细介绍,该模块适用于单向二部图、同质图和 MFG。