dgl.DGLGraph.edge_ids
- DGLGraph.edge_ids(u, v, return_uv=False, etype=None)[source]
根据边的两个端点返回边 ID。
- 参数:
u (节点 ID) –
边的源节点 ID。允许的格式为
int
: 单个节点。Int Tensor: 每个元素是一个节点 ID。该 Tensor 必须与图具有相同的设备类型和 ID 数据类型。
iterable[int]: 每个元素是一个节点 ID。
v (节点 ID) –
边的目标节点 ID。允许的格式为
int
: 单个节点。Int Tensor: 每个元素是一个节点 ID。该 Tensor 必须与图具有相同的设备类型和 ID 数据类型。
iterable[int]: 每个元素是一个节点 ID。
return_uv (bool, 可选) – 是否同时返回边以及源节点和目标节点 ID。如果为 False(默认),则假定图是简单图,并且从一个节点到另一个节点只有一条边。如果为 True,则可能存在从一个节点到另一个节点的多个边。
etype (str 或 (str, str, str), 可选) –
边的类型名称。允许的类型名称格式为
(str, str, str)
,分别表示源节点类型、边类型和目标节点类型。或一个
str
边类型名称,如果该名称可以在图中唯一标识一个三元组格式。
如果图只有一种边类型,则可以省略。
- 返回值:
如果
return_uv=False
,它返回一个 tensor 形式的边 ID,其中第 i 个元素是边(u[i], v[i])
的 ID。如果
return_uv=True
,它返回一个包含三个 1D tensor 的元组(eu, ev, e)
。e[i]
是从eu[i]
到ev[i]
的边的 ID。在这种情况下,它返回从eu[i]
到ev[i]
的所有边(包括并行边)。
- 返回类型:
Tensor,或 (Tensor, Tensor, Tensor)
备注
如果图是简单图,
return_uv=False
,并且在某些节点对之间没有边,则会引发错误。如果图是多重图,
return_uv=False
,并且在某些节点对之间有多条边,则返回其中任意一条边。示例
以下示例使用 PyTorch 后端。
>>> import dgl >>> import torch
创建一个同构图。
>>> g = dgl.graph((torch.tensor([0, 0, 1, 1, 1]), torch.tensor([1, 0, 2, 3, 2])))
查询边。
>>> g.edge_ids(0, 0) 1 >>> g.edge_ids(torch.tensor([1, 0]), torch.tensor([3, 1])) tensor([3, 0])
获取节点对之间的所有边。
>>> g.edge_ids(torch.tensor([1, 0]), torch.tensor([3, 1]), return_uv=True) (tensor([1, 0]), tensor([3, 1]), tensor([3, 0]))
如果图有多种边类型,则需要指定边类型。
>>> g = dgl.heterograph({ ... ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])), ... ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])), ... ('user', 'plays', 'game'): (torch.tensor([1, 3]), torch.tensor([2, 3])) ... }) >>> g.edge_ids(torch.tensor([1]), torch.tensor([2]), etype='plays') tensor([0])
当边类型存在歧义时,请改用规范边类型。
>>> g.edge_ids(torch.tensor([0, 1]), torch.tensor([1, 2]), ... etype=('user', 'follows', 'user')) tensor([0, 1]) >>> g.edge_ids(torch.tensor([1, 2]), torch.tensor([2, 3]), ... etype=('user', 'follows', 'game')) tensor([1, 2])