3.1 DGL NN 模块构建函数

(中文版)

构建函数执行以下步骤

  1. 设置选项。

  2. 注册可学习参数或子模块。

  3. 重置参数。

import torch.nn as nn

from dgl.utils import expand_as_pair

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation

在构建函数中,首先需要设置数据维度。对于一般的 PyTorch 模块,维度通常是输入维度、输出维度和隐藏层维度。对于图神经网络,输入维度可以分为源节点维度和目标节点维度。

除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type)。聚合类型决定了不同边上的消息如何为一个特定的目标节点进行聚合。常用的聚合类型包括 meansummaxmin。有些模块可能会应用更复杂的聚合,例如 lstm

这里的 norm 是一个用于特征归一化的可调用函数。在 SAGEConv 论文中,这种归一化可以是 L2 归一化:\(h_v = h_v / \lVert h_v \rVert_2\)

# aggregator type: mean, pool, lstm, gcn
if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:
    raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'pool':
    self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
    self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'pool', 'lstm']:
    self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()

注册参数和子模块。在 SAGEConv 中,子模块根据聚合类型而异。这些模块是纯粹的 PyTorch nn 模块,例如 nn.Linearnn.LSTM 等。在构建函数的最后,通过调用 reset_parameters() 应用权重初始化。

def reset_parameters(self):
    """Reinitialize learnable parameters."""
    gain = nn.init.calculate_gain('relu')
    if self._aggre_type == 'pool':
        nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
    if self._aggre_type == 'lstm':
        self.lstm.reset_parameters()
    if self._aggre_type != 'gcn':
        nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
    nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)