通过网络中网络架构改进图神经网络
随着图神经网络 (GNN) 变得越来越流行,设计更深的 GNN 架构引起了广泛兴趣。然而,深度 GNN 存在过平滑问题,即随着层数增加,学习到的节点表示会迅速变得难以区分。本篇博客介绍了一种简单而有效的技术,可以在不担心过平滑问题的情况下构建深度 GNN。这种新架构称为图神经网络中的网络 (NGNN),灵感来自计算机视觉中的网络中网络架构,已在多个 Open Graph Benchmark (OGB) 排行榜上展现出卓越性能。
NGNN 简介
从高层次来看,图神经网络 (MPGNN) 层可以写成一个非线性函数
其中表示输入节点特征,表示输入图,表示下游任务使用的最后一层节点嵌入,表示 GNN 层数。此外,函数由可学习参数决定以及是一个非线性激活函数。
NGNN 没有增加更多的 GNN 层,而是通过在每个 GNN 层内插入非线性前馈神经网络层来加深 GNN 模型。
本质上,NGNN 只是对节点原始嵌入的非线性变换,位于层。尽管 NGNN 技术很简单,但它功能非常强大(稍后我们将详细介绍)。此外,它没有大的内存开销,并且可以与各种训练方法一起使用,例如邻居采样或子图采样。
其背后的直觉很简单。随着 GNN 层数和训练迭代次数的增加,同一连接分量内的节点表示将趋于收敛到相同的值。NGNN 在某些 GNN 层之后使用简单的 MLP 来解决所谓的过平滑问题。
在 Deep Graph Library (DGL) 中实现 NGNN
为了更好地理解这个技巧,让我们使用 DGL 来实现一个简单的 NGNN,并使用 GCN 层作为主干。
借助 DGL 内置的 GCN 层 dgl.nn.GraphConv
,我们可以轻松实现一个最小化的 NGNN_GCN
层,它仅在 GCN 层后应用一个激活函数和一个线性变换。
from dgl.nn import GraphConv
class NGNN_GCNConv(torch.nn.Module):
def __init__(self, input_channels, hidden_channels, output_channels):
super(NGNN_GCNConv, self).__init__()
self.conv = GraphConv(input_channels, hidden_channels)
self.fc = Linear(hidden_channels, output_channels)
def forward(self, g, x, edge_weight=None):
x = self.conv(g, x, edge_weight)
x = F.relu(x)
x = self.fc(x)
return x
之后,您可以简单地堆叠 dgl.nn.GraphConv
层和 NGNN_GCN
层,以构成一个多层的 NGNN_GCN
网络。
class NGNN_GCN(nn.Module):
def __init__(self, input_channels, hidden_channels, output_channels):
super(Model, self).__init__()
self.conv1 = NGNN_GCNConv(input_channels, hidden_channels, hidden_channels)
self.conv2 = GraphConv(hidden_channels, output_channels)
def forward(self, g, input_channels):
h = self.conv1(g, input_channels)
h = F.relu(h)
h = self.conv2(g, h)
return h
在 NGNN 架构中,您可以将 dgl.nn.GraphConv
替换为任何其他图卷积层。DGL 提供了许多流行卷积层和实用模块的实现。您可以用一行代码轻松调用它们,并构建自己的 NGNN 模块。
模型性能
NGNN 可用于许多下游任务,例如节点分类/回归、边分类/回归、链接预测和图分类。总的来说,NGNN 在这些任务上的结果优于其主干 GNN。例如,NGNN+SEAL 在 ogbl-ppa 排行榜上取得了第一名,与原版 SEAL 相比,Hit@100 提高了 。下表显示了 NGNN 相对于各种原版 GNN 主干的性能提升。
数据集 | 指标 | 模型 | 性能 | |
---|---|---|---|---|
ogbn-proteins | ROC-AUC(%) | GraphSage+聚类采样 | 原版 | 67.45 ± 1.21 |
+NGNN | 68.12 ± 0.96 | |||
ogbn-products | 准确率(%) | GraphSage | 原版 | 78.27 ± 0.45 |
+NGNN | 79.88 ± 0.34 | |||
GAT+邻居采样 | 原版 | 79.23 ± 0.16 | ||
+NGNN | 79.67 ± 0.09 | |||
ogbl-collab | hit@50(%) | GCN | 原版 | 49.52 ± 0.70 |
+NGNN | 53.48 ± 0.40 | |||
GraphSage | 原版 | 51.66 ± 0.35 | ||
+NGNN | 53.59 ± 0.56 | |||
ogbl-ppa | hit@100(%) | SEAL-DGCNN | 原版 | 48.80 ± 3.16 |
+NGNN | 59.71 ± 2.45 | |||
GCN | 原版 | 18.67 ± 1.32 | ||
+NGNN | 36.83 ± 0.99 |
延伸阅读
- NGNN 论文:https://arxiv.org/abs/2111.11638
- NGNN+SEAL OGB 提交:https://github.com/dmlc/dgl/tree/master/examples/pytorch/ogb/ngnn_seal
- NGNN+GraphSAGE OGB 提交:https://github.com/dmlc/dgl/tree/master/examples/pytorch/ogb/ngnn
- DGL 内置 GNN 模块列表:https://docs.dgl.ai/api/python/nn-pytorch.html
关于作者: Yakun Song 是上海交通大学的一名本科生。这项工作是在 AWS 上海人工智能实验室实习期间完成的。
11月 28日
作者:Yakun Song,来自 博客