dgl.slice_batch

dgl.slice_batch(g, gid, store_ids=False)[source]

从批量图中获取特定图。

参数:
  • g (DGLGraph) – 输入的批量图。

  • gid (int) – 要检索的图的 ID。

  • store_ids (bool) – 如果为 True,它将在结果图的 ndataedata 中分别以 dgl.NIDdgl.EID 为名存储提取的节点和边的原始 ID。

返回值:

检索到的图。

返回值类型:

DGLGraph

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch

创建一个批量图。

>>> g1 = dgl.graph(([0, 1], [2, 3]))
>>> g2 = dgl.graph(([1], [2]))
>>> bg = dgl.batch([g1, g2])

获取第二个组件图。

>>> g = dgl.slice_batch(bg, 1)
>>> print(g)
Graph(num_nodes=3, num_edges=1,
      ndata_schemes={}
      edata_schemes={})