GatedGraphConv

class dgl.nn.pytorch.conv.GatedGraphConv(in_feats, out_feats, n_steps, n_etypes, bias=True)[源代码]

基类: Module

来自 Gated Graph Sequence Neural Networks 的 Gated Graph Convolution 层

\[ \begin{align}\begin{aligned}h_{i}^{0} &= [ x_i \| \mathbf{0} ]\\a_{i}^{t} &= \sum_{j\in\mathcal{N}(i)} W_{e_{ij}} h_{j}^{t}\\h_{i}^{t+1} &= \mathrm{GRU}(a_{i}^{t}, h_{i}^{t})\end{aligned}\end{align} \]
参数:
  • in_feats (int) – 输入特征大小;即 \(x_i\) 的维度数量。

  • out_feats (int) – 输出特征大小;即 \(h_i^{(t+1)}\) 的维度数量。

  • n_steps (int) – 循环步数;即上述公式中的 \(t\)

  • n_etypes (int) – 边类型数量。

  • bias (bool) – 如果为 True,则向输出添加可学习的偏置。默认值: True

示例

>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import GatedGraphConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 10)
>>> conv = GatedGraphConv(10, 10, 2, 3)
>>> etype = th.tensor([0,1,2,0,1,2])
>>> res = conv(g, feat, etype)
>>> res
tensor([[ 0.4652,  0.4458,  0.5169,  0.4126,  0.4847,  0.2303,  0.2757,  0.7721,
        0.0523,  0.0857],
        [ 0.0832,  0.1388, -0.5643,  0.7053, -0.2524, -0.3847,  0.7587,  0.8245,
        0.9315,  0.4063],
        [ 0.6340,  0.4096,  0.7692,  0.2125,  0.2106,  0.4542, -0.0580,  0.3364,
        -0.1376,  0.4948],
        [ 0.5551,  0.7946,  0.6220,  0.8058,  0.5711,  0.3063, -0.5454,  0.2272,
        -0.6931, -0.1607],
        [ 0.2644,  0.2469, -0.6143,  0.6008, -0.1516, -0.3781,  0.5878,  0.7993,
        0.9241,  0.1835],
        [ 0.6393,  0.3447,  0.3893,  0.4279,  0.3342,  0.3809,  0.0406,  0.5030,
        0.1342,  0.0425]], grad_fn=<AddBackward0>)
forward(graph, feat, etypes=None)[源代码]

描述

计算 Gated Graph Convolution 层。

参数 graph:

图。

类型 graph:

DGLGraph

参数 feat:

输入特征,形状为 \((N, D_{in})\),其中 \(N\) 是图的节点数,\(D_{in}\) 是输入特征大小。

类型 feat:

torch.Tensor

参数 etypes:

边类型张量,形状为 \((E,)\),其中 \(E\) 是图的边数。当只有一种边类型时,此参数可以省略。

类型 etypes:

torch.LongTensor 或 None

返回:

输出特征,形状为 \((N, D_{out})\),其中 \(D_{out}\) 是输出特征大小。

返回类型:

torch.Tensor

reset_parameters()[源代码]

描述

重新初始化可学习参数。

注意

模型参数使用 Glorot 均匀初始化,偏置初始化为零。