关系图卷积网络

作者: Lingfan Yu, Mufei Li, Zheng Zhang

警告

本教程旨在通过代码解释来深入理解论文。因此,此实现并未针对运行效率进行优化。对于推荐的实现,请参考官方示例

在本教程中,您将学习如何实现关系图卷积网络 (R-GCN)。这种类型的网络是泛化 GCN 以处理知识库中实体之间不同关系的一种尝试。要了解更多关于 R-GCN 背后的研究,请参阅用图卷积网络建模关系数据

直接的图卷积网络 (GCN) 利用数据集的结构信息(即图连接性)来改进节点表示的提取。图的边未指定类型。

知识图谱由主语、关系、宾语形式的三元组集合组成。因此,边编码了重要信息并具有自身要学习的嵌入。此外,任意给定的一对节点之间可能存在多条边。

R-GCN 简介

统计关系学习 (SRL) 中,有两个基本任务

  • 实体分类 - 您将类型和分类属性分配给实体。

  • 链接预测 - 您恢复缺失的三元组。

在这两种情况下,缺失的信息预计可以从图的邻域结构中恢复。例如,前面引用的 R-GCN 论文提供了以下示例。知道 Mikhail Baryshnikov 在瓦加诺娃芭蕾舞学院接受教育,既意味着 Mikhail Baryshnikov 应该具有人 (person) 的标签,也意味着三元组 (Mikhail Baryshnikov, 居住在, 俄罗斯) 必须属于知识图谱。

R-GCN 使用一个普通的图卷积网络解决了这两个问题。它通过多边编码进行扩展以计算实体的嵌入,但下游处理不同。

  • 实体分类是通过在实体(节点)的最终嵌入上附加一个 softmax 分类器来完成的。训练使用标准的交叉熵损失。

  • 链接预测是通过使用参数化的评分函数,用自编码器架构重建边来完成的。训练使用负采样。

本教程侧重于第一个任务,实体分类,以展示如何生成实体表示。两个任务的完整代码可以在 DGL Github 仓库中找到。

R-GCN 的关键思想

回想一下,在 GCN 中,每个节点 \(i\) 在第 \((l+1)\) 层的隐藏表示计算如下

\[\begin{split}h_i^{l+1} = \sigma\left(\sum_{j\in N_i}\frac{1}{c_i} W^{(l)} h_j^{(l)}\right)~~~~~~~~~~(1)\\\end{split}\]

其中 \(c_i\) 是归一化常数。

R-GCN 和 GCN 之间的主要区别在于 R-GCN 中,边可以表示不同的关系。在 GCN 中,方程 \((1)\) 中的权重 \(W^{(l)}\) 在第 \(l\) 层的所有边之间共享。相比之下,在 R-GCN 中,不同的边类型使用不同的权重,只有相同关系类型 \(r\) 的边与相同的投影权重 \(W_r^{(l)}\) 相关联。

因此,R-GCN 中实体在第 \((l+1)\) 层的隐藏表示可以表示为以下方程

\[\begin{split}h_i^{l+1} = \sigma\left(W_0^{(l)}h_i^{(l)}+\sum_{r\in R}\sum_{j\in N_i^r}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}\right)~~~~~~~~~~(2)\\\end{split}\]

其中 \(N_i^r\) 表示在关系 \(r\in R\) 下节点 \(i\) 的邻居索引集合,\(c_{i,r}\) 是归一化常数。在实体分类中,R-GCN 论文使用 \(c_{i,r}=|N_i^r|\)

直接应用上述方程的问题是参数数量的快速增长,特别是在高度多关系数据中。为了减少模型参数大小并防止过拟合,原论文提出使用基分解。

\[\begin{split}W_r^{(l)}=\sum\limits_{b=1}^B a_{rb}^{(l)}V_b^{(l)}~~~~~~~~~~(3)\\\end{split}\]

因此,权重 \(W_r^{(l)}\) 是基变换 \(V_b^{(l)}\) 的线性组合,系数为 \(a_{rb}^{(l)}\)。基的数量 \(B\) 远小于知识库中关系的数目。

注意

