dgl.to_block
- dgl.to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None)[source]
将图转换为二部结构的 块 用于消息传递。
块 是一个由两组节点组成的图: 源 节点和 目标 节点。源节点和目标节点可以有多种节点类型。所有边都从源节点连接到目标节点。
具体来说,源节点和目标节点将具有与原始图中相同的节点类型。DGL 将原始图中的每条边
(u, v)
及其边类型(utype, etype, vtype)
映射到源端类型为utype
的节点IDu
连接到目标端类型为vtype
的节点IDv
的类型为etype
的边。对于
to_block()
返回的块,块的目标节点将只包含至少有一条入边的任何类型的节点。块的源节点将只包含出现在目标节点中的节点,以及至少有一条出边连接到目标节点之一的节点。如果
dst_nodes
参数不为 None,则目标节点由该参数指定。- 参数:
graph (DGLGraph) – 图。可以在 CPU 或 GPU 上。
dst_nodes (Tensor or dict[str, Tensor], optional) –
目标节点列表。
如果给定的是 Tensor,则图必须只有一种节点类型。
如果给定,它必须是所有至少有一条入边的节点的超集。否则将引发错误。
include_dst_in_src (bool) –
如果为 False,则不将目标节点包含在源节点中。
(默认值: True)
src_nodes (Tensor or disct[str, Tensor], optional) –
源节点列表(如果 include_dst_in_src 为 True,则包含前缀为目标节点的节点)。
如果给定的是 Tensor,则图必须只有一种节点类型。
- 返回:
描述该块的新图。
块两侧每种类型的诱导节点ID将存储在特征
dgl.NID
中。每种类型的诱导边ID将存储在特征
dgl.EID
中。- 返回类型:
DGLBlock
- 引发:
DGLError – 如果指定了
dst_nodes
但它不是所有至少有一条入边的节点的超集。如果dst_nodes
不为 None,并且g
和dst_nodes
不在同一个上下文(context)中。
注意
to_block()
最常用于为大型图的随机训练定制邻居采样。请参考用户指南 第 6 章:大型图的随机训练 以获取关于随机训练方法的更详细讨论。另请参阅
create_block()
,以实现更灵活的块构建。示例
将同构图转换为上述描述的块
>>> g = dgl.graph(([1, 2], [2, 3])) >>> block = dgl.to_block(g, torch.LongTensor([3, 2]))
目标节点将与给定的一样:[3, 2]。
>>> induced_dst = block.dstdata[dgl.NID] >>> induced_dst tensor([3, 2])
最初的几个源节点也将与给定的一样。其余的节点是消息传递到节点 3, 2 所必需的节点。这意味着节点 1 将被包括在内。
>>> induced_src = block.srcdata[dgl.NID] >>> induced_src tensor([3, 2, 1])
你可以注意到前两个节点与给定节点以及目标节点相同。
诱导的边也可以通过以下方式获得
>>> block.edata[dgl.EID] tensor([2, 1])
这表明结果图中包含了边 (2, 3) 和 (1, 2)。你可以验证块中的第一条边确实映射到边 (2, 3),块中的第二条边确实映射到边 (1, 2)
>>> src, dst = block.edges(order='eid') >>> induced_src[src], induced_dst[dst] (tensor([2, 1]), tensor([3, 2]))
指定的目标节点必须是所有连接到它们的节点的超集。例如,以下代码将引发错误,因为目标节点不包含节点 3,而节点 3 有一条边连接到它。
>>> g = dgl.graph(([1, 2], [2, 3])) >>> dgl.to_block(g, torch.LongTensor([2])) # error
将异构图转换为块类似,只是在指定目标节点时,必须提供一个 dict
>>> g = dgl.heterograph({('A', '_E', 'B'): ([1, 2], [2, 3])})
如果在目标端没有指定任何类型 A 的节点,则块中类型
A
的节点在目标端将没有节点。>>> block = dgl.to_block(g, {'B': torch.LongTensor([3, 2])}) >>> block.number_of_dst_nodes('A') 0 >>> block.number_of_dst_nodes('B') 2 >>> block.dstnodes['B'].data[dgl.NID] tensor([3, 2])
源端将包含目标端的所有节点
>>> block.srcnodes['B'].data[dgl.NID] tensor([3, 2])
以及所有连接到目标端节点的节点
>>> block.srcnodes['A'].data[dgl.NID] tensor([2, 1])
另请参阅