EdgePredictor
- class dgl.nn.pytorch.link.EdgePredictor(op, in_feats=None, out_feats=None, bias=False)[source]
基类:
Module
节点表示对的预测器/评分函数
给定一对节点表示 \(h_i\) 和 \(h_j\),它通过以下方式组合它们:
点积
\[h_i^{T} h_j\]或 余弦相似度
\[\frac{h_i^{T} h_j}{{\| h_i \|}_2 \cdot {\| h_j \|}_2}\]或 元素乘积
\[h_i \odot h_j\]或 拼接
\[h_i \Vert h_j\]可选地,它将组合结果传递给一个线性层进行最终预测。
- 参数:
示例
>>> import dgl >>> import torch as th >>> from dgl.nn import EdgePredictor >>> num_nodes = 2 >>> num_edges = 3 >>> in_feats = 4 >>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges) >>> h = th.randn(num_nodes, in_feats) >>> src, dst = g.edges() >>> h_src = h[src] >>> h_dst = h[dst]
示例 1:点积
>>> predictor = EdgePredictor('dot') >>> predictor(h_src, h_dst).shape torch.Size([3, 1]) >>> predictor = EdgePredictor('dot', in_feats, out_feats=3) >>> predictor.reset_parameters() >>> predictor(h_src, h_dst).shape torch.Size([3, 3])
示例 2:余弦相似度
>>> predictor = EdgePredictor('cos') >>> predictor(h_src, h_dst).shape torch.Size([3, 1]) >>> predictor = EdgePredictor('cos', in_feats, out_feats=3) >>> predictor.reset_parameters() >>> predictor(h_src, h_dst).shape torch.Size([3, 3])
示例 3:元素乘积
>>> predictor = EdgePredictor('ele') >>> predictor(h_src, h_dst).shape torch.Size([3, 4]) >>> predictor = EdgePredictor('ele', in_feats, out_feats=3) >>> predictor.reset_parameters() >>> predictor(h_src, h_dst).shape torch.Size([3, 3])
示例 4:拼接
>>> predictor = EdgePredictor('cat') >>> predictor(h_src, h_dst).shape torch.Size([3, 8]) >>> predictor = EdgePredictor('cat', in_feats, out_feats=3) >>> predictor.reset_parameters() >>> predictor(h_src, h_dst).shape torch.Size([3, 3])