dgl.unbatch

dgl.unbatch(g, node_split=None, edge_split=None)[source]

通过将给定图分割成小图列表来恢复批处理操作。

这是 :func:dgl.batch 的逆操作。如果未给定 node_splitedge_split,则调用输入图的 DGLGraph.batch_num_nodes()DGLGraph.batch_num_edges() 以获取信息。

如果给定 node_splitedge_split 参数,它将根据给定段划分图。必须确保划分是有效的——第 i 个图的边仅连接属于第 i 个图的节点。否则,DGL 将抛出错误。

该函数支持异构图输入,在这种情况下,两个分割段参数应为字典类型——类似于异构图的 DGLGraph.batch_num_nodes()DGLGraph.batch_num_edges() 属性。

参数:
  • g (DGLGraph) – 输入的待解除批处理图。

  • node_split (Tensor, dict[str, Tensor], optional) – 每个结果图的节点数。

  • edge_split (Tensor, dict[str, Tensor], optional) – 每个结果图的边数。

返回值:

解除批处理后的图列表。

返回类型:

list[DGLGraph]

示例

解除批处理一个已批处理的图

>>> import dgl
>>> import torch as th
>>> # 4 nodes, 3 edges
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> # 3 nodes, 4 edges
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> # add features
>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)
>>> g1.edata['w'] = th.ones(g1.num_edges(), 2)
>>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)
>>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)
>>> bg = dgl.batch([g1, g2])
>>> f1, f2 = dgl.unbatch(bg)
>>> f1
Graph(num_nodes=4, num_edges=3,
      ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})
>>> f2
Graph(num_nodes=3, num_edges=4,
      ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})

提供分割参数

>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> g3 = dgl.graph((th.tensor([0]), th.tensor([1])))
>>> bg = dgl.batch([g1, g2, g3])
>>> bg.batch_num_nodes()
tensor([4, 3, 2])
>>> bg.batch_num_edges()
tensor([3, 4, 1])
>>> # unbatch but merge g2 and g3
>>> f1, f2 = dgl.unbatch(bg, th.tensor([4, 5]), th.tensor([3, 5]))
>>> f1
Graph(num_nodes=4, num_edges=3,
      ndata_schemes={}
      edata_schemes={})
>>> f2
Graph(num_nodes=5, num_edges=5,
      ndata_schemes={}
      edata_schemes={})

异构图输入

>>> hg1 = dgl.heterograph({
...     ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))})
>>> hg2 = dgl.heterograph({
...     ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))})
>>> bhg = dgl.batch([hg1, hg2])
>>> f1, f2 = dgl.unbatch(bhg)
>>> f1
Graph(num_nodes={'user': 2, 'game': 1},
      num_edges={('user', 'plays', 'game'): 2},
      metagraph=[('drug', 'game')])
>>> f2
Graph(num_nodes={'user': 1, 'game': 3},
      num_edges={('user', 'plays', 'game'): 3},
      metagraph=[('drug', 'game')])

另请参阅

batch