dgl.DGLGraph.set_batch_num_edges

DGLGraph.set_batch_num_edges(val)[source]

手动设置批处理中每个图的指定边类型的边数。

参数:

val (Tensor or Mapping[str, Tensor]) – 字典,存储批处理中所有边类型下每个图的边数。如果图只有一种边类型,val 也可以是一个单独的数组,表示批处理中每个图的边数。

备注

此 API 通常与 set_batch_num_nodes 一起使用来指定图的批处理信息,它不检查图结构和批处理信息之间的一致性,用户必须保证批处理中不会出现跨图的边。

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch

创建一个同构图。

>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))

手动设置批处理信息

>>> g.set_batch_num_nodes(torch.tensor([3, 3]))
>>> g.set_batch_num_edges(torch.tensor([3, 3]))

取消图的批处理。

>>> dgl.unbatch(g)
[Graph(num_nodes=3, num_edges=3,
      ndata_schemes={}
      edata_schemes={}), Graph(num_nodes=3, num_edges=3,
      ndata_schemes={}
      edata_schemes={})]

创建一个异构图。

>>> hg = dgl.heterograph({
...      ('user', 'plays', 'game') : ([0, 1, 2, 3, 4, 5], [0, 1, 1, 3, 3, 2]),
...      ('developer', 'develops', 'game') : ([0, 1, 2, 3], [1, 0, 3, 2])})

手动设置批处理信息。

>>> hg.set_batch_num_nodes({
...     'user': torch.tensor([3, 3]),
...     'game': torch.tensor([2, 2]),
...     'developer': torch.tensor([2, 2])})
>>> hg.set_batch_num_edges(
...     {('user', 'plays', 'game'): torch.tensor([3, 3]),
...     ('developer', 'develops', 'game'): torch.tensor([2, 2])})

取消图的批处理。

>>> g1, g2 = dgl.unbatch(hg)
>>> g1
Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3},
      num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3},
      metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])
>>> g2
Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3},
      num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3},
      metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])