dgl.unbatch
- dgl.unbatch(g, node_split=None, edge_split=None)[source]
通过将给定图分割成小图列表来恢复批处理操作。
这是 :func:
dgl.batch
的逆操作。如果未给定node_split
或edge_split
,则调用输入图的DGLGraph.batch_num_nodes()
和DGLGraph.batch_num_edges()
以获取信息。如果给定
node_split
或edge_split
参数,它将根据给定段划分图。必须确保划分是有效的——第 i 个图的边仅连接属于第 i 个图的节点。否则,DGL 将抛出错误。该函数支持异构图输入,在这种情况下,两个分割段参数应为字典类型——类似于异构图的
DGLGraph.batch_num_nodes()
和DGLGraph.batch_num_edges()
属性。- 参数:
- 返回值:
解除批处理后的图列表。
- 返回类型:
示例
解除批处理一个已批处理的图
>>> import dgl >>> import torch as th >>> # 4 nodes, 3 edges >>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3]))) >>> # 3 nodes, 4 edges >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0]))) >>> # add features >>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3) >>> g1.edata['w'] = th.ones(g1.num_edges(), 2) >>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3) >>> g2.edata['w'] = th.zeros(g2.num_edges(), 2) >>> bg = dgl.batch([g1, g2]) >>> f1, f2 = dgl.unbatch(bg) >>> f1 Graph(num_nodes=4, num_edges=3, ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)}) >>> f2 Graph(num_nodes=3, num_edges=4, ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})
提供分割参数
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3]))) >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0]))) >>> g3 = dgl.graph((th.tensor([0]), th.tensor([1]))) >>> bg = dgl.batch([g1, g2, g3]) >>> bg.batch_num_nodes() tensor([4, 3, 2]) >>> bg.batch_num_edges() tensor([3, 4, 1]) >>> # unbatch but merge g2 and g3 >>> f1, f2 = dgl.unbatch(bg, th.tensor([4, 5]), th.tensor([3, 5])) >>> f1 Graph(num_nodes=4, num_edges=3, ndata_schemes={} edata_schemes={}) >>> f2 Graph(num_nodes=5, num_edges=5, ndata_schemes={} edata_schemes={})
异构图输入
>>> hg1 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))}) >>> hg2 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))}) >>> bhg = dgl.batch([hg1, hg2]) >>> f1, f2 = dgl.unbatch(bhg) >>> f1 Graph(num_nodes={'user': 2, 'game': 1}, num_edges={('user', 'plays', 'game'): 2}, metagraph=[('drug', 'game')]) >>> f2 Graph(num_nodes={'user': 1, 'game': 3}, num_edges={('user', 'plays', 'game'): 3}, metagraph=[('drug', 'game')])
另请参阅