超图神经网络

本教程介绍什么是超图以及如何使用 DGL 的稀疏矩阵 API 构建超图神经网络。

Open In Colab GitHub

[ ]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

# Uncomment below to install required packages. If the CUDA version is not 11.8,
# check the https://dgl.ac.cn/pages/start.html to find the supported CUDA
# version and corresponding command to install DGL.
#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null
#!pip install torchmetrics > /dev/null

try:
    import dgl
    installed = True
except ImportError:
    installed = False
print("DGL installed!" if installed else "Failed to install DGL!")

超图

一个 超图节点超边组成。与图中的边不同,一个超边可以连接任意数量的节点。例如,下图显示了一个包含 11 个节点和 5 条不同颜色绘制的超边的超图。 图像1

当数据集中的数据点之间的关系不是二元的时,超图特别有用。例如,在电子商务系统中,两个以上的产品可以一起被共同购买,因此共同购买的关系是 \(n\) 元的而不是二元的,因此最好将其描述为超图而不是普通图。

超图通常用其关联矩阵 \(H\) 来表征,其中行代表节点,列代表超边。如果超边 \(j\) 包含节点 \(i\),则条目 \(H_{ij}\) 为 1,否则为 0。例如,上图中的超图可以用一个 \(11 \times 5\) 的矩阵来表示,如下所示

\[\begin{split}H = \begin{bmatrix} 1 & 0 & 0 & 0 & 0 \\ 1 & 0 & 0 & 0 & 0 \\ 1 & 1 & 0 & 1 & 1 \\ 0 & 0 & 1 & 0 & 0 \\ 0 & 1 & 0 & 0 & 0 \\ 1 & 0 & 1 & 1 & 1 \\ 0 & 0 & 1 & 0 & 0 \\ 0 & 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 & 0 \\ 0 & 0 & 1 & 0 & 1 \\ 0 & 0 & 0 & 0 & 1 \\ \end{bmatrix}\end{split}\]

可以通过指定两个张量 nodeshyperedges 来构建超图关联矩阵,其中对于所有的 i,节点 ID nodes[i] 属于超边 ID hyperedges[i]。在上述情况下,关联矩阵可以按如下方式构建。

[ ]:
import dgl.sparse as dglsp
import torch

H = dglsp.spmatrix(
    torch.LongTensor([[0, 1, 2, 2, 2, 2, 3, 4, 5, 5, 5, 5, 6, 7, 7, 8, 8, 9, 9, 10],
                      [0, 0, 0, 1, 3, 4, 2, 1, 0, 2, 3, 4, 2, 1, 3, 1, 3, 2, 4, 4]])
)

print(H.to_dense())
tensor([[1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 1., 0., 1., 1.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.],
        [1., 0., 1., 1., 1.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 1., 0.],
        [0., 1., 0., 1., 0.],
        [0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 1.]])

超图中节点的度定义为包含该节点的超边数量。类似地,超图中超边的度定义为该超边包含的节点数量。在上述示例中,超边的度可以通过行向量的和计算(即全部为 4),而节点的度可以通过列向量的和计算。

[ ]:
node_degrees = H.sum(1)
print("Node degrees", node_degrees)

hyperedge_degrees = H.sum(0)
print("Hyperedge degrees", hyperedge_degrees)
Node degrees tensor([1., 1., 4., 1., 1., 4., 1., 2., 2., 2., 1.])
Hyperedge degrees tensor([4., 4., 4., 4., 4.])

超图神经网络 (HGNN) 层

HGNN 层的定义如下:

\[f(X^{(l)}, H; W^{(l)}) = \sigma(L X^{(l)} W^{(l)})\]

\[L = D_v^{-1/2} H B D_e^{-1} H^\top D_v^{-1/2}\]

其中

  • \(H \in \mathbb{R}^{N \times M}\) 是具有 \(N\) 个节点和 \(M\) 条超边的超图的关联矩阵。

  • \(D_v \in \mathbb{R}^{N \times N}\) 是表示节点度的对角矩阵,其第 \(i\) 个对角线元素是 \(\sum_{j=1}^M H_{ij}\)

  • \(D_e \in \mathbb{R}^{M \times M}\) 是表示超边度的对角矩阵,其第 \(j\) 个对角线元素是 \(\sum_{i=1}^N H_{ij}\)

  • \(B \in \mathbb{R}^{M \times M}\) 是表示超边权重的对角矩阵,其第 \(j\) 个对角线元素是第 \(j\) 条超边的权重。在我们的示例中,\(B\) 是一个单位矩阵。

