dgl.DGLGraph.multi_update_all
- DGLGraph.multi_update_all(etype_dict, cross_reducer, apply_node_func=None)[源代码]
沿所有边发送消息,首先按类型进行聚合,然后跨不同类型进行聚合,接着更新所有节点的节点特征。
- 参数:
etype_dict (dict) –
按边类型进行消息传递的参数。键是边类型,值是消息传递参数。
允许的键格式为
(str, str, str)
,表示源节点类型、边类型和目标节点类型。或一个
str
类型的边类型名称,如果该名称可以在图中唯一标识一个三元组格式。
值必须是一个元组
(message_func, reduce_func, [apply_node_func])
,其中cross_reducer (str 或 可调用函数) – 跨类型聚合器。可以是
"sum"
,"min"
,"max"
,"mean"
,"stack"
之一,或一个可调用函数。如果提供一个可调用函数,输入参数必须是一个包含来自每种边类型的聚合结果的张量列表,并且函数的输出必须是一个张量。apply_node_func (可调用对象, 可选的) – 在消息按类型和跨不同类型聚合后执行的可选应用函数。它必须是一个 用户自定义函数。
备注
DGL 建议在按类型进行消息传递参数中使用 DGL 的内置 message_func 和 reduce_func,因为在这种情况下 DGL 会调用高效的内核,避免将节点特征复制到边特征。
示例
>>> import dgl >>> import dgl.function as fn >>> import torch
实例化一个异构图。
>>> g = dgl.heterograph({ ... ('user', 'follows', 'user'): ([0, 1], [1, 1]), ... ('game', 'attracts', 'user'): ([0], [1]) ... }) >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]]) >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
执行多类型聚合更新。
>>> g.multi_update_all( ... {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')), ... 'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))}, ... "sum") >>> g.nodes['user'].data['h'] tensor([[0.], [4.]])
用户自定义的跨类型聚合器,等价于 “sum”。
>>> def cross_sum(flist): ... return torch.sum(torch.stack(flist, dim=0), dim=0) if len(flist) > 1 else flist[0]
使用用户自定义的跨类型聚合器。
>>> g.multi_update_all( ... {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')), ... 'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))}, ... cross_sum)