PathEncoder
- class dgl.nn.pytorch.gt.PathEncoder(max_len, feat_dim, num_heads=1)[source]
基类:
Module
路径编码器,如论文 Do Transformers Really Perform Bad for Graph Representation? 中 Edge Encoding 部分介绍的。
此模块是一个可学习的路径嵌入模块,将每对节点之间的最短路径编码为注意力偏置。
- 参数:
示例
>>> import torch as th >>> import dgl >>> from dgl.nn import PathEncoder >>> from dgl import shortest_dist
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1])) >>> edata = th.rand(8, 16) >>> # Since shortest_dist returns -1 for unreachable node pairs, >>> # edata[-1] should be filled with zero padding. >>> edata = th.cat( (edata, th.zeros(1, 16)), dim=0 ) >>> dist, path = shortest_dist(g, root=None, return_paths=True) >>> path_data = edata[path[:, :, :2]] >>> path_encoder = PathEncoder(2, 16, num_heads=8) >>> out = path_encoder(dist.unsqueeze(0), path_data.unsqueeze(0)) >>> print(out.shape) torch.Size([1, 4, 4, 8])
- forward(dist, path_data)[source]
- 参数:
dist (Tensor) – 带零填充的批处理图的最短路径距离矩阵,形状为 \((B, N, N)\),其中 \(B\) 是批处理图的批大小,\(N\) 是节点的最大数量。
path_data (Tensor) – 带零填充的最短路径上的边特征,形状为 \((B, N, N, L, d)\),其中 \(L\) 是最短路径的最大长度,\(d\) 是
feat_dim
。
- 返回值:
返回作为路径编码的注意力偏置,形状为 \((B, N, N, H)\),其中 \(B\) 是输入图的批大小,\(N\) 是节点的最大数量,\(H\) 是
num_heads
。- 返回类型:
torch.Tensor