JumpingKnowledge
- class dgl.nn.pytorch.utils.JumpingKnowledge(mode='cat', in_feats=None, num_layers=None)[source]
基类:
Module
来自 Representation Learning on Graphs with Jumping Knowledge Networks 的 Jumping Knowledge 聚合模块
它通过以下方式聚合多个 GNN 层的输出表示:
拼接
\[h_i^{(1)} \, \Vert \, \ldots \, \Vert \, h_i^{(T)}\]或 **最大池化**
\[\max \left( h_i^{(1)}, \ldots, h_i^{(T)} \right)\]或 **LSTM**
\[\sum_{t=1}^T \alpha_i^{(t)} h_i^{(t)}\]其中注意力分数 \(\alpha_i^{(t)}\) 通过 BiLSTM 获得
- 参数:
示例
>>> import dgl >>> import torch as th >>> from dgl.nn import JumpingKnowledge
>>> # Output representations of two GNN layers >>> num_nodes = 3 >>> in_feats = 4 >>> feat_list = [th.zeros(num_nodes, in_feats), th.ones(num_nodes, in_feats)]
>>> # Case1 >>> model = JumpingKnowledge() >>> model(feat_list).shape torch.Size([3, 8])
>>> # Case2 >>> model = JumpingKnowledge(mode='max') >>> model(feat_list).shape torch.Size([3, 4])
>>> # Case3 >>> model = JumpingKnowledge(mode='max', in_feats=in_feats, num_layers=len(feat_list)) >>> model(feat_list).shape torch.Size([3, 4])