博客详情

主页 / 博客详情
blog

通过网络中网络架构改进图神经网络

随着图神经网络 (GNN) 变得越来越流行,设计更深的 GNN 架构引起了广泛兴趣。然而,深度 GNN 存在过平滑问题,即随着层数增加,学习到的节点表示会迅速变得难以区分。本篇博客介绍了一种简单而有效的技术,可以在不担心过平滑问题的情况下构建深度 GNN。这种新架构称为图神经网络中的网络 (NGNN),灵感来自计算机视觉中的网络中网络架构,已在多个 Open Graph Benchmark (OGB) 排行榜上展现出卓越性能。

NGNN 简介

从高层次来看,图神经网络 (MPGNN) 层可以写成一个非线性函数

其中表示输入节点特征,表示输入图,表示下游任务使用的最后一层节点嵌入,表示 GNN 层数。此外,函数由可学习参数决定以及是一个非线性激活函数。

NGNN 没有增加更多的 GNN 层,而是通过在每个 GNN 层内插入非线性前馈神经网络层来加深 GNN 模型。

ngnn

本质上,NGNN 只是对节点原始嵌入的非线性变换,位于层。尽管 NGNN 技术很简单,但它功能非常强大(稍后我们将详细介绍)。此外,它没有大的内存开销,并且可以与各种训练方法一起使用,例如邻居采样或子图采样。

其背后的直觉很简单。随着 GNN 层数和训练迭代次数的增加,同一连接分量内的节点表示将趋于收敛到相同的值。NGNN 在某些 GNN 层之后使用简单的 MLP 来解决所谓的过平滑问题。

在 Deep Graph Library (DGL) 中实现 NGNN

为了更好地理解这个技巧,让我们使用 DGL 来实现一个简单的 NGNN,并使用 GCN 层作为主干。

借助 DGL 内置的 GCN 层 dgl.nn.GraphConv,我们可以轻松实现一个最小化的 NGNN_GCN 层,它仅在 GCN 层后应用一个激活函数和一个线性变换。

from dgl.nn import GraphConv

class NGNN_GCNConv(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NGNN_GCNConv, self).__init__()
        self.conv = GraphConv(input_channels, hidden_channels)
        self.fc = Linear(hidden_channels, output_channels)

    def forward(self, g, x, edge_weight=None):
        x = self.conv(g, x, edge_weight)
        x = F.relu(x)
        x = self.fc(x)
        return x

之后,您可以简单地堆叠 dgl.nn.GraphConv 层和 NGNN_GCN 层,以构成一个多层的 NGNN_GCN 网络。

class NGNN_GCN(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(Model, self).__init__()
        self.conv1 = NGNN_GCNConv(input_channels, hidden_channels, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, output_channels)

    def forward(self, g, input_channels):
        h = self.conv1(g, input_channels)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

在 NGNN 架构中,您可以将 dgl.nn.GraphConv 替换为任何其他图卷积层。DGL 提供了许多流行卷积层和实用模块的实现。您可以用一行代码轻松调用它们,并构建自己的 NGNN 模块。

模型性能

NGNN 可用于许多下游任务,例如节点分类/回归、边分类/回归、链接预测和图分类。总的来说,NGNN 在这些任务上的结果优于其主干 GNN。例如,NGNN+SEAL 在 ogbl-ppa 排行榜上取得了第一名,与原版 SEAL 相比,Hit@100 提高了 。下表显示了 NGNN 相对于各种原版 GNN 主干的性能提升。

数据集 指标 模型 性能
ogbn-proteins ROC-AUC(%) GraphSage+聚类采样 原版 67.45 ± 1.21
+NGNN 68.12 ± 0.96
ogbn-products 准确率(%) GraphSage 原版 78.27 ± 0.45
+NGNN 79.88 ± 0.34
GAT+邻居采样 原版 79.23 ± 0.16
+NGNN 79.67 ± 0.09
ogbl-collab hit@50(%) GCN 原版 49.52 ± 0.70
+NGNN 53.48 ± 0.40
GraphSage 原版 51.66 ± 0.35
+NGNN 53.59 ± 0.56
ogbl-ppa hit@100(%) SEAL-DGCNN 原版 48.80 ± 3.16
+NGNN 59.71 ± 2.45
GCN 原版 18.67 ± 1.32
+NGNN 36.83 ± 0.99

延伸阅读

关于作者: Yakun Song 是上海交通大学的一名本科生。这项工作是在 AWS 上海人工智能实验室实习期间完成的。