BiasedMHA

class dgl.nn.pytorch.gt.BiasedMHA(feat_size, num_heads, bias=True, attn_bias_type='add', attn_drop=0.1)[source]

基类:Module

带图注意力偏置的密集多头注意力模块。

计算节点之间的注意力,其中注意力偏置来自图结构,如论文 Do Transformers Really Perform Bad for Graph Representation? 中所介绍的。

\[\text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b)\]

\(Q\)\(K\) 是节点的特征表示。\(d\) 是对应的 feat_size\(b\) 是注意力偏置,根据运算符 \(\circ\) 的不同,其可以是加性的或乘性的。

参数:
  • feat_size (int) – 特征大小。

  • num_heads (int) – 注意力头的数量,`feat_size` 必须能被此数整除。

  • bias (bool, optional) – 如果为 True,则在线性投影中使用偏置。默认值:True。

  • attn_bias_type (str, optional) –

    用于修改注意力的注意力偏置类型。可选择 ‘add’ 或 ‘mul’。默认值:‘add’。

    • ‘add’ 表示加法注意力偏置。

    • ‘mul’ 表示乘法注意力偏置。

  • attn_drop (float, optional) – 注意力权重的 Dropout 概率。默认值:0.1。

示例

>>> import torch as th
>>> from dgl.nn import BiasedMHA
>>> ndata = th.rand(16, 100, 512)
>>> bias = th.rand(16, 100, 100, 8)
>>> net = BiasedMHA(feat_size=512, num_heads=8)
>>> out = net(ndata, bias)
forward(ndata, attn_bias=None, attn_mask=None)[source]

前向计算。

参数:
  • ndata (torch.Tensor) – 一个 3D 输入张量。形状:(batch_size, N, feat_size),其中 N 是最大节点数。

  • attn_bias (torch.Tensor, optional) – 用于修改注意力的注意力偏置。形状:(batch_size, N, N, num_heads)。

  • attn_mask (torch.Tensor, optional) – 用于避免在无效位置进行计算的注意力掩码,无效位置由 `True` 值表示。形状:(batch_size, N, N)。注意:对于对应不存在节点的行,请确保至少一个条目设置为 `False`,以防止 softmax 得到 NaNs。

返回:

y – 输出张量。形状:(batch_size, N, feat_size)

返回类型:

torch.Tensor

reset_parameters()[source]

初始化投影矩阵的参数,设置与论文原始实现相同。