编写你自己的 GNN 模块

有时,你的模型不仅仅是简单地堆叠现有的 GNN 模块。例如,你可能希望通过考虑节点重要性或边缘权重来发明一种新的聚合邻居信息的方法。

在本教程结束时,你将能够

  • 理解 DGL 的消息传递 API。

  • 自己实现 GraphSAGE 卷积模块。

本教程假设你已经了解了 使用 GNN 进行节点分类训练的基础知识

(预计时间:10 分钟)

import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F

消息传递和 GNN

DGL 遵循由 Gilmer 等人 提出的消息传递神经网络所启发的 消息传递范式 。本质上,他们发现许多 GNN 模型可以适配到以下框架中

\[m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right)\]
\[m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)}\]
\[h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)\]

其中 DGL 将 \(M^{(l)}\) 称为 消息函数,将 \(\sum\) 称为 聚合函数 ,以及将 \(U^{(l)}\) 称为 更新函数 。注意,此处可以表示任何函数,不一定是求和。

例如,GraphSAGE 卷积(Hamilton 等人,2017) 具有以下数学形式

\[h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\}\]
\[h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right)\]

可以看到消息传递是有方向的:从一个节点 \(u\) 发送到另一个节点 \(v\) 的消息不一定与从节点 \(v\) 发送到节点 \(u\) 的反方向消息相同。

