dgl.DGLGraph.push

DGLGraph.push(u, message_func, reduce_func, apply_node_func=None, etype=None)[source]

沿指定边类型将消息从指定节点发送到其后继节点,并更新其节点特征。

参数:
  • v (节点 ID) –

    节点 ID。允许的格式有

    • int: 单个节点。

    • 整数张量 (Int Tensor): 每个元素是一个节点 ID。张量必须与图具有相同的设备类型和 ID 数据类型。

    • 可迭代对象 [int]: 每个元素是一个节点 ID。

  • message_func (dgl.function.BuiltinFunction 可调用对象 callable) – 沿边生成消息的消息函数。它必须是 DGL 内建函数用户定义函数

  • reduce_func (dgl.function.BuiltinFunction 可调用对象 callable) – 聚合消息的归约函数。它必须是 DGL 内建函数用户定义函数

  • apply_node_func (可调用对象 callable, 可选) – 在消息归约后进一步更新节点特征的可选应用函数。它必须是一个 用户定义函数

  • etype (str(str, str, str), 可选) –

    边类型名称。允许的类型名称格式有

    • (str, str, str) 用于指定源节点类型、边类型和目标节点类型。

    • 或者一个 str 边类型名称,如果该名称能在图中唯一标识一个三元组格式。

    如果图只有一种边类型,则可以省略。

注意事项

DGL 建议对 message_funcreduce_func 参数使用 DGL 的内建函数,因为在这种情况下,DGL 会调用高效的内核,避免将节点特征复制到边特征。

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch

同构图

>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
>>> g.ndata['x'] = torch.ones(5, 2)
>>> g.push([0, 1], fn.copy_u('x', 'm'), fn.sum('m', 'h'))
>>> g.ndata['h']
tensor([[0., 0.],
        [1., 1.],
        [1., 1.],
        [0., 0.],
        [0., 0.]])

异构图

>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 0], [1, 2])})
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

Push 操作。

>>> g['follows'].push(0, fn.copy_u('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
tensor([[0.],
        [0.],
        [0.]])