DGL

入门

  • 安装和设置
  • DGL 速览

进阶材料

  • 🆕 使用 GraphBolt 对 GNN 进行随机训练
    • 邻居采样概述
    • 节点分类
    • 链接预测
    • 多 GPU 节点分类
    • 从原始数据构建 OnDiskDataset
  • 用户指南
  • 用户指南【包含过时信息】
  • 사용자 가이드[시대에 뒤쳐진]
  • 🆕 教程:图 Transformer
  • 教程:dgl.sparse
  • 在 CPU 上训练
  • 在多 GPU 上训练
  • 分布式训练
  • 使用 DGL 进行论文研读

API 参考

  • dgl
  • dgl.data
  • dgl.dataloading
  • dgl.DGLGraph
  • dgl.distributed
  • dgl.function
  • dgl.geometry
  • 🆕 dgl.graphbolt
  • dgl.nn (PyTorch)
  • dgl.nn.functional
  • dgl.ops
  • dgl.optim
  • dgl.sampling
  • dgl.sparse
  • dgl.multiprocessing
  • dgl.transforms
  • 用户自定义函数

注意事项

  • 贡献给 DGL
  • DGL 外部函数接口 (FFI)
  • 性能基准测试

杂项

  • 常见问题解答 (FAQ)
  • 环境变量
  • 资源
DGL
  • 🆕 使用 GraphBolt 对 GNN 进行随机训练
  • 链接预测
  • 查看页面源码

链接预测

Open In Colab GitHub

本教程将展示如何在使用 CoraGraphDataset 数据集上训练多层 GraphSAGE 模型进行链接预测。该数据集包含 2708 个节点和 10556 条边。

学完本教程后,您将能够

  • 使用 DGL 的邻居采样组件,在目标设备上训练用于链接预测的 GNN 模型。

安装 DGL 包

[1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

# Install the CPU version in default. If you want to install CUDA version,
# please refer to https://dgl.ac.cn/pages/start.html and change runtime type
# accordingly.
device = torch.device("cpu")
!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html

try:
    import dgl
    import dgl.graphbolt as gb
    installed = True
except ImportError as error:
    installed = False
    print(error)
print("DGL installed!" if installed else "DGL not found!")
Looking in links: https://data.dgl.ai/wheels-test/repo.html
Collecting dgl
  Downloading https://data.dgl.ai/wheels-test/dgl-2.2a240410-cp310-cp310-manylinux1_x86_64.whl (221.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 221.8/221.8 MB 19.2 MB/s eta 0:00:00
Requirement already satisfied: numpy>=1.14.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (1.26.4)
Requirement already satisfied: scipy>=1.1.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (1.14.1)
Requirement already satisfied: networkx>=2.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (3.4.2)
Requirement already satisfied: requests>=2.19.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (2.32.3)
Requirement already satisfied: tqdm in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (4.66.6)
Requirement already satisfied: psutil>=5.8.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (6.1.0)
Collecting torchdata>=0.5.0 (from dgl)
  Downloading torchdata-0.9.0-cp310-cp310-manylinux1_x86_64.whl.metadata (5.5 kB)
Requirement already satisfied: pandas in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (2.2.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (2024.8.30)
Requirement already satisfied: torch>=2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torchdata>=0.5.0->dgl) (2.1.0+cpu)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2024.2)
Requirement already satisfied: six>=1.5 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->dgl) (1.16.0)
Requirement already satisfied: filelock in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.16.1)
Requirement already satisfied: typing-extensions in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.12.2)
Requirement already satisfied: sympy in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.13.3)
Requirement already satisfied: jinja2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.4)
Requirement already satisfied: fsspec in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2024.10.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (3.0.2)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)
Downloading torchdata-0.9.0-cp310-cp310-manylinux1_x86_64.whl (2.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB 26.0 MB/s eta 0:00:00
Installing collected packages: torchdata, dgl
Successfully installed dgl-2.2a240410 torchdata-0.9.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pythonlang.cn/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
DGL installed!

加载数据集

cora 在 GraphBolt 中已被准备为 BuiltinDataset。

[2]:
dataset = gb.BuiltinDataset("cora-seeds").load()
Downloading datasets/cora-seeds.zip from https://data.dgl.ai/dataset/graphbolt/cora-seeds.zip...
Extracting file to datasets
Start to preprocess the on-disk dataset.
Finish preprocessing the on-disk dataset.

数据集由图、特征和任务组成。您可以从任务中获取训练集、验证集和测试集。种子节点和相应的标签已存储在每个训练集、验证集和测试集中。此数据集包含 2 个任务,一个用于节点分类,另一个用于链接预测。我们将使用链接预测任务。

[3]:
graph = dataset.graph.to(device)
feature = dataset.feature.to(device)
train_set = dataset.tasks[1].train_set
test_set = dataset.tasks[1].test_set
task_name = dataset.tasks[1].metadata["name"]
print(f"Task: {task_name}.")
Task: link_prediction.

在 DGL 中定义邻居采样器和数据加载器

与完整图的链接预测教程不同,在大图上训练 GNN 的常见做法是按小批量迭代边,因为计算所有边的概率通常是不可能的。对于每个边的小批量,您可以使用邻居采样和 GNN 计算其关联节点的输出表示,其方式类似于节点分类教程中介绍的。

要执行链接预测,您需要指定一个负采样器。DGL 提供了内置的负采样器,例如 dgl.graphbolt.UniformNegativeSampler。本教程中,每个正例均匀抽取 5 个负例。

除了负采样器,其余代码与节点分类教程中的代码相同。

[4]:
from functools import partial
datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)
datapipe = datapipe.copy_to(device)
datapipe = datapipe.sample_uniform_negative(graph, 5)
datapipe = datapipe.sample_neighbor(graph, [5, 5])
datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
train_dataloader = gb.DataLoader(datapipe)

