dgl.DGLGraph.multi_update_all

DGLGraph.multi_update_all(etype_dict, cross_reducer, apply_node_func=None)[源代码]

沿所有边发送消息,首先按类型进行聚合,然后跨不同类型进行聚合,接着更新所有节点的节点特征。

参数:
  • etype_dict (dict) –

    按边类型进行消息传递的参数。键是边类型,值是消息传递参数。

    允许的键格式为

    • (str, str, str),表示源节点类型、边类型和目标节点类型。

    • 或一个 str 类型的边类型名称,如果该名称可以在图中唯一标识一个三元组格式。

    值必须是一个元组 (message_func, reduce_func, [apply_node_func]),其中

    • message_funcdgl.function.BuiltinFunction 或 可调用对象

      用于沿边生成消息的消息函数。它必须是 DGL 内置函数用户自定义函数

    • reduce_funcdgl.function.BuiltinFunction 或 可调用对象

      用于聚合消息的聚合函数。它必须是 DGL 内置函数用户自定义函数

    • apply_node_func可调用对象,可选的

      一个可选的应用函数,用于在消息聚合后进一步更新节点特征。它必须是一个 用户自定义函数

  • cross_reducer (str可调用函数) – 跨类型聚合器。可以是 "sum", "min", "max", "mean", "stack" 之一,或一个可调用函数。如果提供一个可调用函数,输入参数必须是一个包含来自每种边类型的聚合结果的张量列表,并且函数的输出必须是一个张量。

  • apply_node_func (可调用对象, 可选的) – 在消息按类型和跨不同类型聚合后执行的可选应用函数。它必须是一个 用户自定义函数

备注

DGL 建议在按类型进行消息传递参数中使用 DGL 的内置 message_func 和 reduce_func,因为在这种情况下 DGL 会调用高效的内核,避免将节点特征复制到边特征。

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch

实例化一个异构图。

>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): ([0, 1], [1, 1]),
...     ('game', 'attracts', 'user'): ([0], [1])
... })
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])

执行多类型聚合更新。

>>> g.multi_update_all(
...     {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')),
...      'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))},
... "sum")
>>> g.nodes['user'].data['h']
tensor([[0.],
        [4.]])

用户自定义的跨类型聚合器,等价于 “sum”。

>>> def cross_sum(flist):
...     return torch.sum(torch.stack(flist, dim=0), dim=0) if len(flist) > 1 else flist[0]

使用用户自定义的跨类型聚合器。

>>> g.multi_update_all(
...     {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')),
...      'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))},
... cross_sum)