LapPosEncoder
- class dgl.nn.pytorch.gt.LapPosEncoder(model_type, num_layer, k, dim, n_head=1, batch_norm=False, num_post_layer=0)[source]
基类:
Module
拉普拉斯位置编码 (LPE),如 GraphGPS: General Powerful Scalable Graph Transformers 中所介绍的
此模块是使用 Transformer 或 DeepSet 实现的学习型拉普拉斯位置编码模块。
- 参数:
model_type (str) – LPE 的编码器模型类型,只能是 “Transformer” 或 “DeepSet”。
num_layer (int) – Transformer/DeepSet 编码器中的层数。
k (int) – 最小的非平凡特征向量的数量。
dim (int) – 最终拉普拉斯编码的输出大小。
n_head (int, 可选) – Transformer 编码器中的注意力头数。默认值 : 1。
batch_norm (bool, 可选) – 如果为 True,对原始拉普拉斯位置编码应用批量归一化。默认值 : False。
num_post_layer (int, 可选) – 如果 num_post_layer > 0,在池化后应用一个包含
num_post_layer
层的 MLP。默认值 : 0。
示例
>>> import dgl >>> from dgl import LapPE >>> from dgl.nn import LapPosEncoder
>>> transform = LapPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True) >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4])) >>> g = transform(g) >>> eigvals, eigvecs = g.ndata['eigval'], g.ndata['eigvec'] >>> transformer_encoder = LapPosEncoder( model_type="Transformer", num_layer=3, k=5, dim=16, n_head=4 ) >>> pos_encoding = transformer_encoder(eigvals, eigvecs) >>> deepset_encoder = LapPosEncoder( model_type="DeepSet", num_layer=3, k=5, dim=16, num_post_layer=2 ) >>> pos_encoding = deepset_encoder(eigvals, eigvecs)