2.1 内建函数和消息传递API
在 DGL 中,消息函数接受一个参数 edges
,它是一个 EdgeBatch
实例。在消息传递过程中,DGL 内部生成它来表示一批边。它有三个成员 src
、dst
和 data
,分别用于访问源节点、目标节点和边的特征。
归约函数接受一个参数 nodes
,它是一个 NodeBatch
实例。在消息传递过程中,DGL 内部生成它来表示一批节点。它有成员 mailbox
用于访问该批节点收到的消息。一些最常见的归约操作包括 sum
、max
、min
等。
更新函数接受一个参数 nodes
,如上所述。此函数对 归约函数
的聚合结果进行操作,通常在最后一步将其与节点的原始特征结合,并将结果保存为节点特征。
DGL 在命名空间 dgl.function
中将常用的消息函数和归约函数实现为内建函数。通常,DGL 建议尽可能使用内建函数,因为它们经过大量优化,并自动处理维度广播。
如果你的消息传递函数无法使用内建函数实现,你可以实现用户自定义消息/归约函数(也称为 UDF)。
内建消息函数可以是一元的或二元的。DGL 支持一元函数 copy
。对于二元函数,DGL 支持 add
、sub
、mul
、div
、dot
。消息内建函数的命名约定是 u
代表 src
节点,v
代表 dst
节点,e
代表 edges
。这些函数的参数是字符串,表示对应节点和边的输入和输出字段名称。支持的内建函数列表可在 DGL 内建函数 中找到。例如,要将源节点的 hu
特征和目标节点的 hv
特征相加,然后将结果保存在边的 he
字段中,可以使用内建函数 dgl.function.u_add_v('hu', 'hv', 'he')
。这等同于以下消息 UDF:
def message_func(edges):
return {'he': edges.src['hu'] + edges.dst['hv']}
内建归约函数支持 sum
、max
、min
和 mean
操作。归约函数通常有两个参数,一个用于 mailbox
中的字段名,一个用于节点特征中的字段名,两者都是字符串。例如,dgl.function.sum('m', 'h')
等同于以下归约 UDF,它对消息 m
进行求和:
import torch
def reduce_func(nodes):
return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
有关 UDF 的进阶用法,请参阅 用户自定义函数。
也可以仅通过 apply_edges()
调用逐边计算,而不触发消息传递。apply_edges()
将消息函数作为参数,默认更新所有边的特征。例如:
import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
对于消息传递,update_all()
是一个高级 API,它在一个调用中合并了消息生成、消息聚合和节点更新,这为整体优化留下了空间。
应用于 update_all()
的参数是一个消息函数、一个归约函数和一个更新函数。你可以在 update_all
之外调用更新函数,并在调用 update_all()
时不指定它。DGL 推荐这种方法,因为更新函数通常可以写成纯粹的张量操作,使代码更简洁。例如:
def update_all_example(graph):
# store the result in graph.ndata['ft']
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
# Call update function outside of update_all
final_ft = graph.ndata['ft'] * 2
return final_ft
此调用将通过乘以源节点特征 ft
和边特征 a
来生成消息 m
,对消息 m
求和以更新节点特征 ft
,最后将 ft
乘以 2 得到结果 final_ft
。调用后,DGL 将清除中间消息 m
。上述函数的数学公式是:
DGL 的内建函数支持浮点数据类型,即特征必须是 half
(float16
) /float
/double
张量。默认情况下禁用 float16
数据类型支持,因为它要求 GPU 计算能力至少达到 sm_53
(Pascal、Volta、Turing 和 Ampere 架构)。
用户可以通过从源代码编译 DGL 来启用 float16 进行混合精度训练(详细信息请参阅 混合精度训练 教程)。