dgl.DGLGraph.batch_num_edges

DGLGraph.batch_num_edges(etype=None)[源码]

返回批处理中每个图指定边类型的边数。

参数:

etype (str or tuple of str, optional) – 查询的边类型,可以是边类型 (str) 或规范边类型 (str 的 3 元组)。当边类型出现在多个规范边类型中时,必须使用规范边类型。如果图有多种边类型,则必须指定此参数。否则,可以省略。

返回值:

返回批处理中每个图指定类型的边数。其第 i 个元素是第 i 个图指定类型的边数。如果图不是批处理图,将返回一个长度为 1 的列表,其中包含图中边数。

返回类型:

Tensor

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch

查询同构图。

>>> g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
>>> g1.batch_num_edges()
tensor([3])
>>> g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))
>>> bg = dgl.batch([g1, g2])
>>> bg.batch_num_edges()
tensor([3, 4])

查询异构图。

>>> hg1 = dgl.heterograph({
...       ('user', 'plays', 'game') : (torch.tensor([0, 1]), torch.tensor([0, 0]))})
>>> hg2 = dgl.heterograph({
...       ('user', 'plays', 'game') : (torch.tensor([0, 0]), torch.tensor([1, 0]))})
>>> bg = dgl.batch([hg1, hg2])
>>> bg.batch_num_edges('plays')
tensor([2, 2])