dgl.DGLGraph.send_and_recv

DGLGraph.send_and_recv(edges, message_func, reduce_func, apply_node_func=None, etype=None)[source]

沿着指定边发送消息,并在目标节点上归约这些消息以更新其特征。

参数:
  • edges () –

    用于发送和接收消息的边。允许的输入格式有

    • int: 单个边 ID。

    • 整型 Tensor: 每个元素是一个边 ID。该 tensor 必须与图的设备类型和 ID 数据类型相同。

    • iterable[int]: 每个元素是一个边 ID。

    • (Tensor, Tensor): 节点-tensor 格式,其中两个 tensor 的第 i 个元素指定一条边。

    • (iterable[int], iterable[int]): 类似于节点-tensor 格式,但将边端点存储在 Python 可迭代对象中。

  • message_func (dgl.function.BuiltinFunctioncallable) – 沿边生成消息的消息函数。它必须是 DGL 内建函数用户自定义函数

  • reduce_func (dgl.function.BuiltinFunctioncallable) – 用于聚合消息的归约函数。它必须是 DGL 内建函数用户自定义函数

  • apply_node_func (callable, 可选) – 一个可选的应用函数,用于在消息归约后进一步更新节点特征。它必须是 用户自定义函数

  • etype (str(str, str, str), 可选) –

    边的类型名称。允许的类型名称格式有

    • (str, str, str) 表示源节点类型、边类型和目标节点类型。

    • 或一个 str 边类型名称,如果该名称可以唯一标识图中的三元组格式。

    如果图只有一种边类型,则可以省略。

注意事项

DGL 建议对 message_funcreduce_func 参数使用 DGL 的内建函数,因为在这种情况下 DGL 会调用高效的内核,避免将节点特征复制到边特征。

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch

同构图

>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
>>> g.ndata['x'] = torch.ones(5, 2)
>>> # Specify edges using (Tensor, Tensor).
>>> g.send_and_recv(([1, 2], [2, 3]), fn.copy_u('x', 'm'), fn.sum('m', 'h'))
>>> g.ndata['h']
tensor([[0., 0.],
        [0., 0.],
        [1., 1.],
        [1., 1.],
        [0., 0.]])
>>> # Specify edges using IDs.
>>> g.send_and_recv([0, 2, 3], fn.copy_u('x', 'm'), fn.sum('m', 'h'))
>>> g.ndata['h']
tensor([[0., 0.],
        [1., 1.],
        [0., 0.],
        [1., 1.],
        [1., 1.]])

异构图

>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): ([0, 1], [1, 2]),
...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])
... })
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.send_and_recv(g['follows'].edges(), fn.copy_u('h', 'm'),
...                 fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
tensor([[0.],
        [0.],
        [1.]])

使用用户自定义函数进行 ``send_and_recv``

>>> import torch as th
>>> g = dgl.graph(([0, 1], [1, 2]))
>>> g.ndata['x'] = th.tensor([[1.], [2.], [3.]])
>>> # Define the function for sending node features as messages.
>>> def send_source(edges):
...     return {'m': edges.src['x']}
>>> # Sum the messages received and use this to replace the original node feature.
>>> def simple_reduce(nodes):
...     return {'x': nodes.mailbox['m'].sum(1)}

发送和接收消息。

>>> g.send_and_recv(g.edges())
>>> g.ndata['x']
tensor([[1.],
        [1.],
        [2.]])

请注意,节点 0 的特征保持不变,因为它没有入边。