图 Transformer 简介
Transformer (Vaswani et al. 2017) 已被证明是一种在自然语言处理和计算机视觉领域有效的学习架构。 近期,研究人员转向探索 Transformer 在图学习中的应用。他们在许多实际任务上已取得初步成功,例如图属性预测。Dwivedi et al. (2020) 首次将 Transformer 神经网络架构泛化到图结构数据。在这里,我们展示如何使用 DGL 的稀疏矩阵 API 构建这样一个图 Transformer。
[ ]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"
# Uncomment below to install required packages. If the CUDA version is not 11.8,
# check the https://dgl.ac.cn/pages/start.html to find the supported CUDA
# version and corresponding command to install DGL.
#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null
#!pip install ogb >/dev/null
try:
import dgl
installed = True
except ImportError:
installed = False
print("DGL installed!" if installed else "Failed to install DGL!")
稀疏多头注意力
回顾原始 Transformer 中的全连接缩放点积注意力机制:
图 Transformer (GT) 模型采用稀疏多头注意力模块:
其中 \(Q, K, V ∈\mathbb{R}^{N\times d}\) 分别是查询特征、键特征和值特征。 \(A\in[0,1]^{N\times N}\) 是输入图的邻接矩阵。 \((QK^T)\circ A\) 意味着查询矩阵和键矩阵的乘法之后,与稀疏邻接矩阵进行 Hadamard 积(或逐元素乘法),如下图所示:
本质上,只根据 \(A\) 的稀疏性计算连接节点之间的注意力分数。 这个操作也称为 Sampled Dense Dense Matrix Multiplication (SDDMM)。
利用 DGL 的批处理 SDDMM API,我们可以并行计算多个注意力头(不同的表示子空间)。
[ ]:
import dgl
import dgl.nn as dglnn
import dgl.sparse as dglsp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dgl.data import AsGraphPredDataset
from dgl.dataloading import GraphDataLoader
from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator
from ogb.graphproppred.mol_encoder import AtomEncoder
from tqdm import tqdm
class SparseMHA(nn.Module):
"""Sparse Multi-head Attention Module"""
def __init__(self, hidden_size=80, num_heads=8):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scaling = self.head_dim**-0.5
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.out_proj = nn.Linear(hidden_size, hidden_size)
def forward(self, A, h):
N = len(h)
# [N, dh, nh]
q = self.q_proj(h).reshape(N, self.head_dim, self.num_heads)
q *= self.scaling
# [N, dh, nh]
k = self.k_proj(h).reshape(N, self.head_dim, self.num_heads)
# [N, dh, nh]
v = self.v_proj(h).reshape(N, self.head_dim, self.num_heads)
######################################################################
# (HIGHLIGHT) Compute the multi-head attention with Sparse Matrix API
######################################################################
attn = dglsp.bsddmm(A, q, k.transpose(1, 0)) # (sparse) [N, N, nh]
# Sparse softmax by default applies on the last sparse dimension.
attn = attn.softmax() # (sparse) [N, N, nh]
out = dglsp.bspmm(attn, v) # [N, dh, nh]
return self.out_proj(out.reshape(N, -1))
图 Transformer 层
GT 层由多头注意力、Batch Norm 和前馈网络组成,像原始 Transformer 一样通过残差连接连接。
[ ]:
class GTLayer(nn.Module):
"""Graph Transformer Layer"""
def __init__(self, hidden_size=80, num_heads=8):
super().__init__()
self.MHA = SparseMHA(hidden_size=hidden_size, num_heads=num_heads)
self.batchnorm1 = nn.BatchNorm1d(hidden_size)
self.batchnorm2 = nn.BatchNorm1d(hidden_size)
self.FFN1 = nn.Linear(hidden_size, hidden_size * 2)
self.FFN2 = nn.Linear(hidden_size * 2, hidden_size)
def forward(self, A, h):
h1 = h
h = self.MHA(A, h)
h = self.batchnorm1(h + h1)
h2 = h
h = self.FFN2(F.relu(self.FFN1(h)))
h = h2 + h
return self.batchnorm2(h)
图 Transformer 模型
GT 模型通过堆叠 GT 层构建。原始 Transformer 的输入位置编码被替换为拉普拉斯位置编码 (Dwivedi et al. 2020)。对于图级别预测任务,在 GT 层之上堆叠一个额外的池化器,用于聚合同一个图中的节点特征。
[ ]:
class GTModel(nn.Module):
def __init__(
self,
out_size,
hidden_size=80,
pos_enc_size=2,
num_layers=8,
num_heads=8,
):
super().__init__()
self.atom_encoder = AtomEncoder(hidden_size)
self.pos_linear = nn.Linear(pos_enc_size, hidden_size)
self.layers = nn.ModuleList(
[GTLayer(hidden_size, num_heads) for _ in range(num_layers)]
)
self.pooler = dglnn.SumPooling()
self.predictor = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Linear(hidden_size // 2, hidden_size // 4),
nn.ReLU(),
nn.Linear(hidden_size // 4, out_size),
)
def forward(self, g, X, pos_enc):
indices = torch.stack(g.edges())
N = g.num_nodes()
A = dglsp.spmatrix(indices, shape=(N, N))
h = self.atom_encoder(X) + self.pos_linear(pos_enc)
for layer in self.layers:
h = layer(A, h)
h = self.pooler(g, h)
return self.predictor(h)
训练
我们在 ogbg-molhiv 基准测试上训练 GT 模型。每个图的拉普拉斯位置编码被预先计算(使用这里的 API),作为模型输入的一部分。
请注意,为了让此演示运行更快,我们对数据集进行了下采样。关于在完整数据集上的性能,请参考示例脚本。
[ ]:
@torch.no_grad()
def evaluate(model, dataloader, evaluator, device):
model.eval()
y_true = []
y_pred = []
for batched_g, labels in dataloader:
batched_g, labels = batched_g.to(device), labels.to(device)
y_hat = model(batched_g, batched_g.ndata["feat"], batched_g.ndata["PE"])
y_true.append(labels.view(y_hat.shape).detach().cpu())
y_pred.append(y_hat.detach().cpu())
y_true = torch.cat(y_true, dim=0).numpy()
y_pred = torch.cat(y_pred, dim=0).numpy()
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)["rocauc"]
def train(model, dataset, evaluator, device):
train_dataloader = GraphDataLoader(
dataset[dataset.train_idx],
batch_size=256,
shuffle=True,
collate_fn=collate_dgl,
)
valid_dataloader = GraphDataLoader(
dataset[dataset.val_idx], batch_size=256, collate_fn=collate_dgl
)
test_dataloader = GraphDataLoader(
dataset[dataset.test_idx], batch_size=256, collate_fn=collate_dgl
)
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 20
scheduler = optim.lr_scheduler.StepLR(
optimizer, step_size=num_epochs, gamma=0.5
)
loss_fcn = nn.BCEWithLogitsLoss()
for epoch in range(num_epochs):
model.train()
total_loss = 0.0
for batched_g, labels in train_dataloader:
batched_g, labels = batched_g.to(device), labels.to(device)
logits = model(
batched_g, batched_g.ndata["feat"], batched_g.ndata["PE"]
)
loss = loss_fcn(logits, labels.float())
total_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
avg_loss = total_loss / len(train_dataloader)
val_metric = evaluate(model, valid_dataloader, evaluator, device)
test_metric = evaluate(model, test_dataloader, evaluator, device)
print(
f"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, "
f"Val: {val_metric:.4f}, Test: {test_metric:.4f}"
)
# Training device.
dev = torch.device("cpu")
# Uncomment the code below to train on GPU. Be sure to install DGL with CUDA support.
#dev = torch.device("cuda:0")
# Load dataset.
pos_enc_size = 8
dataset = AsGraphPredDataset(
DglGraphPropPredDataset("ogbg-molhiv", "./data/OGB")
)
evaluator = Evaluator("ogbg-molhiv")
# Down sample the dataset to make the tutorial run faster.
import random
random.seed(42)
train_size = len(dataset.train_idx)
val_size = len(dataset.val_idx)
test_size = len(dataset.test_idx)
dataset.train_idx = dataset.train_idx[
torch.LongTensor(random.sample(range(train_size), 2000))
]
dataset.val_idx = dataset.val_idx[
torch.LongTensor(random.sample(range(val_size), 1000))
]
dataset.test_idx = dataset.test_idx[
torch.LongTensor(random.sample(range(test_size), 1000))
]
# Laplacian positional encoding.
indices = torch.cat([dataset.train_idx, dataset.val_idx, dataset.test_idx])
for idx in tqdm(indices, desc="Computing Laplacian PE"):
g, _ = dataset[idx]
g.ndata["PE"] = dgl.laplacian_pe(g, k=pos_enc_size, padding=True)
# Create model.
out_size = dataset.num_tasks
model = GTModel(out_size=out_size, pos_enc_size=pos_enc_size).to(dev)
# Kick off training.
train(model, dataset, evaluator, dev)
Computing Laplacian PE: 1%| | 25/4000 [00:00<00:16, 244.77it/s]/usr/local/lib/python3.8/dist-packages/dgl/backend/pytorch/tensor.py:52: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:250.)
return th.as_tensor(data, dtype=dtype)
Computing Laplacian PE: 100%|██████████| 4000/4000 [00:13<00:00, 296.04it/s]
Epoch: 000, Loss: 0.2486, Val: 0.3082, Test: 0.3068
Epoch: 001, Loss: 0.1695, Val: 0.4684, Test: 0.4572
Epoch: 002, Loss: 0.1428, Val: 0.5887, Test: 0.4721
Epoch: 003, Loss: 0.1237, Val: 0.6375, Test: 0.5010
Epoch: 004, Loss: 0.1127, Val: 0.6628, Test: 0.4854
Epoch: 005, Loss: 0.1047, Val: 0.6811, Test: 0.4983
Epoch: 006, Loss: 0.0949, Val: 0.6751, Test: 0.5409
Epoch: 007, Loss: 0.0901, Val: 0.6340, Test: 0.5357
Epoch: 008, Loss: 0.0811, Val: 0.6717, Test: 0.5543
Epoch: 009, Loss: 0.0643, Val: 0.7861, Test: 0.5628
Epoch: 010, Loss: 0.0489, Val: 0.7319, Test: 0.5341
Epoch: 011, Loss: 0.0340, Val: 0.7884, Test: 0.5299
Epoch: 012, Loss: 0.0285, Val: 0.5887, Test: 0.4293
Epoch: 013, Loss: 0.0361, Val: 0.5514, Test: 0.3419
Epoch: 014, Loss: 0.0451, Val: 0.6795, Test: 0.4964
Epoch: 015, Loss: 0.0429, Val: 0.7405, Test: 0.5527
Epoch: 016, Loss: 0.0331, Val: 0.7859, Test: 0.4994
Epoch: 017, Loss: 0.0177, Val: 0.6544, Test: 0.4457
Epoch: 018, Loss: 0.0201, Val: 0.8250, Test: 0.6073
Epoch: 019, Loss: 0.0093, Val: 0.7356, Test: 0.5561