dgl.DGLGraph.batch_num_nodes
- DGLGraph.batch_num_nodes(ntype=None)[源]
返回批处理中每个图具有指定节点类型的节点数量。
- 参数:
ntype (str, 可选) – 要查询的节点类型。如果图有多种节点类型,必须指定此参数。否则,可以省略。如果图不是批处理图,它将返回一个长度为 1 的列表,其中包含图中节点的数量。
- 返回:
批处理中每个图具有指定类型的节点数量。其第 i 个元素是批处理中第 i 个图具有指定类型的节点数量。
- 返回类型:
Tensor
示例
以下示例使用 PyTorch 后端。
>>> import dgl >>> import torch
查询同构图。
>>> g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) >>> g1.batch_num_nodes() tensor([4]) >>> g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0]))) >>> bg = dgl.batch([g1, g2]) >>> bg.batch_num_nodes() tensor([4, 3])
查询异构图。
>>> hg1 = dgl.heterograph({ ... ('user', 'plays', 'game') : (torch.tensor([0, 1]), torch.tensor([0, 0]))}) >>> hg2 = dgl.heterograph({ ... ('user', 'plays', 'game') : (torch.tensor([0, 0]), torch.tensor([1, 0]))}) >>> bg = dgl.batch([hg1, hg2]) >>> bg.batch_num_nodes('user') tensor([2, 1])