2.2 编写高效的消息传递代码
DGL 对消息传递的内存消耗和计算速度进行了优化。利用这些优化的一个常见做法是,将自己的消息传递功能构建为对 update_all()
函数的组合调用,并以内置函数作为参数。
此外,考虑到在某些图上边数远大于节点数,避免节点到边的冗余内存复制是有益的。对于 GATConv
等需要在边上保存消息的情况,需要调用带内置函数的 apply_edges()
。有时边上的消息维度很高,这会消耗大量内存。DGL 建议尽可能降低边特征的维度。
这里有一个例子,说明如何通过将边上的操作拆分到节点上来实现这一点。这种方法执行以下步骤:连接 src
特征和 dst
特征,然后应用一个线性层,即 \(W\times (u || v)\)。其中 src
和 dst
特征维度较高,而线性层输出维度较低。一个直接的实现方法是:
import torch
import torch.nn as nn
linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim)))
def concat_message_function(edges):
return {'cat_feat': torch.cat([edges.src['feat'], edges.dst['feat']], dim=1)}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] @ linear
建议的实现方法将线性操作拆分为两个,一个应用于 src
特征,另一个应用于 dst
特征。然后在最后阶段将这些线性操作的输出在边上相加,即执行 \(W_l\times u + W_r \times v\)。这是因为 \(W \times (u||v) = W_l \times u + W_r \times v\),其中 \(W_l\) 和 \(W_r\) 分别是矩阵 \(W\) 的左半部分和右半部分。
import dgl.function as fn
linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
out_src = g.ndata['feat'] @ linear_src
out_dst = g.ndata['feat'] @ linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))
上述两种实现方法在数学上是等价的。后一种方法更高效,因为它不需要在边上保存 feat_src 和 feat_dst,这在内存效率方面表现不佳。此外,加法可以通过 DGL 的内置函数 u_add_v()
进行优化,这进一步加快了计算速度并节省了内存占用。