EGTLayer
- class dgl.nn.pytorch.gt.EGTLayer(feat_size, edge_feat_size, num_heads, num_virtual_nodes, dropout=0, attn_dropout=0, activation=ELU(alpha=1.0), edge_update=True)[source]
基类:
Module
边缘增强图 Transformer (EGT) 的 EGTLayer,如 `Global Self-Attention as a Replacement for Graph Convolution Reference `<https://arxiv.org/pdf/2108.03348.pdf>`_ 中介绍的那样
- 参数:
feat_size (int) – 节点特征大小。
edge_feat_size (int) – 边特征大小。
num_heads (int) – 注意力头的数量,feat_size 必须能被其整除。
num_virtual_nodes (int) – 虚拟节点的数量。
dropout (float, optional) – Dropout 概率。默认值:0.0。
attn_dropout (float, optional) – 注意力 dropout 概率。默认值:0.0。
activation (callable activation layer, optional) – 激活函数。默认值:nn.ELU()。
edge_update (bool, optional) – 是否更新边嵌入。默认值:True。
示例
>>> import torch as th >>> from dgl.nn import EGTLayer
>>> batch_size = 16 >>> num_nodes = 100 >>> feat_size, edge_feat_size = 128, 32 >>> nfeat = th.rand(batch_size, num_nodes, feat_size) >>> efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size) >>> net = EGTLayer( feat_size=feat_size, edge_feat_size=edge_feat_size, num_heads=8, num_virtual_nodes=4, ) >>> out = net(nfeat, efeat)
- forward(nfeat, efeat, mask=None)[source]
前向计算。注意:如果
num_virtual_nodes
> 0,则nfeat
和efeat
应该使用虚拟节点的嵌入进行填充,而mask
中虚拟节点对应的值应该填充为 0。填充应放在最前面。- 参数:
nfeat (torch.Tensor) – 一个 3D 输入张量。形状:(batch_size, N,
feat_size
),其中 N 是最大节点数和虚拟节点数的总和。efeat (torch.Tensor) – 用于注意力计算和自更新的边嵌入。形状:(batch_size, N, N,
edge_feat_size
)。mask (torch.Tensor, optional) – 用于避免在无效位置进行计算的注意力掩码,其中有效位置用 0 表示,无效位置用 -inf 表示。形状:(batch_size, N, N)。默认值:None。
- 返回值:
nfeat (torch.Tensor) – 输出节点嵌入。形状:(batch_size, N,
feat_size
)。efeat (torch.Tensor, optional) – 输出边嵌入。形状:(batch_size, N, N,
edge_feat_size
)。仅当edge_update
为 True 时返回。