DegreeEncoder
- class dgl.nn.pytorch.gt.DegreeEncoder(max_degree, embedding_dim, direction='both')[source]
基类:
Module
度编码器(Degree Encoder),如论文 Do Transformers Really Perform Bad for Graph Representation? 中介绍的那样。
这个模块是一个可学习的度嵌入模块。
- 参数:
示例
>>> import dgl >>> from dgl.nn import DegreeEncoder >>> import torch as th >>> from torch.nn.utils.rnn import pad_sequence
>>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1])) >>> g2 = dgl.graph(([0,1], [1,0])) >>> in_degree = pad_sequence([g1.in_degrees(), g2.in_degrees()], batch_first=True) >>> out_degree = pad_sequence([g1.out_degrees(), g2.out_degrees()], batch_first=True) >>> print(in_degree.shape) torch.Size([2, 4]) >>> degree_encoder = DegreeEncoder(5, 16) >>> degree_embedding = degree_encoder(th.stack((in_degree, out_degree))) >>> print(degree_embedding.shape) torch.Size([2, 4, 16])