博客详情

主页 / 博客详情
blog

当内核融合遇见图神经网络

在去年12月发布的DGL第一个版本中,我们通过引入一套精心设计、易于使用的API来支持各种图神经网络模型的实现,从而专注于易用性。我们决定保持DGL的框架无关性,以便与来自不同平台(PyTorch,MXNet等)的用户互动。因此,在我们早期的版本中,我们很大程度上利用了这些框架提供的现有功能,并且根据用户提供的许多宝贵反馈,我们深知在处理稀疏和不规则图上的某些新模型方面仍有改进空间。

随着DGL第一个版本中的API受到好评并逐渐稳定,我们一直在全力提升其性能。我们的下一个主要版本(v0.3)将专注于显著提高训练速度、降低内存消耗和增强可扩展性。

即将发布的DGL v0.3在性能上有了实质性的提升。与当前版本相比,DGL v0.3的训练吞吐量提高了高达19倍,并且可以在单张GPU上训练8倍大的图。DGL的训练速度现在与Pytorch Geometric等其他框架具有竞争力,同时拥有更好的可扩展性。DGL允许在相当大的图上进行训练——高达5亿个节点和250亿条边。如需更具体的性能评估和比较,请查看我们的研讨会论文以获取更多详情。

在本文的其余部分,我们将深入技术细节,描述融合消息传递,这项实现这些性能改进的关键技术。我们将探讨以下问题:

  • 为什么基础消息传递无法扩展到大型图?
  • 融合消息传递如何提供帮助?
  • 如何在 DGL 中启用融合消息传递?

为什么基础消息传递无法扩展到大型图?

大多数GNN模型在图上执行消息传递风格的计算。这种计算包含两个主要的用户定义函数:

  • 消息函数 指定了如何沿一条边从一个节点发送消息给其邻居。这是边级别的计算,因为它针对所有(或一部分)边执行。
  • 归约函数 指定了如何聚合节点的传入消息并更新节点特征。这是节点级别的计算,因为它针对所有(或一部分)节点执行。

下图给出了一个示例。用户定义的消息函数表示为,用于在每条边上生成一条消息(黄色框)。为了创建边的消息,消息函数考虑边特征以及两个端点的节点特征,, 。在每个节点处,使用用户定义的归约函数聚合传入的消息,并使用另一个用户定义的函数更新节点特征。在DGL中,可以通过调用 sendrecv API 轻松实现这种消息传递(更多详情请参见我们的消息传递教程)。

实现消息传递的基本策略是直接的。首先,我们通过调用边级别的消息函数来 send 消息。然后,我们通过应用节点级别的计算来 recv 消息,根据它们的目标节点聚合消息(并更新节点特征)。在下面的DGL示例中,我们通过在lambda表达式中指定其消息和归约函数来实现了图卷积网络(GCN),DGL在底层使用了基础消息传递。

# A GCN example with user-defined message function.
# Using user-defined message function causes DGL to use 
# the basic message passing strategy.
G.update_all(lambda edges: {'m' : edges.src['h']}, 
             lambda nodes: {'h' : sum(nodes.mailbox['m'], axis=1)})

基础消息传递的问题在于消息被显式物化和存储,导致图变大时内存爆炸。举一个具体的例子,考虑GraphSAGE论文中介绍的reddit图数据集。它有23.2万个节点和1.14亿条边。如果我们训练一个GCN模型,其消息函数只是复制源节点特征,这将导致内存消耗比节点特征存储本身高出惊人的约500倍!更糟糕的是,频繁的内存访问可能会成为计算的瓶颈,导致GPU设备利用率不足。

融合消息传递 == 无显式消息

为了避免物化消息带来的开销,我们实现了融合消息传递,即将 sendrecv 组合成一个操作 send_and_recv (如下图所示)。在底层,融合的 send_and_recv 操作是使用CUDA内核实现的,其中每个线程将源节点特征加载到线程本地内存中,计算消息,并直接将其聚合到给定目标节点的缓冲区中,然后立即丢弃该消息。

要启用融合消息传递,需要解决两个挑战:

  • 如何融合用户定义的函数 在DGL中,我们提供了一组预定义函数,称为DGL内置函数,供用户选择。这限制了可用于融合消息传递的消息和归约函数,但我们提供了各种常见函数,因此大多数GNN模型都可以实现。UDF(用户定义函数)仍然允许使用,在这种情况下DGL将转而使用基础消息传递。
  • 如何在没有显式消息的情况下反向传播梯度? 技巧是在反向传播过程中重新计算消息,类似于训练非常深的神经网络模型中使用的技术。实际上,许多不需要显式消息来计算梯度(例如将节点特征复制为消息),我们的实现利用了这项优化。

