dgl.DGLGraph.set_batch_num_nodes
- DGLGraph.set_batch_num_nodes(val)[source]
为批处理中的每个图手动设置指定节点类型的节点数。
- 参数:
val (Tensor 或 Mapping[str, Tensor]) – 存储批处理中所有节点类型下每个图节点数的字典。如果图仅有一种节点类型,
val
也可以是单个数组,指示批处理中每个图的节点数。
备注
此 API 总是与
set_batch_num_edges
一起使用来指定图的批处理信息,它也不检查图结构和批处理信息之间的对应关系,用户必须保证批处理中不会有跨图边。示例
以下示例使用 PyTorch 后端。
>>> import dgl >>> import torch
创建同构图。
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
手动设置批处理信息
>>> g.set_batch_num_nodes(torch.tensor([3, 3])) >>> g.set_batch_num_edges(torch.tensor([3, 3]))
解批处理图。
>>> dgl.unbatch(g) [Graph(num_nodes=3, num_edges=3, ndata_schemes={} edata_schemes={}), Graph(num_nodes=3, num_edges=3, ndata_schemes={} edata_schemes={})]
创建异构图。
>>> hg = dgl.heterograph({ ... ('user', 'plays', 'game') : ([0, 1, 2, 3, 4, 5], [0, 1, 1, 3, 3, 2]), ... ('developer', 'develops', 'game') : ([0, 1, 2, 3], [1, 0, 3, 2])})
手动设置批处理信息。
>>> hg.set_batch_num_nodes({ ... 'user': torch.tensor([3, 3]), ... 'game': torch.tensor([2, 2]), ... 'developer': torch.tensor([2, 2])}) >>> hg.set_batch_num_edges({ ... ('user', 'plays', 'game'): torch.tensor([3, 3]), ... ('developer', 'develops', 'game'): torch.tensor([2, 2])})
解批处理图。
>>> g1, g2 = dgl.unbatch(hg) >>> g1 Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3}, num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3}, metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')]) >>> g2 Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3}, num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3}, metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])
另请参阅