dgl.DGLGraph.apply_nodes

DGLGraph.apply_nodes(func, v='__ALL__', ntype=None)[源码]

使用提供的函数更新指定节点的特征。

参数:
  • func (callable) – 用于更新节点特征的函数。它必须是 用户自定义函数

  • v (节点 ID) –

    节点 ID。允许的格式如下:

    • int: 单个节点。

    • Int Tensor: 每个元素是一个节点 ID。Tensor 的设备类型和 ID 数据类型必须与图的相同。

    • iterable[int]: 每个元素是一个节点 ID。

    如果未给出(默认),则使用图中的所有节点。

  • ntype (str, 可选) – 节点类型名称。如果图中只有一种节点类型,则可以省略。

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch

同构图

>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
>>> g.ndata['h'] = torch.ones(5, 2)
>>> g.apply_nodes(lambda nodes: {'x' : nodes.data['h'] * 2})
>>> g.ndata['x']
tensor([[2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])

异构图

>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1], [1, 2])})
>>> g.nodes['user'].data['h'] = torch.ones(3, 5)
>>> g.apply_nodes(lambda nodes: {'h': nodes.data['h'] * 2}, ntype='user')
>>> g.nodes['user'].data['h']
tensor([[2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.]])

另请参阅

apply_edges