dgl.DGLGraph.prop_nodes
- DGLGraph.prop_nodes(nodes_generator, message_func, reduce_func, apply_node_func=None, etype=None)[源码]
通过按顺序触发节点上的
pull()
操作,使用图遍历来传播消息。遍历顺序由
nodes_generator
指定。它生成节点边界(node frontiers),即节点列表或张量。同一边界中的节点将一起被触发,而不同边界中的节点将按照生成顺序被触发。- 参数:
nodes_generator (iterable[节点 ID]) – 节点边界的生成器。每个边界是一组存储在 Tensor 或 python 可迭代对象中的节点 ID。它指定了在每一步中哪些节点执行
pull()
操作。message_func (dgl.function.BuiltinFunction 或 callable) – 用于沿边生成消息的消息函数。它必须是 DGL 内置函数 或 用户自定义函数。
reduce_func (dgl.function.BuiltinFunction 或 callable) – 用于聚合消息的归约函数。它必须是 DGL 内置函数 或 用户自定义函数。
apply_node_func (callable, 可选) – 一个可选的 apply 函数,用于在消息归约后进一步更新节点特征。它必须是 用户自定义函数。
etype (str 或 (str, str, str), 可选) –
边的类型名称。允许的类型名称格式为
(str, str, str)
,表示源节点类型、边类型和目标节点类型。或者一个
str
边类型名称,如果该名称可以在图中唯一标识一个三元组格式。
如果图只有一种边类型,则可以省略。
示例
>>> import torch >>> import dgl >>> import dgl.function as fn
实例化异构图并执行多轮消息传递。
>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])}) >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]]) >>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_u('h', 'm'), ... fn.sum('m', 'h'), etype='follows') tensor([[1.], [2.], [1.], [2.], [3.]])
另请参阅