另一种权重正则化,块分解,在链接预测中实现。

在 DGL 中实现 R-GCN

一个 R-GCN 模型由多个 R-GCN 层组成。第一个 R-GCN 层也充当输入层,接收与节点实体相关的特征(例如,描述文本)并投影到隐藏空间。在本教程中,我们仅使用实体 ID 作为实体特征。

R-GCN 层

对于每个节点,R-GCN 层执行以下步骤

  • 使用节点表示和与边类型相关的权重矩阵计算传出消息(消息函数)

  • 聚合传入消息并生成新的节点表示(归约和应用函数)

以下代码是 R-GCN 隐藏层的定义。

注意

每种关系类型都关联着不同的权重。因此,完整的权重矩阵具有三个维度:关系、输入特征、输出特征。

注意

这里展示了如何从零开始实现 R-GCN。DGL 提供了更高效的内置 R-GCN 层模块

import os

os.environ["DGLBACKEND"] = "pytorch"
from functools import partial

import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph


class RGCNLayer(nn.Module):
    def __init__(
        self,
        in_feat,
        out_feat,
        num_rels,
        num_bases=-1,
        bias=None,
        activation=None,
        is_input_layer=False,
    ):
        super(RGCNLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.num_rels = num_rels
        self.num_bases = num_bases
        self.bias = bias
        self.activation = activation
        self.is_input_layer = is_input_layer

        # sanity check
        if self.num_bases <= 0 or self.num_bases > self.num_rels:
            self.num_bases = self.num_rels
        # weight bases in equation (3)
        self.weight = nn.Parameter(
            torch.Tensor(self.num_bases, self.in_feat, self.out_feat)
        )
        if self.num_bases < self.num_rels:
            # linear combination coefficients in equation (3)
            self.w_comp = nn.Parameter(
                torch.Tensor(self.num_rels, self.num_bases)
            )
        # add bias
        if self.bias:
            self.bias = nn.Parameter(torch.Tensor(out_feat))
        # init trainable parameters
        nn.init.xavier_uniform_(
            self.weight, gain=nn.init.calculate_gain("relu")
        )
        if self.num_bases < self.num_rels:
            nn.init.xavier_uniform_(
                self.w_comp, gain=nn.init.calculate_gain("relu")
            )
        if self.bias:
            nn.init.xavier_uniform_(
                self.bias, gain=nn.init.calculate_gain("relu")
            )

    def forward(self, g):
        if self.num_bases < self.num_rels:
            # generate all weights from bases (equation (3))
            weight = self.weight.view(
                self.in_feat, self.num_bases, self.out_feat
            )
            weight = torch.matmul(self.w_comp, weight).view(
                self.num_rels, self.in_feat, self.out_feat
            )
        else:
            weight = self.weight
        if self.is_input_layer:

            def message_func(edges):
                # for input layer, matrix multiply can be converted to be
                # an embedding lookup using source node id
                embed = weight.view(-1, self.out_feat)
                index = edges.data[dgl.ETYPE] * self.in_feat + edges.src["id"]
                return {"msg": embed[index] * edges.data["norm"]}

        else:

            def message_func(edges):
                w = weight[edges.data[dgl.ETYPE]]
                msg = torch.bmm(edges.src["h"].unsqueeze(1), w).squeeze()
                msg = msg * edges.data["norm"]
                return {"msg": msg}

        def apply_func(nodes):
            h = nodes.data["h"]
            if self.bias:
                h = h + self.bias
            if self.activation:
                h = self.activation(h)
            return {"h": h}

        g.update_all(message_func, fn.sum(msg="msg", out="h"), apply_func)

完整的 R-GCN 模型定义

class Model(nn.Module):
    def __init__(
        self,
        num_nodes,
        h_dim,
        out_dim,
        num_rels,
        num_bases=-1,
        num_hidden_layers=1,
    ):
        super(Model, self).__init__()
        self.num_nodes = num_nodes
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.num_rels = num_rels
        self.num_bases = num_bases
        self.num_hidden_layers = num_hidden_layers

        # create rgcn layers
        self.build_model()

        # create initial features
        self.features = self.create_features()

    def build_model(self):
        self.layers = nn.ModuleList()
        # input to hidden
        i2h = self.build_input_layer()
        self.layers.append(i2h)
        # hidden to hidden
        for _ in range(self.num_hidden_layers):
            h2h = self.build_hidden_layer()
            self.layers.append(h2h)
        # hidden to output
        h2o = self.build_output_layer()
        self.layers.append(h2o)

    # initialize feature for each node
    def create_features(self):
        features = torch.arange(self.num_nodes)
        return features

    def build_input_layer(self):
        return RGCNLayer(
            self.num_nodes,
            self.h_dim,
            self.num_rels,
            self.num_bases,
            activation=F.relu,
            is_input_layer=True,
        )

    def build_hidden_layer(self):
        return RGCNLayer(
            self.h_dim,
            self.h_dim,
            self.num_rels,
            self.num_bases,
            activation=F.relu,
        )

    def build_output_layer(self):
        return RGCNLayer(
            self.h_dim,
            self.out_dim,
            self.num_rels,
            self.num_bases,
            activation=partial(F.softmax, dim=1),
        )

    def forward(self, g):
        if self.features is not None:
            g.ndata["id"] = self.features
        for layer in self.layers:
            layer(g)
        return g.ndata.pop("h")

处理数据集

本教程使用 R-GCN 论文中的应用信息学与形式化描述方法研究所 (AIFB) 数据集。

# load graph data
dataset = dgl.data.rdf.AIFBDataset()
g = dataset[0]
category = dataset.predict_category
train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop("label")
num_rels = len(g.canonical_etypes)
num_classes = dataset.num_classes
# normalization factor
for cetype in g.canonical_etypes:
    g.edges[cetype].data["norm"] = dgl.norm_by_dst(g, cetype).unsqueeze(1)
category_id = g.ntypes.index(category)
Downloading /root/.dgl/aifb-hetero.zip from https://data.dgl.ai/dataset/rdf/aifb-hetero.zip...

/root/.dgl/aifb-hetero.zip:   0%|          | 0.00/344k [00:00<?, ?B/s]
/root/.dgl/aifb-hetero.zip: 100%|██████████| 344k/344k [00:00<00:00, 14.7MB/s]
Extracting file to /root/.dgl/aifb-hetero_82d021d8
Parsing file aifbfixed_complete.n3 ...
Processed 0 tuples, found 0 valid tuples.
Processed 10000 tuples, found 8406 valid tuples.
Processed 20000 tuples, found 16622 valid tuples.
Adding reverse edges ...
Creating one whole graph ...
Total #nodes: 7262
Total #edges: 48810
Convert to heterograph ...
#Node types: 7
#Canonical edge types: 104
#Unique edge type names: 78
Load training/validation/testing split ...
Done saving data into cached files.

创建图和模型

# configurations
n_hidden = 16  # number of hidden units
n_bases = -1  # use number of relations as number of bases
n_hidden_layers = 0  # use 1 input layer, 1 output layer, no hidden layer
n_epochs = 25  # epochs to train
lr = 0.01  # learning rate
l2norm = 0  # L2 norm coefficient

# create graph
g = dgl.to_homogeneous(g, edata=["norm"])
node_ids = torch.arange(g.num_nodes())
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]

