SpatialEncoder

class dgl.nn.pytorch.gt.SpatialEncoder(max_dist, num_heads=1)[source]

基类: Module

Spatial Encoder,如 Do Transformers Really Perform Bad for Graph Representation? 中所述。

该模块是一个可学习的空间嵌入模块,用于编码每对节点之间的最短距离以生成注意力偏置。

参数:
  • max_dist (int) – 要编码的每对节点之间最短路径距离的上限。所有距离将被截断到范围 [0, max_dist]

  • num_heads (int, optional) – 应用多头注意力机制时的注意力头数量。默认值:1。

示例

>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder
>>> from dgl import shortest_dist
>>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> g2 = dgl.graph(([0,1], [1,0]))
>>> n1, n2 = g1.num_nodes(), g2.num_nodes()
>>> # use -1 padding since shortest_dist returns -1 for unreachable node pairs
>>> dist = -th.ones((2, 4, 4), dtype=th.long)
>>> dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False)
>>> dist[1, :n2, :n2] = shortest_dist(g2, root=None, return_paths=False)
>>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8)
>>> out = spatial_encoder(dist)
>>> print(out.shape)
torch.Size([2, 4, 4, 8])
forward(dist)[source]
参数:

dist (Tensor) – 批处理图的最短路径距离,使用 -1 进行填充,形状为 \((B, N, N)\) 的张量,其中 \(B\) 是批处理图的批大小,\(N\) 是最大节点数。

返回值:

返回注意力偏置,作为空间编码,形状为 \((B, N, N, H)\),其中 \(H\)num_heads

返回值类型:

torch.Tensor