EGTLayer

class dgl.nn.pytorch.gt.EGTLayer(feat_size, edge_feat_size, num_heads, num_virtual_nodes, dropout=0, attn_dropout=0, activation=ELU(alpha=1.0), edge_update=True)[source]

基类: Module

边缘增强图 Transformer (EGT) 的 EGTLayer,如 `Global Self-Attention as a Replacement for Graph Convolution Reference `<https://arxiv.org/pdf/2108.03348.pdf>`_ 中介绍的那样

参数:
  • feat_size (int) – 节点特征大小。

  • edge_feat_size (int) – 边特征大小。

  • num_heads (int) – 注意力头的数量,feat_size 必须能被其整除。

  • num_virtual_nodes (int) – 虚拟节点的数量。

  • dropout (float, optional) – Dropout 概率。默认值:0.0。

  • attn_dropout (float, optional) – 注意力 dropout 概率。默认值:0.0。

  • activation (callable activation layer, optional) – 激活函数。默认值:nn.ELU()。

  • edge_update (bool, optional) – 是否更新边嵌入。默认值:True。

示例

>>> import torch as th
>>> from dgl.nn import EGTLayer
>>> batch_size = 16
>>> num_nodes = 100
>>> feat_size, edge_feat_size = 128, 32
>>> nfeat = th.rand(batch_size, num_nodes, feat_size)
>>> efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size)
>>> net = EGTLayer(
        feat_size=feat_size,
        edge_feat_size=edge_feat_size,
        num_heads=8,
        num_virtual_nodes=4,
    )
>>> out = net(nfeat, efeat)
forward(nfeat, efeat, mask=None)[source]

前向计算。注意:如果 num_virtual_nodes > 0,则 nfeatefeat 应该使用虚拟节点的嵌入进行填充,而 mask 中虚拟节点对应的值应该填充为 0。填充应放在最前面。

参数:
  • nfeat (torch.Tensor) – 一个 3D 输入张量。形状:(batch_size, N, feat_size),其中 N 是最大节点数和虚拟节点数的总和。

  • efeat (torch.Tensor) – 用于注意力计算和自更新的边嵌入。形状:(batch_size, N, N, edge_feat_size)。

  • mask (torch.Tensor, optional) – 用于避免在无效位置进行计算的注意力掩码,其中有效位置用 0 表示,无效位置用 -inf 表示。形状:(batch_size, N, N)。默认值:None。

返回值:

  • nfeat (torch.Tensor) – 输出节点嵌入。形状:(batch_size, N, feat_size)。

  • efeat (torch.Tensor, optional) – 输出边嵌入。形状:(batch_size, N, N, edge_feat_size)。仅当 edge_update 为 True 时返回。