您可以从 train_dataloader 中查看一个小型批量,看看它会提供什么。

[5]:
data = next(iter(train_dataloader))
print(f"MiniBatch: {data}")
MiniBatch: MiniBatch(seeds=tensor([[2630, 2699],
                        [  49, 2034],
                        [ 415, 1677],
                        ...,
                        [ 454, 2217],
                        [ 454,  469],
                        [ 454,  276]], dtype=torch.int32),
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([   0,    1,    1,  ..., 7000, 7001, 7004], dtype=torch.int32),
                                                                         indices=tensor([1040, 1304, 1306,  ..., 2440, 1303,  870], dtype=torch.int32),
                                                           ),
                                               original_row_node_ids=tensor([2630, 2699,   49,  ..., 2559, 1277, 2653], dtype=torch.int32),
                                               original_edge_ids=tensor([10386,   183,  8408,  ...,  8952,  8953,  8954], dtype=torch.int32),
                                               original_column_node_ids=tensor([2630, 2699,   49,  ..., 1748,   96, 2164], dtype=torch.int32),
                            ),
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([   0,    1,    1,  ..., 3765, 3766, 3771], dtype=torch.int32),
                                                                         indices=tensor([1040, 1304,   48,  ..., 1825,  268, 2251], dtype=torch.int32),
                                                           ),
                                               original_row_node_ids=tensor([2630, 2699,   49,  ..., 1748,   96, 2164], dtype=torch.int32),
                                               original_edge_ids=tensor([10386,   183,  8410,  ...,  9165,  9166,  9167], dtype=torch.int32),
                                               original_column_node_ids=tensor([2630, 2699,   49,  ..., 1886,  241, 2217], dtype=torch.int32),
                            )],
          node_features={'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
                                [0., 0., 0.,  ..., 0., 0., 0.],
                                [0., 0., 0.,  ..., 0., 0., 0.],
                                ...,
                                [0., 0., 0.,  ..., 0., 0., 0.],
                                [0., 0., 0.,  ..., 0., 0., 0.],
                                [0., 0., 0.,  ..., 0., 0., 0.]])},
          labels=tensor([1., 1., 1.,  ..., 0., 0., 0.]),
          input_nodes=tensor([2630, 2699,   49,  ..., 2559, 1277, 2653], dtype=torch.int32),
          indexes=tensor([  0,   1,   2,  ..., 255, 255, 255]),
          edge_features=[{},
                        {}],
          compacted_seeds=tensor([[   0,    1],
                                  [   2,    3],
                                  [   4,    5],
                                  ...,
                                  [ 187, 1303],
                                  [ 187, 1033],
                                  [ 187,  695]], dtype=torch.int32),
          blocks=[Block(num_src_nodes=2523, num_dst_nodes=2252, num_edges=7004),
                 Block(num_src_nodes=2252, num_dst_nodes=1304, num_edges=3771)],
       )

定义用于节点表示的模型

让我们考虑使用邻居采样训练一个 2 层的 GraphSAGE 模型。模型可以写成如下形式

[6]:
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F


class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.hidden_size = hidden_size
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x

定义训练循环

以下初始化模型并定义优化器。

[7]:
in_size = feature.size("node", None, "feat")[0]
model = SAGE(in_size, 128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

以下是用于链接预测和评估的训练循环。

[8]:
from tqdm.auto import tqdm
for epoch in range(3):
    model.train()
    total_loss = 0
    for step, data in tqdm(enumerate(train_dataloader)):
        # Get node pairs with labels for loss calculation.
        compacted_seeds = data.compacted_seeds.T
        labels = data.labels
        node_feature = data.node_features["feat"]
        # Convert sampled subgraphs to DGL blocks.
        blocks = data.blocks

        # Get the embeddings of the input nodes.
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[compacted_seeds[0]] * y[compacted_seeds[1]]
        ).squeeze()

        # Compute loss.
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch:03d} | Loss {total_loss / (step + 1):.3f}")
Epoch 000 | Loss 0.562
Epoch 001 | Loss 0.449
Epoch 002 | Loss 0.445

使用链接预测评估性能

[9]:
model.eval()

datapipe = gb.ItemSampler(test_set, batch_size=256, shuffle=False)
datapipe = datapipe.copy_to(device)
# Since we need to use all neghborhoods for evaluation, we set the fanout
# to -1.
datapipe = datapipe.sample_neighbor(graph, [-1, -1])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
eval_dataloader = gb.DataLoader(datapipe, num_workers=0)

logits = []
labels = []
for step, data in tqdm(enumerate(eval_dataloader)):
    # Get node pairs with labels for loss calculation.
    compacted_seeds = data.compacted_seeds.T
    label = data.labels

    # The features of sampled nodes.
    x = data.node_features["feat"]

    # Forward.
    y = model(data.blocks, x)
    logit = (
        model.predictor(y[compacted_seeds[0]] * y[compacted_seeds[1]])
        .squeeze()
        .detach()
    )

    logits.append(logit)
    labels.append(label)

logits = torch.cat(logits, dim=0)
labels = torch.cat(labels, dim=0)


# Compute the AUROC score.
from sklearn.metrics import roc_auc_score

auc = roc_auc_score(labels.cpu(), logits.cpu())
print("Link Prediction AUC:", auc)
Link Prediction AUC: 0.6895649244176906

结论

在本教程中,您学习了如何使用邻居采样训练多层 GraphSAGE 进行链接预测。

上一个 下一个

© 版权所有 2018, DGL 团队。

使用 Sphinx 并借助于 Read the Docs 提供的 主题 构建。
Read the Docs v: 最新
版本
下载
在 Read the Docs 上
项目主页
构建版本