TransR
- class dgl.nn.pytorch.link.TransR(num_rels, rfeats, nfeats, p=1)[source]
基类:
Module
相似性度量,出自论文 Learning entity and relation embeddings for knowledge graph completion
数学上,它定义如下:
\[- {\| M_r h + r - M_r t \|}_p\]其中 \(M_r\) 是关系特定的投影矩阵,\(h\) 是头实体嵌入,\(r\) 是关系嵌入,\(t\) 是尾实体嵌入。
- 参数:
- rel_emb
可学习的关系类型嵌入。
- 类型:
torch.nn.Embedding
- rel_project
可学习的关系类型特定投影。
- 类型:
torch.nn.Embedding
示例
>>> import dgl >>> import torch as th >>> from dgl.nn import TransR
>>> # input features >>> num_nodes = 10 >>> num_edges = 30 >>> num_rels = 3 >>> feats = 4
>>> scorer = TransR(num_rels=num_rels, rfeats=2, nfeats=feats) >>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges) >>> src, dst = g.edges() >>> h = th.randn(num_nodes, feats) >>> h_head = h[src] >>> h_tail = h[dst] >>> # Randomly initialize edge relation types for demonstration >>> rels = th.randint(low=0, high=num_rels, size=(num_edges,)) >>> scorer(h_head, h_tail, rels).shape torch.Size([30])
- forward(h_head, h_tail, rels)[source]
对三元组评分。
- 参数:
h_head (torch.Tensor) – 头实体特征。张量形状为 \((E, D)\),其中 \(E\) 是三元组数量,\(D\) 是特征维度。
h_tail (torch.Tensor) – 尾实体特征。张量形状为 \((E, D)\),其中 \(E\) 是三元组数量,\(D\) 是特征维度。
rels (torch.Tensor) – 关系类型。它是形状为 \((E)\) 的 LongTensor,其中 \(E\) 是三元组数量。
- 返回值:
三元组得分。张量形状为 \((E)\)。
- 返回类型:
torch.Tensor