dgl.DGLGraph.edge_ids

DGLGraph.edge_ids(u, v, return_uv=False, etype=None)[source]

根据边的两个端点返回边 ID。

参数:
  • u (节点 ID) –

    边的源节点 ID。允许的格式为

    • int: 单个节点。

    • Int Tensor: 每个元素是一个节点 ID。该 Tensor 必须与图具有相同的设备类型和 ID 数据类型。

    • iterable[int]: 每个元素是一个节点 ID。

  • v (节点 ID) –

    边的目标节点 ID。允许的格式为

    • int: 单个节点。

    • Int Tensor: 每个元素是一个节点 ID。该 Tensor 必须与图具有相同的设备类型和 ID 数据类型。

    • iterable[int]: 每个元素是一个节点 ID。

  • return_uv (bool, 可选) – 是否同时返回边以及源节点和目标节点 ID。如果为 False(默认),则假定图是简单图,并且从一个节点到另一个节点只有一条边。如果为 True,则可能存在从一个节点到另一个节点的多个边。

  • etype (str(str, str, str), 可选) –

    边的类型名称。允许的类型名称格式为

    • (str, str, str),分别表示源节点类型、边类型和目标节点类型。

    • 或一个 str 边类型名称,如果该名称可以在图中唯一标识一个三元组格式。

    如果图只有一种边类型,则可以省略。

返回值:

  • 如果 return_uv=False,它返回一个 tensor 形式的边 ID,其中第 i 个元素是边 (u[i], v[i]) 的 ID。

  • 如果 return_uv=True,它返回一个包含三个 1D tensor 的元组 (eu, ev, e)e[i] 是从 eu[i]ev[i] 的边的 ID。在这种情况下,它返回从 eu[i]ev[i] 的所有边(包括并行边)。

返回类型:

Tensor,或 (Tensor, Tensor, Tensor)

备注

如果图是简单图,return_uv=False,并且在某些节点对之间没有边,则会引发错误。

如果图是多重图,return_uv=False,并且在某些节点对之间有多条边,则返回其中任意一条边。

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch

创建一个同构图。

>>> g = dgl.graph((torch.tensor([0, 0, 1, 1, 1]), torch.tensor([1, 0, 2, 3, 2])))

查询边。

>>> g.edge_ids(0, 0)
1
>>> g.edge_ids(torch.tensor([1, 0]), torch.tensor([3, 1]))
tensor([3, 0])

获取节点对之间的所有边。

>>> g.edge_ids(torch.tensor([1, 0]), torch.tensor([3, 1]), return_uv=True)
(tensor([1, 0]), tensor([3, 1]), tensor([3, 0]))

如果图有多种边类型,则需要指定边类型。

>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),
...     ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])),
...     ('user', 'plays', 'game'): (torch.tensor([1, 3]), torch.tensor([2, 3]))
... })
>>> g.edge_ids(torch.tensor([1]), torch.tensor([2]), etype='plays')
tensor([0])

当边类型存在歧义时,请改用规范边类型。

>>> g.edge_ids(torch.tensor([0, 1]), torch.tensor([1, 2]),
...            etype=('user', 'follows', 'user'))
tensor([0, 1])
>>> g.edge_ids(torch.tensor([1, 2]), torch.tensor([2, 3]),
...            etype=('user', 'follows', 'game'))
tensor([1, 2])