GlobalAttentionPooling

class dgl.nn.pytorch.glob.GlobalAttentionPooling(gate_nn, feat_nn=None)[源代码]

基类: Module

来自 Gated Graph Sequence Neural Networks 的全局注意力池化。

\[r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate} \left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)\]
参数:
  • gate_nn (torch.nn.Module) – 用于计算每个特征的注意力分数的神经网络。

  • feat_nn (torch.nn.Module, 可选) – 在将每个特征与注意力分数结合之前应用于它们的神经网络。

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch as th
>>> from dgl.nn import GlobalAttentionPooling
>>>
>>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5)  # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
        [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
        [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5)  # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
        [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
        [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
        [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> gate_nn = th.nn.Linear(5, 1)  # the gate layer that maps node feature to scalar
>>> gap = GlobalAttentionPooling(gate_nn)  # create a Global Attention Pooling layer

情况 1:输入单个图

>>> gap(g1, g1_node_feats)
tensor([[0.7410, 0.6032, 0.8111, 0.5942, 0.4762]],
       grad_fn=<SegmentReduceBackward>)

情况 2:输入一批图

构建一批 DGL 图并将所有图的节点特征连接到一个张量中。

>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats], 0)
>>>
>>> gap(batch_g, batch_f)
tensor([[0.7410, 0.6032, 0.8111, 0.5942, 0.4762],
        [0.2417, 0.2743, 0.5054, 0.7356, 0.6146]],
       grad_fn=<SegmentReduceBackward>)

注意

请参阅我们的 GGNN 示例,了解如何使用 GatedGraphConv 和 GlobalAttentionPooling 层构建可以解决数独的图神经网络。

forward(graph, feat, get_attention=False)[源代码]

计算全局注意力池化。

参数:
  • graph (DGLGraph) – 一个 DGLGraph 或一批 DGLGraph。

  • feat (torch.Tensor) – 输入节点特征,形状为 \((N, D)\),其中 \(N\) 是图中的节点数,\(D\) 是特征的大小。

  • get_attention (bool, 可选) – 是否返回来自 gate_nn 的注意力值。默认为 False。

返回:

  • torch.Tensor – 输出特征,形状为 \((B, D)\),其中 \(B\) 指批量大小。

  • torch.Tensor, 可选 – 注意力值,形状为 \((N, 1)\),其中 \(N\) 是图中的节点数。仅当 get_attentionTrue 时返回此值。