GroupRevRes
- class dgl.nn.pytorch.conv.GroupRevRes(gnn_module, groups=2)[源代码]
基类:
Module
用于 GNN 的分组可逆残差连接,如《Training Graph Neural Networks with 1000 Layers》中所介绍的
它将输入节点特征 \(X\) 沿通道维度均匀地划分为 \(C\) 组 \(X_1, X_2, \cdots, X_C\)。此外,它还创建了输入 GNN 模块 \(f_{w1}, \cdots, f_{wC}\) 的 \(C\) 个副本。在前向传播中,每个 GNN 模块仅处理相应的节点特征组。
输出节点表示 \(X^{'}\) 的计算如下。
\[ \begin{align}\begin{aligned}X_0^{'} = \sum_{i=2}^{C}X_i\\X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\in\{1,\cdots,C\}\\X^{'} = X_1^{'} \, \Vert \, \ldots \, \Vert \, X_C^{'}\end{aligned}\end{align} \]其中 \(g\) 是输入图,\(U\) 是任意附加输入参数,如边特征,而 \(\, \Vert \,\) 表示连接。
- 参数:
gnn_module (nn.Module) – 用于消息传递的 GNN 模块。
GroupRevRes
将复制该模块groups
-1 次,总共产生groups
个副本。输入和输出节点表示的大小需要相同。其前向函数需要依次接受一个 DGLGraph 和相关的输入节点特征,可选择后跟额外的参数,如边特征。groups (int, 可选) – 组的数量。
示例
>>> import dgl >>> import torch >>> import torch.nn as nn >>> from dgl.nn import GraphConv, GroupRevRes
>>> class GNNLayer(nn.Module): ... def __init__(self, feats, dropout=0.2): ... super(GNNLayer, self).__init__() ... # Use BatchNorm and dropout to prevent gradient vanishing ... # In particular if you use a large number of GNN layers ... self.norm = nn.BatchNorm1d(feats) ... self.conv = GraphConv(feats, feats) ... self.dropout = nn.Dropout(dropout) ... ... def forward(self, g, x): ... x = self.norm(x) ... x = self.dropout(x) ... return self.conv(g, x)
>>> num_nodes = 5 >>> num_edges = 20 >>> feats = 32 >>> groups = 2 >>> g = dgl.rand_graph(num_nodes, num_edges) >>> x = torch.randn(num_nodes, feats) >>> conv = GNNLayer(feats // groups) >>> model = GroupRevRes(conv, groups) >>> out = model(g, x)