dgl.DGLGraph.pull
- DGLGraph.pull(v, message_func, reduce_func, apply_node_func=None, etype=None)[源代码]
从指定节点的前驱节点沿指定边类型拉取消息,然后聚合这些消息以更新节点特征。
- 参数:
v (节点 ID) –
节点 ID。允许的格式有:
int
: 单个节点。整型张量:每个元素都是一个节点 ID。张量必须与图的设备类型和 ID 数据类型相同。
可迭代对象 [int]:每个元素都是一个节点 ID。
message_func (dgl.function.BuiltinFunction 或 callable) – 沿边生成消息的消息函数。它必须是 DGL 内置函数 或 用户定义函数。
reduce_func (dgl.function.BuiltinFunction 或 callable) – 聚合消息的归约函数。它必须是 DGL 内置函数 或 用户定义函数。
apply_node_func (callable, 可选) – 一个可选的应用函数,用于在消息归约后进一步更新节点特征。它必须是 用户定义函数。
etype (str 或 (str, str, str), 可选) –
边的类型名称。允许的类型名称格式有:
(str, str, str)
,分别表示源节点类型、边类型和目标节点类型。或者一个
str
边类型名称,如果该名称在图中能唯一标识一个三元组格式。
如果图只有一种边类型,则可以省略。
备注
如果给定的一些节点
v
没有入边,DGL 不会为这些节点调用消息和归约函数,并将其聚合后的消息填充为零。用户可以通过set_n_initializer()
控制填充值。如果提供了apply_node_func
,DGL 仍然会调用它。DGL 推荐对
message_func
和reduce_func
参数使用 DGL 内置函数,因为在这种情况下,DGL 会调用高效的内核,避免将节点特征复制到边特征。
示例
>>> import dgl >>> import dgl.function as fn >>> import torch
同构图
>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4])) >>> g.ndata['x'] = torch.ones(5, 2) >>> g.pull([0, 3, 4], fn.copy_u('x', 'm'), fn.sum('m', 'h')) >>> g.ndata['h'] tensor([[0., 0.], [0., 0.], [0., 0.], [1., 1.], [1., 1.]])
异构图
>>> g = dgl.heterograph({ ... ('user', 'follows', 'user'): ([0, 1], [1, 2]), ... ('user', 'plays', 'game'): ([0, 2], [0, 1]) ... }) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
拉取。
>>> g['follows'].pull(2, fn.copy_u('h', 'm'), fn.sum('m', 'h'), etype='follows') >>> g.nodes['user'].data['h'] tensor([[0.], [1.], [1.]])