PathEncoder

class dgl.nn.pytorch.gt.PathEncoder(max_len, feat_dim, num_heads=1)[source]

基类: Module

路径编码器,如论文 Do Transformers Really Perform Bad for Graph Representation? 中 Edge Encoding 部分介绍的。

此模块是一个可学习的路径嵌入模块,将每对节点之间的最短路径编码为注意力偏置。

参数:
  • max_len (int) – 要编码的每条路径中的最大边数。超出部分将被截断,即截断序列号不小于 max_len 的边。

  • feat_dim (int) – 输入图中边特征的维度。

  • num_heads (int, 可选) – 如果应用多头注意力机制,注意力头的数量。默认值:1。

示例

>>> import torch as th
>>> import dgl
>>> from dgl.nn import PathEncoder
>>> from dgl import shortest_dist
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> edata = th.rand(8, 16)
>>> # Since shortest_dist returns -1 for unreachable node pairs,
>>> # edata[-1] should be filled with zero padding.
>>> edata = th.cat(
        (edata, th.zeros(1, 16)), dim=0
    )
>>> dist, path = shortest_dist(g, root=None, return_paths=True)
>>> path_data = edata[path[:, :, :2]]
>>> path_encoder = PathEncoder(2, 16, num_heads=8)
>>> out = path_encoder(dist.unsqueeze(0), path_data.unsqueeze(0))
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
forward(dist, path_data)[source]
参数:
  • dist (Tensor) – 带零填充的批处理图的最短路径距离矩阵,形状为 \((B, N, N)\),其中 \(B\) 是批处理图的批大小,\(N\) 是节点的最大数量。

  • path_data (Tensor) – 带零填充的最短路径上的边特征,形状为 \((B, N, N, L, d)\),其中 \(L\) 是最短路径的最大长度,\(d\)feat_dim

返回值:

返回作为路径编码的注意力偏置,形状为 \((B, N, N, H)\),其中 \(B\) 是输入图的批大小,\(N\) 是节点的最大数量,\(H\)num_heads

返回类型:

torch.Tensor