3.2 DGL NN 模块 前向函数

(中文版)

在 NN 模块中,forward() 函数执行实际的消息传递和计算。与 PyTorch 的 NN 模块通常接收张量作为参数不同,DGL NN 模块额外接收一个参数 dgl.DGLGraphforward() 函数的工作可以分为三个部分

  • 图检查和图类型规范。

  • 消息传递。

  • 特征更新。

本节其余部分将深入探讨 SAGEConv 示例中的 forward() 函数。

图检查和图类型规范

def forward(self, graph, feat):
    with graph.local_scope():
        # Specify graph type then expand input feature according to graph type
        feat_src, feat_dst = expand_as_pair(feat, graph)

forward() 需要处理许多可能导致计算和消息传递中出现无效值的输入边界情况。像 GraphConv 这样的卷积模块中的一个典型检查是验证输入图没有入度为 0 的节点。当一个节点的入度为 0 时,其 mailbox 将为空,并且归约函数将产生全零值。这可能导致模型性能的静默下降。然而,在 SAGEConv 模块中,聚合表示将与原始节点特征连接,forward() 的输出不会全为零。在这种情况下不需要进行此类检查。

DGL NN 模块应可重用于不同类型的图输入,包括:同构图、异构图 (1.5 异构图)、子图块 (第 6 章: 在大型图上进行随机训练)。

SAGEConv 的数学公式如下:

\[h_{\mathcal{N}(dst)}^{(l+1)} = \mathrm{aggregate} \left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right)\]
\[h_{dst}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat} (h_{dst}^{l}, h_{\mathcal{N}(dst)}^{l+1}) + b \right)\]
\[h_{dst}^{(l+1)} = \mathrm{norm}(h_{dst}^{(l+1)})\]

需要根据图的类型指定源节点特征 feat_src 和目标节点特征 feat_dstexpand_as_pair() 是一个函数,用于指定图的类型并将 feat 扩展为 feat_srcfeat_dst。此函数的详细信息如下所示。

def expand_as_pair(input_, g=None):
    if isinstance(input_, tuple):
        # Bipartite graph case
        return input_
    elif g is not None and g.is_block:
        # Subgraph block case
        if isinstance(input_, Mapping):
            input_dst = {
                k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                for k, v in input_.items()}
        else:
            input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
        return input_, input_dst
    else:
        # Homogeneous graph case
        return input_, input_

对于同构完整图训练,源节点和目标节点是相同的。它们是图中的所有节点。

对于异构情况,图可以被分割成几个二分图,每个关系一个二分图。关系表示为 (src_type, edge_type, dst_dtype)。当它识别出输入特征 feat 是一个元组时,它会将图视为二分的。元组的第一个元素将是源节点特征,第二个元素将是目标节点特征。

在 mini-batch 训练中,计算应用于基于一批目标节点采样的子图。在 DGL 中,该子图被称为 block。在 block 创建阶段,dst nodes 位于节点列表的前面。可以通过索引 [0:g.number_of_dst_nodes()] 找到 feat_dst

在确定 feat_srcfeat_dst 后,上述三种图类型的计算是相同的。

消息传递和归约

import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape

if self._aggre_type == 'mean':
    graph.srcdata['h'] = feat_src
    graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
    h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
    check_eq_shape(feat)
    graph.srcdata['h'] = feat_src
    graph.dstdata['h'] = feat_dst
    graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
    # divide in_degrees
    degs = graph.in_degrees().to(feat_dst)
    h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'pool':
    graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
    graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
    h_neigh = graph.dstdata['neigh']
else:
    raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
    rst = self.fc_neigh(h_neigh)
else:
    rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

代码实际执行消息传递和归约计算。这部分代码因模块而异。请注意,上述代码中的所有消息传递都使用 update_all() API 和 内置 消息/归约函数实现,以充分利用 DGL 的性能优化,如 2.2 编写高效的消息传递代码 中所述。

归约后更新输出特征

# activation
if self.activation is not None:
    rst = self.activation(rst)
# normalization
if self.norm is not None:
    rst = self.norm(rst)
return rst

forward() 函数的最后一部分是在 归约函数 后更新特征。常见的更新操作是根据对象构造阶段设置的选项应用激活函数和归一化。