# create model
model = Model(
    g.num_nodes(),
    n_hidden,
    num_classes,
    num_rels,
    num_bases=n_bases,
    num_hidden_layers=n_hidden_layers,
)

训练循环

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2norm)

print("start training...")
model.train()
for epoch in range(n_epochs):
    optimizer.zero_grad()
    logits = model.forward(g)
    logits = logits[target_idx]
    loss = F.cross_entropy(logits[train_idx], labels[train_idx])
    loss.backward()

    optimizer.step()

    train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx])
    train_acc = train_acc.item() / len(train_idx)
    val_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
    val_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx])
    val_acc = val_acc.item() / len(test_idx)
    print(
        "Epoch {:05d} | ".format(epoch)
        + "Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format(
            train_acc, loss.item()
        )
        + "Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format(
            val_acc, val_loss.item()
        )
    )
start training...
Epoch 00000 | Train Accuracy: 0.2357 | Train Loss: 1.3862 | Validation Accuracy: 0.3333 | Validation loss: 1.3864
Epoch 00001 | Train Accuracy: 0.9286 | Train Loss: 1.3552 | Validation Accuracy: 0.9444 | Validation loss: 1.3623
Epoch 00002 | Train Accuracy: 0.9357 | Train Loss: 1.3076 | Validation Accuracy: 0.9444 | Validation loss: 1.3241
Epoch 00003 | Train Accuracy: 0.9429 | Train Loss: 1.2439 | Validation Accuracy: 0.9444 | Validation loss: 1.2723
Epoch 00004 | Train Accuracy: 0.9429 | Train Loss: 1.1725 | Validation Accuracy: 0.9444 | Validation loss: 1.2131
Epoch 00005 | Train Accuracy: 0.9429 | Train Loss: 1.1046 | Validation Accuracy: 0.9444 | Validation loss: 1.1560
Epoch 00006 | Train Accuracy: 0.9500 | Train Loss: 1.0454 | Validation Accuracy: 0.9444 | Validation loss: 1.1059
Epoch 00007 | Train Accuracy: 0.9500 | Train Loss: 0.9946 | Validation Accuracy: 0.9444 | Validation loss: 1.0614
Epoch 00008 | Train Accuracy: 0.9500 | Train Loss: 0.9517 | Validation Accuracy: 0.9444 | Validation loss: 1.0211
Epoch 00009 | Train Accuracy: 0.9500 | Train Loss: 0.9164 | Validation Accuracy: 0.9722 | Validation loss: 0.9849
Epoch 00010 | Train Accuracy: 0.9500 | Train Loss: 0.8883 | Validation Accuracy: 0.9722 | Validation loss: 0.9535
Epoch 00011 | Train Accuracy: 0.9500 | Train Loss: 0.8665 | Validation Accuracy: 0.9722 | Validation loss: 0.9272
Epoch 00012 | Train Accuracy: 0.9500 | Train Loss: 0.8498 | Validation Accuracy: 0.9722 | Validation loss: 0.9058
Epoch 00013 | Train Accuracy: 0.9500 | Train Loss: 0.8373 | Validation Accuracy: 0.9444 | Validation loss: 0.8887
Epoch 00014 | Train Accuracy: 0.9500 | Train Loss: 0.8281 | Validation Accuracy: 0.9444 | Validation loss: 0.8754
Epoch 00015 | Train Accuracy: 0.9500 | Train Loss: 0.8214 | Validation Accuracy: 0.9167 | Validation loss: 0.8651
Epoch 00016 | Train Accuracy: 0.9571 | Train Loss: 0.8166 | Validation Accuracy: 0.9167 | Validation loss: 0.8574
Epoch 00017 | Train Accuracy: 0.9571 | Train Loss: 0.8131 | Validation Accuracy: 0.9167 | Validation loss: 0.8515
Epoch 00018 | Train Accuracy: 0.9571 | Train Loss: 0.8104 | Validation Accuracy: 0.9167 | Validation loss: 0.8472
Epoch 00019 | Train Accuracy: 0.9571 | Train Loss: 0.8081 | Validation Accuracy: 0.9167 | Validation loss: 0.8438
Epoch 00020 | Train Accuracy: 0.9571 | Train Loss: 0.8060 | Validation Accuracy: 0.9167 | Validation loss: 0.8413
Epoch 00021 | Train Accuracy: 0.9571 | Train Loss: 0.8041 | Validation Accuracy: 0.9167 | Validation loss: 0.8394
Epoch 00022 | Train Accuracy: 0.9571 | Train Loss: 0.8022 | Validation Accuracy: 0.9167 | Validation loss: 0.8379
Epoch 00023 | Train Accuracy: 0.9571 | Train Loss: 0.8004 | Validation Accuracy: 0.9444 | Validation loss: 0.8368
Epoch 00024 | Train Accuracy: 0.9571 | Train Loss: 0.7984 | Validation Accuracy: 0.9444 | Validation loss: 0.8360