例如,我们可以重写之前的GCN实现,使用内置的 copy_src 消息函数和 sum 归约函数,如下所示。

import dgl.function as fn
G = ... # some graph
# copy src feature 'h' as the message and sum it as new 'h'
G.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'))

图注意力网络(GAT)也可以使用内置的 src_mul_edge 消息函数和 sum 归约函数来实现。

# the src features 'h' are weighted by the attention score 'e'
# on the edges and are summed as the new feature 'h'
G.update_all(fn.src_mul_edge('h', 'e', 'm'), fn.sum('m', 'h'))

DGL v0.3包含支持以下组合的内置函数:

  • 消息函数可以是 src、edge、dst 三个给定特征之间的加/减/乘/除中的任何一个。
  • 特征维度支持广播。例如,一个多头注意力模块,其中节点特征的形状为 (NUM_HEADS, NUM_FEATS),而注意力分数的形状为 (NUM_HEADS, 1)
  • 归约函数可以是 sum/max/min/prod 中的任何一个。

经验法则是尽可能使用内置函数,以便DGL可以使用融合消息传递以获得更好的性能。我们注意到这可能会带来一些编程负担,但回报绝对是值得的(评估结果请参见下一节)。

用数字来说服我

我们与v0.2版本进行比较,看看系统改进了多少。此外,我们还与PyG(Pytorch Geometric v1.2.0)进行了比较。PyG通过先 gather 节点特征作为边消息,然后 scatter 它们进行消息聚合来实现基础消息传递,这会生成显式消息。

我们按照其原始设置在几个流行数据集上对GCN和GAT模型进行了基准测试。测试平台是一台带有NVIDIA V100 GPU(16GB内存)的AWS p3.2xlarge实例。

数据集 #V #E 模型 DGL(v0.2) 时间(秒) PyG 时间(秒) DGL(v0.3) 时间(秒)
Cora 3K 11K GCN
GAT
0.685
9.727
0.482
1.248
0.619
1.389
Citeseer 3K 9K GCN
GAT
0.670
9.018
0.490
1.254
0.631
1.363
Pubmed 20K 889K GCN
GAT
0.694
26.186
0.485
1.509
0.603
1.381
Reddit 232K 114M GCN 内存不足 内存不足 25.30

即将发布的版本性能有了巨大的提升,尤其是在GAT模型上(得益于内核融合,速度提高了19倍)。与PyG相比,对于小型图(即Cora、Citeseer和Pubmed),计算和内存消耗保持不变,相对不受图大小影响。在这种情况下,图计算不是训练的瓶颈,DGL相对于PyG有一些轻微的、恒定的开销。然而,在对从Reddit中提取的更大图进行评估时,PyG内存不足,而DGL可以轻松容纳该图。

我们进一步使用合成图进行了一些消融研究。

我们首先改变图中的节点数量,保持图密度固定(0.0008),然后测试GCN和GAT的训练速度。DGL可以在节点数量多达50万的图上训练GCN,是PyG能处理的两倍大。在PyG能容纳的最大图上,PyG也比DGL慢3.4倍。

然后我们固定图中的节点数量(3.2万),但改变图的密度。结果清楚地显示了融合消息传递的优势。对于GCN和GAT,DGL可以在边数多达8倍的图上进行训练,并且在PyG能容纳的最大图上比PyG快7.5倍。

我们还改变隐藏层大小,并在中等大小的图(3.2万节点,密度0.0008)上比较性能。对于GCN,尽管PyG可以容纳我们测试的最大隐藏层大小,但它比DGL慢4倍。对于GAT,PyG无法训练隐藏层大小超过32的模型。

最后,我们突破极限,看看单台具有大CPU内存的机器(AWS x1.32xlarge实例,具有2TB内存)上可以训练多大的图。

#节点数 #边数 时间(秒) 内存(GB)
5M 250M 4.7 8
50M 2.5B 46 75
500M 25B 505 740

结果表明,DGL可以在节点数高达5亿、边数高达250亿的图上训练GCN模型。

下一步是什么?

DGL团队对未来计划中的一系列功能充满热情。事实上,其中许多自项目启动以来就已在我们心中。例如,内置函数一直是我们优先考虑的事项,但只有在内核融合的帮助下它们才能真正发光发热。为了阐明我们正在努力的方向:

  • 关于如何在大型CPU机器上重现本文中大型图实验的详细演示和教程。
  • 支持异构图。
  • 使用 GPU 加速图遍历和查询。

DGL始终与用户保持紧密联系,我们非常重视您的反馈!要尝试这项新功能,只需如此简单:克隆DGL仓库,切换到 kernel 分支并从源代码构建库。

请继续关注我们的下一个主要版本!更多精彩内容即将到来。