用户定义函数

用户定义函数(UDF)允许在消息传递(参见 第 2 章:消息传递)和使用 apply_edges() 更新边特征时进行任意计算。当 dgl.function 无法实现所需的计算时,它们带来了更大的灵活性。

边上的用户定义函数

可以使用边上的用户定义函数作为消息传递中的消息函数,或作为在 apply_edges() 中应用的函数。它将一批边作为输入,并为每条边返回消息(在消息传递中)或特征(在 apply_edges() 中)。该函数可以在计算中结合边的特征及其端点节点的特征。

形式上,它采用以下形式

def edge_udf(edges):
    """
    Parameters
    ----------
    edges : EdgeBatch
        A batch of edges.

    Returns
    -------
    dict[str, tensor]
        The messages or edge features generated. It maps a message/feature name to the
        corresponding messages/features of all edges in the batch. The order of the
        messages/features is the same as the order of the edges in the input argument.
    """

DGL 内部生成 EdgeBatch 实例,这些实例暴露了以下接口用于定义 edge_udf

EdgeBatch.src

返回批处理中边的源节点特征的视图。

EdgeBatch.dst

返回批处理中边的目标节点特征的视图。

EdgeBatch.data

返回批处理中边的边特征的视图。

EdgeBatch.edges()

返回批处理中的边。

EdgeBatch.batch_size()

返回批处理中的边数。

节点上的用户定义函数

可以使用节点上的用户定义函数作为消息传递中的归约函数。它将一批节点作为输入,并为每个节点返回更新后的特征。它可以结合当前节点特征和节点接收到的消息。形式上,它采用以下形式

def node_udf(nodes):
    """
    Parameters
    ----------
    nodes : NodeBatch
        A batch of nodes.

    Returns
    -------
    dict[str, tensor]
        The updated node features. It maps a feature name to the corresponding features of
        all nodes in the batch. The order of the nodes is the same as the order of the nodes
        in the input argument.
    """

DGL 内部生成 NodeBatch 实例,这些实例暴露了以下接口用于定义 node_udf

NodeBatch.data

返回批处理中节点的节点特征视图。

NodeBatch.mailbox

返回收到的消息视图。

NodeBatch.nodes()

返回批处理中的节点。

NodeBatch.batch_size()

返回批处理中的节点数。

使用用户定义函数进行消息传递的度分桶

DGL 采用度分桶机制来使用 UDF 进行消息传递。它将具有相同入度的节点分组,并为每组节点调用消息传递。因此,不应对 NodeBatch 实例的批处理大小做任何假设。

对于一批节点,DGL 沿着第二个维度堆叠每个节点的传入消息,并按边 ID 排序。示例如下

>>> import dgl
>>> import torch
>>> import dgl.function as fn
>>> g = dgl.graph(([1, 3, 5, 0, 4, 2, 3, 3, 4, 5], [1, 1, 0, 0, 1, 2, 2, 0, 3, 3]))
>>> g.edata['eid'] = torch.arange(10)
>>> def reducer(nodes):
...     print(nodes.mailbox['eid'])
...     return {'n': nodes.mailbox['eid'].sum(1)}
>>> g.update_all(fn.copy_e('eid', 'eid'), reducer)
tensor([[5, 6],
        [8, 9]])
tensor([[3, 7, 2],
        [0, 1, 4]])

本质上,节点 #2 和节点 #3 被分组到入度为 2 的桶中,而节点 #0 和节点 #1 被分组到入度为 3 的桶中。在每个桶内,边按每个节点的边 ID 进行排序。