NNConv

class dgl.nn.pytorch.conv.NNConv(in_feats, out_feats, edge_func, aggregator_type='mean', residual=False, bias=True)[source]

Bases: Module

来自 Neural Message Passing for Quantum Chemistry 的图卷积层

\[h_{i}^{l+1} = h_{i}^{l} + \mathrm{aggregate}\left(\left\{ f_\Theta (e_{ij}) \cdot h_j^{l}, j\in \mathcal{N}(i) \right\}\right)\]

其中 \(e_{ij}\) 是边特征,\(f_\Theta\) 是一个具有可学习参数的函数。

参数:
  • in_feats (int) – 输入特征大小;即 \(h_j^{(l)}\) 的维数。NNConv 可以应用于同构图和单向二部图。如果该层应用于单向二部图,in_feats 指定源节点和目标节点的输入特征大小。如果给定一个标量,则源节点和目标节点的特征大小将取相同的值。

  • out_feats (int) – 输出特征大小;即 \(h_i^{(l+1)}\) 的维数。

  • edge_func (callable activation function/layer) – 将每个边特征映射到形状为 (in_feats * out_feats) 的向量作为权重来计算消息。也是公式中的 \(f_\Theta\)

  • aggregator_type (str) – 要使用的聚合器类型(sum, meanmax)。

  • residual (bool, optional) – 如果为 True,则使用残差连接。默认为:False

  • bias (bool, optional) – 如果为 True,则向输出添加可学习的偏置。默认为:True

示例

>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import NNConv
>>> # Case 1: Homogeneous graph
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> lin = th.nn.Linear(5, 20)
>>> def edge_func(efeat):
...     return lin(efeat)
>>> efeat = th.ones(6+6, 5)
>>> conv = NNConv(10, 2, edge_func, 'mean')
>>> res = conv(g, feat, efeat)
>>> res
tensor([[-1.5243, -0.2719],
        [-1.5243, -0.2719],
        [-1.5243, -0.2719],
        [-1.5243, -0.2719],
        [-1.5243, -0.2719],
        [-1.5243, -0.2719]], grad_fn=<AddBackward0>)
>>> # Case 2: Unidirectional bipartite graph
>>> u = [0, 1, 0, 0, 1]
>>> v = [0, 1, 2, 3, 2]
>>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})
>>> u_feat = th.tensor(np.random.rand(2, 10).astype(np.float32))
>>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))
>>> conv = NNConv(10, 2, edge_func, 'mean')
>>> efeat = th.ones(5, 5)
>>> res = conv(g, (u_feat, v_feat), efeat)
>>> res
tensor([[-0.6568,  0.5042],
        [ 0.9089, -0.5352],
        [ 0.1261, -0.0155],
        [-0.6568,  0.5042]], grad_fn=<AddBackward0>)
forward(graph, feat, efeat)[source]

计算 MPNN 图卷积层。

参数:
  • graph (DGLGraph) – 图。

  • feat (torch.Tensor or pair of torch.Tensor) – 输入特征,形状为 \((N, D_{in})\),其中 \(N\) 是图中的节点数,\(D_{in}\) 是输入特征大小。

  • efeat (torch.Tensor) – 边特征,形状为 \((E, *)\),应符合 edge_func 的输入形状要求。\(E\) 是图中的边数。

返回:

输出特征,形状为 \((N, D_{out})\),其中 \(D_{out}\) 是输出特征大小。

返回类型:

torch.Tensor

reset_parameters()[source]

描述

重新初始化可学习参数。

注意

模型参数使用 Glorot 均匀初始化,偏置初始化为零。