dgl.to_block

dgl.to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None)[source]

将图转换为二部结构的 用于消息传递。

块 是一个由两组节点组成的图: 节点和 目标 节点。源节点和目标节点可以有多种节点类型。所有边都从源节点连接到目标节点。

具体来说,源节点和目标节点将具有与原始图中相同的节点类型。DGL 将原始图中的每条边 (u, v)及其边类型 (utype, etype, vtype) 映射到源端类型为 utype 的节点ID u 连接到目标端类型为 vtype 的节点ID v 的类型为 etype 的边。

对于 to_block() 返回的块,块的目标节点将只包含至少有一条入边的任何类型的节点。块的源节点将只包含出现在目标节点中的节点,以及至少有一条出边连接到目标节点之一的节点。

如果 dst_nodes 参数不为 None,则目标节点由该参数指定。

参数:
  • graph (DGLGraph) – 图。可以在 CPU 或 GPU 上。

  • dst_nodes (Tensor or dict[str, Tensor], optional) –

    目标节点列表。

    如果给定的是 Tensor,则图必须只有一种节点类型。

    如果给定,它必须是所有至少有一条入边的节点的超集。否则将引发错误。

  • include_dst_in_src (bool) –

    如果为 False,则不将目标节点包含在源节点中。

    (默认值: True)

  • src_nodes (Tensor or disct[str, Tensor], optional) –

    源节点列表(如果 include_dst_in_src 为 True,则包含前缀为目标节点的节点)。

    如果给定的是 Tensor,则图必须只有一种节点类型。

返回:

描述该块的新图。

块两侧每种类型的诱导节点ID将存储在特征 dgl.NID 中。

每种类型的诱导边ID将存储在特征 dgl.EID 中。

返回类型:

DGLBlock

引发:

DGLError – 如果指定了 dst_nodes 但它不是所有至少有一条入边的节点的超集。如果 dst_nodes 不为 None,并且 gdst_nodes 不在同一个上下文(context)中。

注意

to_block() 最常用于为大型图的随机训练定制邻居采样。请参考用户指南 第 6 章:大型图的随机训练 以获取关于随机训练方法的更详细讨论。

另请参阅 create_block(),以实现更灵活的块构建。

示例

将同构图转换为上述描述的块

>>> g = dgl.graph(([1, 2], [2, 3]))
>>> block = dgl.to_block(g, torch.LongTensor([3, 2]))

目标节点将与给定的一样:[3, 2]。

>>> induced_dst = block.dstdata[dgl.NID]
>>> induced_dst
tensor([3, 2])

最初的几个源节点也将与给定的一样。其余的节点是消息传递到节点 3, 2 所必需的节点。这意味着节点 1 将被包括在内。

>>> induced_src = block.srcdata[dgl.NID]
>>> induced_src
tensor([3, 2, 1])

你可以注意到前两个节点与给定节点以及目标节点相同。

诱导的边也可以通过以下方式获得

>>> block.edata[dgl.EID]
tensor([2, 1])

这表明结果图中包含了边 (2, 3) 和 (1, 2)。你可以验证块中的第一条边确实映射到边 (2, 3),块中的第二条边确实映射到边 (1, 2)

>>> src, dst = block.edges(order='eid')
>>> induced_src[src], induced_dst[dst]
(tensor([2, 1]), tensor([3, 2]))

指定的目标节点必须是所有连接到它们的节点的超集。例如,以下代码将引发错误,因为目标节点不包含节点 3,而节点 3 有一条边连接到它。

>>> g = dgl.graph(([1, 2], [2, 3]))
>>> dgl.to_block(g, torch.LongTensor([2]))     # error

将异构图转换为块类似,只是在指定目标节点时,必须提供一个 dict

>>> g = dgl.heterograph({('A', '_E', 'B'): ([1, 2], [2, 3])})

如果在目标端没有指定任何类型 A 的节点,则块中类型 A 的节点在目标端将没有节点。

>>> block = dgl.to_block(g, {'B': torch.LongTensor([3, 2])})
>>> block.number_of_dst_nodes('A')
0
>>> block.number_of_dst_nodes('B')
2
>>> block.dstnodes['B'].data[dgl.NID]
tensor([3, 2])

源端将包含目标端的所有节点

>>> block.srcnodes['B'].data[dgl.NID]
tensor([3, 2])

以及所有连接到目标端节点的节点

>>> block.srcnodes['A'].data[dgl.NID]
tensor([2, 1])

另请参阅

create_block