EdgePredictor

class dgl.nn.pytorch.link.EdgePredictor(op, in_feats=None, out_feats=None, bias=False)[source]

基类:Module

节点表示对的预测器/评分函数

给定一对节点表示 \(h_i\)\(h_j\),它通过以下方式组合它们:

点积

\[h_i^{T} h_j\]

余弦相似度

\[\frac{h_i^{T} h_j}{{\| h_i \|}_2 \cdot {\| h_j \|}_2}\]

元素乘积

\[h_i \odot h_j\]

拼接

\[h_i \Vert h_j\]

可选地,它将组合结果传递给一个线性层进行最终预测。

参数:
  • op (str) – 应用的操作。它可以是 'dot', 'cos', 'ele', 或 'cat',依次对应上面公式中的操作。

  • in_feats (int, optional) – 输入特征 \(h_i\)\(h_j\) 的大小。仅在应用线性层时需要。

  • out_feats (int, optional) – 输出特征的大小。仅在应用线性层时需要。

  • bias (bool, optional) – 如果应用线性层,是否使用偏置项。

示例

>>> import dgl
>>> import torch as th
>>> from dgl.nn import EdgePredictor
>>> num_nodes = 2
>>> num_edges = 3
>>> in_feats = 4
>>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges)
>>> h = th.randn(num_nodes, in_feats)
>>> src, dst = g.edges()
>>> h_src = h[src]
>>> h_dst = h[dst]

示例 1:点积

>>> predictor = EdgePredictor('dot')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 1])
>>> predictor = EdgePredictor('dot', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])

示例 2:余弦相似度

>>> predictor = EdgePredictor('cos')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 1])
>>> predictor = EdgePredictor('cos', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])

示例 3:元素乘积

>>> predictor = EdgePredictor('ele')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 4])
>>> predictor = EdgePredictor('ele', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])

示例 4:拼接

>>> predictor = EdgePredictor('cat')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 8])
>>> predictor = EdgePredictor('cat', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])
forward(h_src, h_dst)[source]

描述

对节点表示对进行预测。

参数 h_src:

源节点特征。张量的形状为 \((E, D_{in})\),其中 \(E\) 是边/节点对的数量,\(D_{in}\) 是输入特征大小。

类型 h_src:

torch.Tensor

参数 h_dst:

目标节点特征。张量的形状为 \((E, D_{in})\),其中 \(E\) 是边/节点对的数量,\(D_{in}\) 是输入特征大小。

类型 h_dst:

torch.Tensor

返回值:

输出特征。

返回类型:

torch.Tensor

reset_parameters()[source]

描述

重新初始化可学习参数。