dgl.DGLGraph.apply_edges

DGLGraph.apply_edges(func, edges='__ALL__', etype=None)[source]

使用提供的函数更新指定边的特征。

参数:
  • func (dgl.function.BuiltinFunctioncallable) – 用于生成新边特征的函数。它必须是 DGL 内置函数用户定义函数

  • edges (edges) –

    要更新特征的边。允许的输入格式有

    • int: 单个边 ID。

    • Int Tensor: 每个元素都是一个边 ID。张量必须与图具有相同的设备类型和 ID 数据类型。

    • iterable[int]: 每个元素都是一个边 ID。

    • (Tensor, Tensor): 节点-张量格式,其中两个张量的第 i 个元素指定一条边。

    • (iterable[int], iterable[int]): 类似于节点-张量格式,但将边端点存储在 Python 可迭代对象中。

    默认值指定图中的所有边。

  • etype (str(str, str, str), 可选) –

    边的类型名称。允许的类型名称格式有

    • (str, str, str) 分别表示源节点类型、边类型和目标节点类型。

    • 或者一个 str 边类型名称,如果该名称在图中可以唯一标识一个三元组格式。

    如果图只有一种边类型,则可以省略。

备注

DGL 建议对 func 参数使用 DGL 的内置函数,因为在这种情况下,DGL 会调用高效的内核,从而避免将节点特征复制到边特征。

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch

同构图

>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
>>> g.ndata['h'] = torch.ones(5, 2)
>>> g.apply_edges(lambda edges: {'x' : edges.src['h'] + edges.dst['h']})
>>> g.edata['x']
tensor([[2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])

使用内置函数

>>> import dgl.function as fn
>>> g.apply_edges(fn.u_add_v('h', 'h', 'x'))
>>> g.edata['x']
tensor([[2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])

异构图

>>> g = dgl.heterograph({('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1])})
>>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5)
>>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2})
>>> g.edges[('user', 'plays', 'game')].data['h']
tensor([[2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.]])

另请参阅

apply_nodes