TransE

class dgl.nn.pytorch.link.TransE(num_rels, feats, p=1)[source]

基类: Module

基于论文 Translating Embeddings for Modeling Multi-relational Data 的相似度度量

数学上定义如下:

\[- {\| h + r - t \|}_p\]

其中 \(h\) 是头实体嵌入,\(r\) 是关系嵌入,\(t\) 是尾实体嵌入。

参数:
  • num_rels (int) – 关系类型数量。

  • feats (int) – 嵌入大小。

  • p (int, 可选) – 用于 Lp 范数的 p 值,可以是 1 或 2。

rel_emb

可学习的关系类型嵌入。

类型:

torch.nn.Embedding

示例

>>> import dgl
>>> import torch as th
>>> from dgl.nn import TransE
>>> # input features
>>> num_nodes = 10
>>> num_edges = 30
>>> num_rels = 3
>>> feats = 4
>>> scorer = TransE(num_rels=num_rels, feats=feats)
>>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges)
>>> src, dst = g.edges()
>>> h = th.randn(num_nodes, feats)
>>> h_head = h[src]
>>> h_tail = h[dst]
>>> # Randomly initialize edge relation types for demonstration
>>> rels = th.randint(low=0, high=num_rels, size=(num_edges,))
>>> scorer(h_head, h_tail, rels).shape
torch.Size([30])
forward(h_head, h_tail, rels)[source]

说明

计算三元组得分。

参数 h_head:

头实体特征。张量的形状为 \((E, D)\),其中 \(E\) 是三元组数量,\(D\) 是特征大小。

类型 h_head:

torch.Tensor

参数 h_tail:

尾实体特征。张量的形状为 \((E, D)\),其中 \(E\) 是三元组数量,\(D\) 是特征大小。

类型 h_tail:

torch.Tensor

参数 rels:

关系类型。它是一个形状为 \((E)\) 的 LongTensor,其中 \(E\) 是三元组数量。

类型 rels:

torch.Tensor

返回:

三元组得分。张量的形状为 \((E)\)

返回类型:

torch.Tensor

reset_parameters()[source]

说明

重新初始化可学习参数。