以下代码构建了一个两层的 HGNN。

[ ]:
import dgl.sparse as dglsp
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from dgl.data import CoraGraphDataset
from torchmetrics.functional import accuracy


class HGNN(nn.Module):
    def __init__(self, H, in_size, out_size, hidden_dims=16):
        super().__init__()

        self.W1 = nn.Linear(in_size, hidden_dims)
        self.W2 = nn.Linear(hidden_dims, out_size)
        self.dropout = nn.Dropout(0.5)

        ###########################################################
        # (HIGHLIGHT) Compute the Laplacian with Sparse Matrix API
        ###########################################################
        # Compute node degree.
        d_V = H.sum(1)
        # Compute edge degree.
        d_E = H.sum(0)
        # Compute the inverse of the square root of the diagonal D_v.
        D_v_invsqrt = dglsp.diag(d_V**-0.5)
        # Compute the inverse of the diagonal D_e.
        D_e_inv = dglsp.diag(d_E**-1)
        # In our example, B is an identity matrix.
        n_edges = d_E.shape[0]
        B = dglsp.identity((n_edges, n_edges))
        # Compute Laplacian from the equation above.
        self.L = D_v_invsqrt @ H @ B @ D_e_inv @ H.T @ D_v_invsqrt

    def forward(self, X):
        X = self.L @ self.W1(self.dropout(X))
        X = F.relu(X)
        X = self.L @ self.W2(self.dropout(X))
        return X

加载数据

我们在示例中使用了 Cora 引文网络。但我们没有使用论文之间原始的“引用”关系,而是考虑论文之间的“共同引用”关系。我们从原始引文网络构建了一个超图,其中对于每一篇论文,我们构建了一条包含它引用的所有其他论文以及论文本身的超边。

image1

请注意,以这种方式构建的超图的关联矩阵与原始图的邻接矩阵完全相同(加上一个用于自环的单位矩阵)。这是因为每条超边都与每篇论文一一对应。因此,我们可以直接使用图的邻接矩阵并加上一个单位矩阵,然后将其用作超图的关联矩阵。

[ ]:
def load_data():
    dataset = CoraGraphDataset()

    graph = dataset[0]
    indices = torch.stack(graph.edges())
    H = dglsp.spmatrix(indices)
    H = H + dglsp.identity(H.shape)

    X = graph.ndata["feat"]
    Y = graph.ndata["label"]
    train_mask = graph.ndata["train_mask"]
    val_mask = graph.ndata["val_mask"]
    test_mask = graph.ndata["test_mask"]
    return H, X, Y, dataset.num_classes, train_mask, val_mask, test_mask

训练和评估

现在我们可以编写训练和评估函数,如下所示。

[ ]:
def train(model, optimizer, X, Y, train_mask):
    model.train()
    Y_hat = model(X)
    loss = F.cross_entropy(Y_hat[train_mask], Y[train_mask])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


def evaluate(model, X, Y, val_mask, test_mask, num_classes):
    model.eval()
    Y_hat = model(X)
    val_acc = accuracy(
        Y_hat[val_mask], Y[val_mask], task="multiclass", num_classes=num_classes
    )
    test_acc = accuracy(
        Y_hat[test_mask],
        Y[test_mask],
        task="multiclass",
        num_classes=num_classes,
    )
    return val_acc, test_acc


H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data()
model = HGNN(H, X.shape[1], num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

with tqdm.trange(500) as tq:
    for epoch in tq:
        train(model, optimizer, X, Y, train_mask)
        val_acc, test_acc = evaluate(
            model, X, Y, val_mask, test_mask, num_classes
        )
        tq.set_postfix(
            {
                "Val acc": f"{val_acc:.5f}",
                "Test acc": f"{test_acc:.5f}",
            },
            refresh=False,
        )

print(f"Test acc: {test_acc:.3f}")
Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
Extracting file to /root/.dgl/cora_v2
Finished data loading and preprocessing.
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.
100%|██████████| 500/500 [00:57<00:00,  8.70it/s, Val acc=0.77800, Test acc=0.78100]
Test acc: 0.781

关于完整的 HGNN 示例,请参考此处