dgl.batch

dgl.batch(graphs, ndata='__ALL__', edata='__ALL__')[source]

DGLGraph 集合批量处理成一个图,以实现更高效的图计算。

每个输入图成为批量处理后的图的一个不相交的组件。节点和边被重新标记为不相交的段

原始节点 ID

0 ~ N_0

0 ~ N_1

0 ~ N_k

新节点 ID

0 ~ N_0

N_0 ~ N_0+N_1

sum_{i=0}^{k-1} N_i ~ sum_{i=0}^k N_i

因此,对批量处理后的图进行的许多计算与对每个图分别进行的计算相同,但由于可以轻松并行化而变得更加高效。这使得 dgl.batch 对于处理许多图样本的任务(例如图分类任务)非常有用。

对于异构图输入,它们必须共享相同的关系集(即节点类型和边类型),并且该函数将逐个对每种关系进行批量处理。因此,结果也是一个异构图,并且具有与输入相同的关系集。

输入图的节点数和边数可以通过结果图的 DGLGraph.batch_num_nodes()DGLGraph.batch_num_edges() 属性访问。对于同构图,它们是 1D 整数张量,每个元素是对应输入图的节点数/边数。对于异构图,它们是 1D 整数张量的字典,以节点类型或边类型作为键。

该函数支持对已批量处理的图进行批量处理。结果图的批大小是所有输入图的批大小之和。

默认情况下,通过连接所有输入图的特征张量来批量处理节点/边特征。因此,这要求同名的特征具有相同的数据类型和特征大小。可以将 None 传递给 ndataedata 参数以阻止特征批量处理,或者传递一个字符串列表以指定要批量处理的特征。

要将图反批量处理回列表,请使用 dgl.unbatch() 函数。

参数:
  • graphs (list[DGLGraph]) – 输入图。

  • ndata (list[str], None, optional) – 要批量处理的节点特征。

  • edata (list[str], None, optional) – 要批量处理的边特征。

返回值:

批量处理后的图。

返回类型:

DGLGraph

示例

批量处理同构图

>>> import dgl
>>> import torch as th
>>> # 4 nodes, 3 edges
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> # 3 nodes, 4 edges
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> bg = dgl.batch([g1, g2])
>>> bg
Graph(num_nodes=7, num_edges=7,
      ndata_schemes={}
      edata_schemes={})
>>> bg.batch_size
2
>>> bg.batch_num_nodes()
tensor([4, 3])
>>> bg.batch_num_edges()
tensor([3, 4])
>>> bg.edges()
(tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))

批量处理已批量处理的图

>>> bbg = dgl.batch([bg, bg])
>>> bbg.batch_size
4
>>> bbg.batch_num_nodes()
tensor([4, 3, 4, 3])
>>> bbg.batch_num_edges()
tensor([3, 4, 3, 4])

批量处理带有特征数据的图

>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)
>>> g1.edata['w'] = th.ones(g1.num_edges(), 2)
>>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)
>>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)
>>> bg = dgl.batch([g1, g2])
>>> bg.ndata['x']
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1]])
>>> bg.edata['w']
tensor([[1, 1],
        [1, 1],
        [1, 1],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0]])

批量处理异构图

>>> hg1 = dgl.heterograph({
...     ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))})
>>> hg2 = dgl.heterograph({
...     ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))})
>>> bhg = dgl.batch([hg1, hg2])
>>> bhg
Graph(num_nodes={'user': 3, 'game': 4},
      num_edges={('user', 'plays', 'game'): 5},
      metagraph=[('drug', 'game')])
>>> bhg.batch_size
2
>>> bhg.batch_num_nodes()
{'user' : tensor([2, 1]), 'game' : tensor([1, 3])}
>>> bhg.batch_num_edges()
{('user', 'plays', 'game') : tensor([2, 3])}

另请参阅

unbatch