dgl.batch
- dgl.batch(graphs, ndata='__ALL__', edata='__ALL__')[source]
将
DGLGraph
集合批量处理成一个图,以实现更高效的图计算。每个输入图成为批量处理后的图的一个不相交的组件。节点和边被重新标记为不相交的段
原始节点 ID
0 ~ N_0
0 ~ N_1
…
0 ~ N_k
新节点 ID
0 ~ N_0
N_0 ~ N_0+N_1
…
sum_{i=0}^{k-1} N_i ~ sum_{i=0}^k N_i
因此,对批量处理后的图进行的许多计算与对每个图分别进行的计算相同,但由于可以轻松并行化而变得更加高效。这使得
dgl.batch
对于处理许多图样本的任务(例如图分类任务)非常有用。对于异构图输入,它们必须共享相同的关系集(即节点类型和边类型),并且该函数将逐个对每种关系进行批量处理。因此,结果也是一个异构图,并且具有与输入相同的关系集。
输入图的节点数和边数可以通过结果图的
DGLGraph.batch_num_nodes()
和DGLGraph.batch_num_edges()
属性访问。对于同构图,它们是 1D 整数张量,每个元素是对应输入图的节点数/边数。对于异构图,它们是 1D 整数张量的字典,以节点类型或边类型作为键。该函数支持对已批量处理的图进行批量处理。结果图的批大小是所有输入图的批大小之和。
默认情况下,通过连接所有输入图的特征张量来批量处理节点/边特征。因此,这要求同名的特征具有相同的数据类型和特征大小。可以将
None
传递给ndata
或edata
参数以阻止特征批量处理,或者传递一个字符串列表以指定要批量处理的特征。要将图反批量处理回列表,请使用
dgl.unbatch()
函数。- 参数:
- 返回值:
批量处理后的图。
- 返回类型:
示例
批量处理同构图
>>> import dgl >>> import torch as th >>> # 4 nodes, 3 edges >>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3]))) >>> # 3 nodes, 4 edges >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0]))) >>> bg = dgl.batch([g1, g2]) >>> bg Graph(num_nodes=7, num_edges=7, ndata_schemes={} edata_schemes={}) >>> bg.batch_size 2 >>> bg.batch_num_nodes() tensor([4, 3]) >>> bg.batch_num_edges() tensor([3, 4]) >>> bg.edges() (tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))
批量处理已批量处理的图
>>> bbg = dgl.batch([bg, bg]) >>> bbg.batch_size 4 >>> bbg.batch_num_nodes() tensor([4, 3, 4, 3]) >>> bbg.batch_num_edges() tensor([3, 4, 3, 4])
批量处理带有特征数据的图
>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3) >>> g1.edata['w'] = th.ones(g1.num_edges(), 2) >>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3) >>> g2.edata['w'] = th.zeros(g2.num_edges(), 2) >>> bg = dgl.batch([g1, g2]) >>> bg.ndata['x'] tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 1, 1], [1, 1, 1], [1, 1, 1]]) >>> bg.edata['w'] tensor([[1, 1], [1, 1], [1, 1], [0, 0], [0, 0], [0, 0], [0, 0]])
批量处理异构图
>>> hg1 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))}) >>> hg2 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))}) >>> bhg = dgl.batch([hg1, hg2]) >>> bhg Graph(num_nodes={'user': 3, 'game': 4}, num_edges={('user', 'plays', 'game'): 5}, metagraph=[('drug', 'game')]) >>> bhg.batch_size 2 >>> bhg.batch_num_nodes() {'user' : tensor([2, 1]), 'game' : tensor([1, 3])} >>> bhg.batch_num_edges() {('user', 'plays', 'game') : tensor([2, 3])}
另请参阅