GroupRevRes

class dgl.nn.pytorch.conv.GroupRevRes(gnn_module, groups=2)[源代码]

基类:Module

用于 GNN 的分组可逆残差连接,如《Training Graph Neural Networks with 1000 Layers》中所介绍的

它将输入节点特征 \(X\) 沿通道维度均匀地划分为 \(C\)\(X_1, X_2, \cdots, X_C\)。此外,它还创建了输入 GNN 模块 \(f_{w1}, \cdots, f_{wC}\)\(C\) 个副本。在前向传播中,每个 GNN 模块仅处理相应的节点特征组。

输出节点表示 \(X^{'}\) 的计算如下。

\[ \begin{align}\begin{aligned}X_0^{'} = \sum_{i=2}^{C}X_i\\X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\in\{1,\cdots,C\}\\X^{'} = X_1^{'} \, \Vert \, \ldots \, \Vert \, X_C^{'}\end{aligned}\end{align} \]

其中 \(g\) 是输入图,\(U\) 是任意附加输入参数,如边特征,而 \(\, \Vert \,\) 表示连接。

参数:
  • gnn_module (nn.Module) – 用于消息传递的 GNN 模块。GroupRevRes 将复制该模块 groups-1 次,总共产生 groups 个副本。输入和输出节点表示的大小需要相同。其前向函数需要依次接受一个 DGLGraph 和相关的输入节点特征,可选择后跟额外的参数,如边特征。

  • groups (int, 可选) – 组的数量。

示例

>>> import dgl
>>> import torch
>>> import torch.nn as nn
>>> from dgl.nn import GraphConv, GroupRevRes
>>> class GNNLayer(nn.Module):
...     def __init__(self, feats, dropout=0.2):
...         super(GNNLayer, self).__init__()
...         # Use BatchNorm and dropout to prevent gradient vanishing
...         # In particular if you use a large number of GNN layers
...         self.norm = nn.BatchNorm1d(feats)
...         self.conv = GraphConv(feats, feats)
...         self.dropout = nn.Dropout(dropout)
...
...     def forward(self, g, x):
...         x = self.norm(x)
...         x = self.dropout(x)
...         return self.conv(g, x)
>>> num_nodes = 5
>>> num_edges = 20
>>> feats = 32
>>> groups = 2
>>> g = dgl.rand_graph(num_nodes, num_edges)
>>> x = torch.randn(num_nodes, feats)
>>> conv = GNNLayer(feats // groups)
>>> model = GroupRevRes(conv, groups)
>>> out = model(g, x)
forward(g, x, *args)[源代码]

应用带有分组可逆残差连接的 GNN 模块。

参数:
  • g (DGLGraph) – 图。

  • x (torch.Tensor) – 输入特征,形状为 \((N, D_{in})\),其中 \(D_{in}\) 是输入特征的大小,\(N\) 是节点数量。

  • args – 传递给 gnn_module 的附加参数。

返回:

输出特征,形状为 \((N, D_{in})\)

返回类型:

torch.Tensor