SVDPE
- class dgl.transforms.SVDPE(k, feat_name='svd_pe', padding=False, random_flip=True)[source]
继承自:
BaseTransform
基于 SVD 的位置编码,如论文 Global Self-Attention as a Replacement for Graph Convolution 中所介绍
此函数计算最大的 \(k\) 个奇异值及其对应的左奇异向量和右奇异向量,形成位置编码,可以存储在 ndata 中。
- 参数:
k (int) – 用于位置编码的最大的奇异值及其对应的奇异向量的数量。
feat_name (str, optional) – 在 ndata 中存储计算出的位置编码的名称。默认值:
svd_pe
padding (bool, optional) – 如果为 False,当 \(k > N\) 时(\(N\) 是图
g
中的节点数)会引发错误。如果为 True,当 \(k > N\) 时会在编码末尾添加零填充。默认值: False。random_flip (bool, optional) – 如果为 True,随机翻转编码向量的符号。建议在训练期间启用以获得更好的泛化能力。默认值: True。
示例
>>> import dgl >>> from dgl import SVDPE
>>> transform = SVDPE(k=2, feat_name="svd_pe") >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4])) >>> g_ = transform(g) >>> print(g_.ndata['svd_pe']) tensor([[-6.3246e-01, -1.1373e-07, -6.3246e-01, 0.0000e+00], [-6.3246e-01, 7.6512e-01, -6.3246e-01, -7.6512e-01], [ 6.3246e-01, 4.7287e-01, 6.3246e-01, -4.7287e-01], [-6.3246e-01, -7.6512e-01, -6.3246e-01, 7.6512e-01], [ 6.3246e-01, -4.7287e-01, 6.3246e-01, 4.7287e-01]])