链接预测
本教程将展示如何在使用 CoraGraphDataset 数据集上训练多层 GraphSAGE 模型进行链接预测。该数据集包含 2708 个节点和 10556 条边。
学完本教程后,您将能够
使用 DGL 的邻居采样组件,在目标设备上训练用于链接预测的 GNN 模型。
安装 DGL 包
[1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"
# Install the CPU version in default. If you want to install CUDA version,
# please refer to https://dgl.ac.cn/pages/start.html and change runtime type
# accordingly.
device = torch.device("cpu")
!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html
try:
import dgl
import dgl.graphbolt as gb
installed = True
except ImportError as error:
installed = False
print(error)
print("DGL installed!" if installed else "DGL not found!")
Looking in links: https://data.dgl.ai/wheels-test/repo.html
Collecting dgl
Downloading https://data.dgl.ai/wheels-test/dgl-2.2a240410-cp310-cp310-manylinux1_x86_64.whl (221.8 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 221.8/221.8 MB 19.2 MB/s eta 0:00:00
Requirement already satisfied: numpy>=1.14.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (1.26.4)
Requirement already satisfied: scipy>=1.1.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (1.14.1)
Requirement already satisfied: networkx>=2.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (3.4.2)
Requirement already satisfied: requests>=2.19.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (2.32.3)
Requirement already satisfied: tqdm in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (4.66.6)
Requirement already satisfied: psutil>=5.8.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (6.1.0)
Collecting torchdata>=0.5.0 (from dgl)
Downloading torchdata-0.9.0-cp310-cp310-manylinux1_x86_64.whl.metadata (5.5 kB)
Requirement already satisfied: pandas in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (2.2.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (2024.8.30)
Requirement already satisfied: torch>=2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torchdata>=0.5.0->dgl) (2.1.0+cpu)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2024.2)
Requirement already satisfied: six>=1.5 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->dgl) (1.16.0)
Requirement already satisfied: filelock in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.16.1)
Requirement already satisfied: typing-extensions in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.12.2)
Requirement already satisfied: sympy in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.13.3)
Requirement already satisfied: jinja2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.4)
Requirement already satisfied: fsspec in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2024.10.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (3.0.2)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)
Downloading torchdata-0.9.0-cp310-cp310-manylinux1_x86_64.whl (2.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB 26.0 MB/s eta 0:00:00
Installing collected packages: torchdata, dgl
Successfully installed dgl-2.2a240410 torchdata-0.9.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pythonlang.cn/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
DGL installed!
加载数据集
cora
在 GraphBolt 中已被准备为 BuiltinDataset
。
[2]:
dataset = gb.BuiltinDataset("cora-seeds").load()
Downloading datasets/cora-seeds.zip from https://data.dgl.ai/dataset/graphbolt/cora-seeds.zip...
Extracting file to datasets
Start to preprocess the on-disk dataset.
Finish preprocessing the on-disk dataset.
数据集由图、特征和任务组成。您可以从任务中获取训练集、验证集和测试集。种子节点和相应的标签已存储在每个训练集、验证集和测试集中。此数据集包含 2 个任务,一个用于节点分类,另一个用于链接预测。我们将使用链接预测任务。
[3]:
graph = dataset.graph.to(device)
feature = dataset.feature.to(device)
train_set = dataset.tasks[1].train_set
test_set = dataset.tasks[1].test_set
task_name = dataset.tasks[1].metadata["name"]
print(f"Task: {task_name}.")
Task: link_prediction.
在 DGL 中定义邻居采样器和数据加载器
与完整图的链接预测教程不同,在大图上训练 GNN 的常见做法是按小批量迭代边,因为计算所有边的概率通常是不可能的。对于每个边的小批量,您可以使用邻居采样和 GNN 计算其关联节点的输出表示,其方式类似于节点分类教程中介绍的。
要执行链接预测,您需要指定一个负采样器。DGL 提供了内置的负采样器,例如 dgl.graphbolt.UniformNegativeSampler
。本教程中,每个正例均匀抽取 5 个负例。
除了负采样器,其余代码与节点分类教程中的代码相同。
[4]:
from functools import partial
datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)
datapipe = datapipe.copy_to(device)
datapipe = datapipe.sample_uniform_negative(graph, 5)
datapipe = datapipe.sample_neighbor(graph, [5, 5])
datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
train_dataloader = gb.DataLoader(datapipe)
您可以从 train_dataloader 中查看一个小型批量,看看它会提供什么。
[5]:
data = next(iter(train_dataloader))
print(f"MiniBatch: {data}")
MiniBatch: MiniBatch(seeds=tensor([[2630, 2699],
[ 49, 2034],
[ 415, 1677],
...,
[ 454, 2217],
[ 454, 469],
[ 454, 276]], dtype=torch.int32),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([ 0, 1, 1, ..., 7000, 7001, 7004], dtype=torch.int32),
indices=tensor([1040, 1304, 1306, ..., 2440, 1303, 870], dtype=torch.int32),
),
original_row_node_ids=tensor([2630, 2699, 49, ..., 2559, 1277, 2653], dtype=torch.int32),
original_edge_ids=tensor([10386, 183, 8408, ..., 8952, 8953, 8954], dtype=torch.int32),
original_column_node_ids=tensor([2630, 2699, 49, ..., 1748, 96, 2164], dtype=torch.int32),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([ 0, 1, 1, ..., 3765, 3766, 3771], dtype=torch.int32),
indices=tensor([1040, 1304, 48, ..., 1825, 268, 2251], dtype=torch.int32),
),
original_row_node_ids=tensor([2630, 2699, 49, ..., 1748, 96, 2164], dtype=torch.int32),
original_edge_ids=tensor([10386, 183, 8410, ..., 9165, 9166, 9167], dtype=torch.int32),
original_column_node_ids=tensor([2630, 2699, 49, ..., 1886, 241, 2217], dtype=torch.int32),
)],
node_features={'feat': tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])},
labels=tensor([1., 1., 1., ..., 0., 0., 0.]),
input_nodes=tensor([2630, 2699, 49, ..., 2559, 1277, 2653], dtype=torch.int32),
indexes=tensor([ 0, 1, 2, ..., 255, 255, 255]),
edge_features=[{},
{}],
compacted_seeds=tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
...,
[ 187, 1303],
[ 187, 1033],
[ 187, 695]], dtype=torch.int32),
blocks=[Block(num_src_nodes=2523, num_dst_nodes=2252, num_edges=7004),
Block(num_src_nodes=2252, num_dst_nodes=1304, num_edges=3771)],
)
定义用于节点表示的模型
让我们考虑使用邻居采样训练一个 2 层的 GraphSAGE 模型。模型可以写成如下形式
[6]:
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
class SAGE(nn.Module):
def __init__(self, in_size, hidden_size):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
self.hidden_size = hidden_size
self.predictor = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
)
def forward(self, blocks, x):
hidden_x = x
for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
hidden_x = layer(block, hidden_x)
is_last_layer = layer_idx == len(self.layers) - 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
return hidden_x
定义训练循环
以下初始化模型并定义优化器。
[7]:
in_size = feature.size("node", None, "feat")[0]
model = SAGE(in_size, 128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
以下是用于链接预测和评估的训练循环。
[8]:
from tqdm.auto import tqdm
for epoch in range(3):
model.train()
total_loss = 0
for step, data in tqdm(enumerate(train_dataloader)):
# Get node pairs with labels for loss calculation.
compacted_seeds = data.compacted_seeds.T
labels = data.labels
node_feature = data.node_features["feat"]
# Convert sampled subgraphs to DGL blocks.
blocks = data.blocks
# Get the embeddings of the input nodes.
y = model(blocks, node_feature)
logits = model.predictor(
y[compacted_seeds[0]] * y[compacted_seeds[1]]
).squeeze()
# Compute loss.
loss = F.binary_cross_entropy_with_logits(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch:03d} | Loss {total_loss / (step + 1):.3f}")
Epoch 000 | Loss 0.562
Epoch 001 | Loss 0.449
Epoch 002 | Loss 0.445
使用链接预测评估性能
[9]:
model.eval()
datapipe = gb.ItemSampler(test_set, batch_size=256, shuffle=False)
datapipe = datapipe.copy_to(device)
# Since we need to use all neghborhoods for evaluation, we set the fanout
# to -1.
datapipe = datapipe.sample_neighbor(graph, [-1, -1])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
eval_dataloader = gb.DataLoader(datapipe, num_workers=0)
logits = []
labels = []
for step, data in tqdm(enumerate(eval_dataloader)):
# Get node pairs with labels for loss calculation.
compacted_seeds = data.compacted_seeds.T
label = data.labels
# The features of sampled nodes.
x = data.node_features["feat"]
# Forward.
y = model(data.blocks, x)
logit = (
model.predictor(y[compacted_seeds[0]] * y[compacted_seeds[1]])
.squeeze()
.detach()
)
logits.append(logit)
labels.append(label)
logits = torch.cat(logits, dim=0)
labels = torch.cat(labels, dim=0)
# Compute the AUROC score.
from sklearn.metrics import roc_auc_score
auc = roc_auc_score(labels.cpu(), logits.cpu())
print("Link Prediction AUC:", auc)
Link Prediction AUC: 0.6895649244176906
结论
在本教程中,您学习了如何使用邻居采样训练多层 GraphSAGE 进行链接预测。