虽然 DGL 通过 dgl.nn.SAGEConv 内置支持 GraphSAGE,但你可以自己实现 DGL 中的 GraphSAGE 卷积,如下所示。

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """

    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
            g.ndata["h"] = h
            # update_all is a message passing API.
            g.update_all(
                message_func=fn.copy_u("h", "m"),
                reduce_func=fn.mean("m", "h_N"),
            )
            h_N = g.ndata["h_N"]
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

此代码的核心是 g.update_all 函数,它收集并平均邻居特征。这里有三个概念

  • 消息函数 fn.copy_u('h', 'm'),它将名为 'h' 的节点特征复制为名为 'm'消息 ,发送给邻居。

  • 聚合函数 fn.mean('m', 'h_N'),它平均所有名为 'm' 的接收到的消息,并将结果保存为新的节点特征 'h_N'

  • update_all 告诉 DGL 为所有节点和边缘触发消息函数和聚合函数。

之后,你可以堆叠你自己的 GraphSAGE 卷积层来构建一个多层 GraphSAGE 网络。

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

训练循环

以下用于数据加载和训练循环的代码直接复制自入门教程。

import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]


def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
    for e in range(200):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that we should only compute the losses of the nodes in the training set,
        # i.e. with train_mask 1.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())

        if e % 5 == 0:
            print(
                "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format(
                    e, loss, val_acc, best_val_acc, test_acc, best_test_acc
                )
            )


model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
train(g, model)
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
In epoch 0, loss: 1.948, val acc: 0.072 (best 0.072), test acc: 0.091 (best 0.091)
In epoch 5, loss: 1.881, val acc: 0.368 (best 0.368), test acc: 0.376 (best 0.376)
In epoch 10, loss: 1.752, val acc: 0.616 (best 0.616), test acc: 0.601 (best 0.601)
In epoch 15, loss: 1.555, val acc: 0.650 (best 0.650), test acc: 0.618 (best 0.618)
In epoch 20, loss: 1.299, val acc: 0.650 (best 0.658), test acc: 0.633 (best 0.618)
In epoch 25, loss: 1.011, val acc: 0.660 (best 0.662), test acc: 0.644 (best 0.634)
In epoch 30, loss: 0.731, val acc: 0.700 (best 0.700), test acc: 0.688 (best 0.688)
In epoch 35, loss: 0.492, val acc: 0.716 (best 0.716), test acc: 0.726 (best 0.726)
In epoch 40, loss: 0.313, val acc: 0.732 (best 0.732), test acc: 0.747 (best 0.739)
In epoch 45, loss: 0.194, val acc: 0.734 (best 0.736), test acc: 0.749 (best 0.747)
In epoch 50, loss: 0.121, val acc: 0.740 (best 0.740), test acc: 0.752 (best 0.749)
In epoch 55, loss: 0.078, val acc: 0.744 (best 0.744), test acc: 0.759 (best 0.759)
In epoch 60, loss: 0.052, val acc: 0.744 (best 0.744), test acc: 0.757 (best 0.759)
In epoch 65, loss: 0.038, val acc: 0.750 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 70, loss: 0.029, val acc: 0.748 (best 0.750), test acc: 0.757 (best 0.758)
In epoch 75, loss: 0.023, val acc: 0.744 (best 0.750), test acc: 0.761 (best 0.758)
In epoch 80, loss: 0.019, val acc: 0.744 (best 0.750), test acc: 0.759 (best 0.758)
In epoch 85, loss: 0.016, val acc: 0.742 (best 0.750), test acc: 0.759 (best 0.758)
In epoch 90, loss: 0.014, val acc: 0.742 (best 0.750), test acc: 0.759 (best 0.758)
In epoch 95, loss: 0.012, val acc: 0.742 (best 0.750), test acc: 0.760 (best 0.758)
In epoch 100, loss: 0.011, val acc: 0.742 (best 0.750), test acc: 0.760 (best 0.758)
In epoch 105, loss: 0.010, val acc: 0.740 (best 0.750), test acc: 0.759 (best 0.758)
In epoch 110, loss: 0.009, val acc: 0.742 (best 0.750), test acc: 0.759 (best 0.758)
In epoch 115, loss: 0.008, val acc: 0.742 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 120, loss: 0.008, val acc: 0.742 (best 0.750), test acc: 0.756 (best 0.758)
In epoch 125, loss: 0.007, val acc: 0.742 (best 0.750), test acc: 0.755 (best 0.758)
In epoch 130, loss: 0.007, val acc: 0.742 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 135, loss: 0.006, val acc: 0.744 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 140, loss: 0.006, val acc: 0.746 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 145, loss: 0.006, val acc: 0.744 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 150, loss: 0.005, val acc: 0.744 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 155, loss: 0.005, val acc: 0.744 (best 0.750), test acc: 0.757 (best 0.758)
In epoch 160, loss: 0.005, val acc: 0.744 (best 0.750), test acc: 0.757 (best 0.758)
In epoch 165, loss: 0.004, val acc: 0.744 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 170, loss: 0.004, val acc: 0.744 (best 0.750), test acc: 0.759 (best 0.758)
In epoch 175, loss: 0.004, val acc: 0.744 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 180, loss: 0.004, val acc: 0.744 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 185, loss: 0.004, val acc: 0.744 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 190, loss: 0.004, val acc: 0.744 (best 0.750), test acc: 0.758 (best 0.758)
In epoch 195, loss: 0.003, val acc: 0.746 (best 0.750), test acc: 0.758 (best 0.758)

更多自定义

在 DGL 中,我们在 dgl.function 包下提供了许多内置的消息和聚合函数。你可以在 API 文档 中找到更多详细信息。

这些 API 允许快速实现新的图卷积模块。例如,以下实现了一个新的 SAGEConv,它使用加权平均聚合邻居表示。注意,edata 成员可以持有边缘特征,这些特征也可以参与消息传递。

class WeightedSAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model with edge weights.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """

    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h, w):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        w : Tensor
            The edge weight.
        """
        with g.local_scope():
            g.ndata["h"] = h
            g.edata["w"] = w
            g.update_all(
                message_func=fn.u_mul_e("h", "w", "m"),
                reduce_func=fn.mean("m", "h_N"),
            )
            h_N = g.ndata["h_N"]
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

由于此数据集中的图没有边缘权重,我们在模型的 forward() 函数中手动将所有边缘权重赋值为一。你可以将其替换为你自己的边缘权重。

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
        h = F.relu(h)
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
        return h


model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
train(g, model)
In epoch 0, loss: 1.950, val acc: 0.316 (best 0.316), test acc: 0.319 (best 0.319)
In epoch 5, loss: 1.892, val acc: 0.366 (best 0.366), test acc: 0.362 (best 0.362)
In epoch 10, loss: 1.771, val acc: 0.480 (best 0.500), test acc: 0.481 (best 0.481)
In epoch 15, loss: 1.579, val acc: 0.594 (best 0.594), test acc: 0.580 (best 0.580)
In epoch 20, loss: 1.321, val acc: 0.658 (best 0.658), test acc: 0.656 (best 0.656)
In epoch 25, loss: 1.022, val acc: 0.678 (best 0.678), test acc: 0.703 (best 0.703)
In epoch 30, loss: 0.724, val acc: 0.698 (best 0.698), test acc: 0.714 (best 0.712)
In epoch 35, loss: 0.472, val acc: 0.714 (best 0.714), test acc: 0.725 (best 0.725)
In epoch 40, loss: 0.289, val acc: 0.720 (best 0.720), test acc: 0.733 (best 0.733)
In epoch 45, loss: 0.173, val acc: 0.722 (best 0.724), test acc: 0.737 (best 0.732)
In epoch 50, loss: 0.105, val acc: 0.722 (best 0.724), test acc: 0.739 (best 0.732)
In epoch 55, loss: 0.067, val acc: 0.718 (best 0.724), test acc: 0.746 (best 0.732)
In epoch 60, loss: 0.045, val acc: 0.728 (best 0.728), test acc: 0.745 (best 0.745)
In epoch 65, loss: 0.032, val acc: 0.726 (best 0.728), test acc: 0.748 (best 0.745)
In epoch 70, loss: 0.025, val acc: 0.724 (best 0.728), test acc: 0.752 (best 0.745)
In epoch 75, loss: 0.020, val acc: 0.722 (best 0.728), test acc: 0.751 (best 0.745)
In epoch 80, loss: 0.016, val acc: 0.722 (best 0.728), test acc: 0.752 (best 0.745)
In epoch 85, loss: 0.014, val acc: 0.722 (best 0.728), test acc: 0.751 (best 0.745)
In epoch 90, loss: 0.012, val acc: 0.726 (best 0.728), test acc: 0.752 (best 0.745)
In epoch 95, loss: 0.011, val acc: 0.730 (best 0.730), test acc: 0.752 (best 0.752)
In epoch 100, loss: 0.010, val acc: 0.728 (best 0.730), test acc: 0.750 (best 0.752)
In epoch 105, loss: 0.009, val acc: 0.726 (best 0.730), test acc: 0.751 (best 0.752)
In epoch 110, loss: 0.008, val acc: 0.726 (best 0.730), test acc: 0.750 (best 0.752)
In epoch 115, loss: 0.008, val acc: 0.726 (best 0.730), test acc: 0.750 (best 0.752)
In epoch 120, loss: 0.007, val acc: 0.726 (best 0.730), test acc: 0.750 (best 0.752)
In epoch 125, loss: 0.007, val acc: 0.726 (best 0.730), test acc: 0.751 (best 0.752)
In epoch 130, loss: 0.006, val acc: 0.726 (best 0.730), test acc: 0.752 (best 0.752)
In epoch 135, loss: 0.006, val acc: 0.726 (best 0.730), test acc: 0.752 (best 0.752)
In epoch 140, loss: 0.005, val acc: 0.724 (best 0.730), test acc: 0.754 (best 0.752)
In epoch 145, loss: 0.005, val acc: 0.724 (best 0.730), test acc: 0.755 (best 0.752)
In epoch 150, loss: 0.005, val acc: 0.724 (best 0.730), test acc: 0.755 (best 0.752)
In epoch 155, loss: 0.005, val acc: 0.726 (best 0.730), test acc: 0.754 (best 0.752)
In epoch 160, loss: 0.004, val acc: 0.726 (best 0.730), test acc: 0.754 (best 0.752)
In epoch 165, loss: 0.004, val acc: 0.724 (best 0.730), test acc: 0.754 (best 0.752)
In epoch 170, loss: 0.004, val acc: 0.724 (best 0.730), test acc: 0.754 (best 0.752)
In epoch 175, loss: 0.004, val acc: 0.724 (best 0.730), test acc: 0.754 (best 0.752)
In epoch 180, loss: 0.003, val acc: 0.722 (best 0.730), test acc: 0.754 (best 0.752)
In epoch 185, loss: 0.003, val acc: 0.722 (best 0.730), test acc: 0.754 (best 0.752)
In epoch 190, loss: 0.003, val acc: 0.724 (best 0.730), test acc: 0.754 (best 0.752)
In epoch 195, loss: 0.003, val acc: 0.724 (best 0.730), test acc: 0.754 (best 0.752)

通过用户定义函数实现更多自定义

DGL 允许用户定义消息和聚合函数,以实现最大表达能力。以下是一个用户定义的消息函数,它等价于 fn.u_mul_e('h', 'w', 'm')

def u_mul_e_udf(edges):
    return {"m": edges.src["h"] * edges.data["w"]}

edges 有三个成员:srcdatadst,分别表示所有边缘的源节点特征、边缘特征和目标节点特征。

你也可以编写自己的聚合函数。例如,以下代码等价于内置的 fn.mean('m', 'h_N') 函数,它平均传入的消息

def mean_udf(nodes):
    return {"h_N": nodes.mailbox["m"].mean(1)}

简而言之,DGL 会根据节点的入度进行分组,并为每个分组将传入的消息沿第二个维度堆叠起来。然后你可以在第二个维度上执行规约操作来聚合消息。

有关使用用户定义函数自定义消息和聚合函数的更多详细信息,请参阅 API 参考

编写自定义 GNN 模块的最佳实践

DGL 推荐以下按偏好排序的最佳实践

  • 使用 dgl.nn 模块。

  • 使用 dgl.nn.functional 函数,它们包含较低级别的复杂操作,例如计算每个节点的传入边缘上的 softmax。

  • 使用 update_all 以及内置的消息和聚合函数。

  • 使用用户定义的消息或聚合函数。

下一步?

# Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018
# sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png'

脚本总运行时间: (0 分钟 7.084 秒)

图库由 Sphinx-Gallery 生成