dgl.DGLGraph.update_all

DGLGraph.update_all(message_func, reduce_func, apply_node_func=None, etype=None)[source]

沿着指定类型的所有边发送消息,并更新所有相应目标类型的节点。

对于关系类型数量大于 1 的异构图,沿着所有边发送消息,按类型进行规约,同时也可以跨不同类型进行规约。然后,更新所有节点的节点特征。

参数:
  • message_func (dgl.function.BuiltinFunctioncallable) – 用于沿着边生成消息的消息函数。它必须是 DGL 内置函数用户定义函数

  • reduce_func (dgl.function.BuiltinFunctioncallable) – 用于聚合消息的规约函数。它必须是 DGL 内置函数用户定义函数

  • apply_node_func (callable, 可选) – 一个可选的 apply 函数,用于在消息规约后进一步更新节点特征。它必须是 用户定义函数

  • etype (str(str, str, str), 可选) – 边类型的名称。允许的类型名称格式为

    (str, str, str),分别代表源节点类型、边类型和目标节点类型。

    • 或者一个 str 边类型名称,如果该名称在图中能唯一标识一个三元组格式。

    • 如果图只有一种边类型,则可以省略此参数。

    如果图中的某些节点没有入边,DGL 不会为这些节点调用消息函数和规约函数,并将其聚合消息填充为零。用户可以通过 set_n_initializer() 控制填充值。如果提供了 apply_node_func,DGL 仍然会调用它。

注意事项

  • DGL 建议为 message_funcreduce_func 参数使用 DGL 的内置函数,因为在这种情况下,DGL 会调用高效的内核,避免将节点特征复制到边特征。

  • 示例

同构图

>>> import dgl
>>> import dgl.function as fn
>>> import torch

异构图

>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
>>> g.ndata['x'] = torch.ones(5, 2)
>>> g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'h'))
>>> g.ndata['h']
tensor([[0., 0.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])

全量更新。

>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2], [1, 2, 2])})

异构图(关系类型数量 > 1)

>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g['follows'].update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
tensor([[0.],
        [0.],
        [3.]])

>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): ([0, 1], [1, 1]),
...     ('game', 'attracts', 'user'): ([0], [1])
... })

异构图(关系类型数量 > 1)

>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
>>> g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> g.nodes['user'].data['h']
tensor([[0.],
        [4.]])