SpatialEncoder
- class dgl.nn.pytorch.gt.SpatialEncoder(max_dist, num_heads=1)[source]
基类:
Module
Spatial Encoder,如 Do Transformers Really Perform Bad for Graph Representation? 中所述。
该模块是一个可学习的空间嵌入模块,用于编码每对节点之间的最短距离以生成注意力偏置。
- 参数:
示例
>>> import torch as th >>> import dgl >>> from dgl.nn import SpatialEncoder >>> from dgl import shortest_dist
>>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1])) >>> g2 = dgl.graph(([0,1], [1,0])) >>> n1, n2 = g1.num_nodes(), g2.num_nodes() >>> # use -1 padding since shortest_dist returns -1 for unreachable node pairs >>> dist = -th.ones((2, 4, 4), dtype=th.long) >>> dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False) >>> dist[1, :n2, :n2] = shortest_dist(g2, root=None, return_paths=False) >>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8) >>> out = spatial_encoder(dist) >>> print(out.shape) torch.Size([2, 4, 4, 8])