多GPU节点分类

本教程展示了如何使用由 Open Graph Benchmark (OGB) 提供的 ogbn-products 数据集,训练用于节点分类的多层 GraphSAGE。该数据集包含约 240 万个节点和 6200 万条边。

Open In Colab GitHub

通过本教程的学习,您将能够

  • 使用 DGL 的邻居采样组件,在多个 GPU 上训练用于节点分类的 GNN 模型。在学习了如何使用多个 GPU 后,您可以将其扩展到其他场景,例如链接预测。

安装 DGL 包和其他依赖项

[1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

# Install the CUDA version. If you want to install CPU version, please
# refer to https://dgl.ac.cn/pages/start.html.
!pip install --pre dgl -f https://data.dgl.ai/wheels-test/cu121/repo.html
!pip install torchmetrics multiprocess

try:
    import dgl
    import dgl.graphbolt as gb
    installed = True
except ImportError as error:
    installed = False
    print(error)
print("DGL installed!" if installed else "DGL not found!")
Looking in indexes: https://pypi.ac.cn/simple, https://pypi.ngc.nvidia.com
Looking in links: https://data.dgl.ai/wheels-test/cu121/repo.html
Requirement already satisfied: dgl in /localscratch/dgl-3/python (2.1)
Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (1.24.4)
Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (1.11.4)
Requirement already satisfied: networkx>=2.1 in /usr/local/lib/python3.10/dist-packages (from dgl) (2.6.3)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (2.31.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from dgl) (4.66.1)
Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (5.9.4)
Requirement already satisfied: torchdata>=0.5.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (0.7.0a0)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (1.26.18)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (2023.11.17)
Requirement already satisfied: torch>=2 in /usr/local/lib/python3.10/dist-packages (from torchdata>=0.5.0->dgl) (2.2.0a0+81ea7a4)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.13.1)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.8.0)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.12)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.2)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (2023.12.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.org.cn/warnings/venv

[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: python -m pip install --upgrade pip
Looking in indexes: https://pypi.ac.cn/simple, https://pypi.ngc.nvidia.com
Requirement already satisfied: torchmetrics in /usr/local/lib/python3.10/dist-packages (1.3.0.post0)
Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (0.70.16)
Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (1.24.4)
Requirement already satisfied: packaging>17.1 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (23.2)
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (2.2.0a0+81ea7a4)
Requirement already satisfied: lightning-utilities>=0.8.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (0.10.1)
Requirement already satisfied: dill>=0.3.8 in /usr/local/lib/python3.10/dist-packages (from multiprocess) (0.3.8)
Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (68.2.2)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (4.8.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (3.13.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (1.12)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (2.6.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (3.1.2)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (2023.12.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->torchmetrics) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->torchmetrics) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.org.cn/warnings/venv

[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: python -m pip install --upgrade pip
DGL installed!

在 DGL 中定义邻居采样器和数据加载器

与之前的教程的主要区别在于,我们将使用 DistributedItemSampler 代替 ItemSampler 来采样节点的 mini-batch。DistributedItemSamplerItemSampler 的分布式版本,可与 DistributedDataParallel 一起使用。它被实现为 ItemSampler 的一个包装器,并且会在所有副本上采样相同的 minibatch。它还支持丢弃最后一个非完整 mini-batch,以避免需要填充。

[2]:
def create_dataloader(graph, features, itemset, device, is_train):
    datapipe = gb.DistributedItemSampler(
        item_set=itemset,
        batch_size=1024,
        drop_last=is_train,
        shuffle=is_train,
        drop_uneven_inputs=is_train,
    )
    datapipe = datapipe.copy_to(device)
    # Now that we have moved to device, sample_neighbor and fetch_feature steps
    # will be executed on GPUs.
    datapipe = datapipe.sample_neighbor(graph, [10, 10, 10])
    datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
    return gb.DataLoader(datapipe)

跨 GPU 的加权归约

由于不同的 GPU 可能处理不同数量的数据点,我们定义了一个函数,以加权方式计算损失或准确率等值的精确平均值。

[3]:
import torch.distributed as dist

def weighted_reduce(tensor, weight, dst=0):
    ########################################################################
    # (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
    # obtain overall average values.
    #
    # `torch.distributed.reduce` is used to reduce tensors from all the
    # sub-processes to a specified process, ReduceOp.SUM is used by default.
    #
    # Because the GPUs may have differing numbers of processed items, we
    # perform a weighted mean to calculate the exact loss and accuracy.
    ########################################################################
    dist.reduce(tensor=tensor, dst=dst)
    weight = torch.tensor(weight, device=tensor.device)
    dist.reduce(tensor=weight, dst=dst)
    return tensor / weight

定义模型

让我们考虑训练一个带有邻居采样的 3 层 GraphSAGE 模型。模型可以编写如下

[4]:
from torch import nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # Three-layer GraphSAGE-mean.
        self.layers.append(SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(SAGEConv(hidden_size, hidden_size, "mean"))
        self.layers.append(SAGEConv(hidden_size, out_size, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.hidden_size = hidden_size
        self.out_size = out_size
        # Set the dtype for the layers manually.
        self.float()

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
                hidden_x = self.dropout(hidden_x)
        return hidden_x

评估函数

评估函数可用于计算训练期间的验证准确率或训练结束时的测试准确率。与之前的教程的区别在于,我们需要返回每个 GPU 处理的项目数量,以便进行加权平均。

[5]:
import torchmetrics.functional as MF
import tqdm

@torch.no_grad()
def evaluate(rank, model, graph, features, itemset, num_classes, device):
    model.eval()
    y = []
    y_hats = []
    dataloader = create_dataloader(
        graph,
        features,
        itemset,
        device,
        is_train=False,
    )

    for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader:
        blocks = data.blocks
        x = data.node_features["feat"]
        y.append(data.labels)
        y_hats.append(model.module(blocks, x))

    res = MF.accuracy(
        torch.cat(y_hats),
        torch.cat(y),
        task="multiclass",
        num_classes=num_classes,
    )

    return res.to(device), sum(y_i.size(0) for y_i in y)

训练循环

训练循环与之前的教程几乎相同。在本教程中,我们明确禁用了来自数据加载器的不均匀输入,但是,可以使用 Join Context Manager 来处理在 epoch 结束时可能出现的不完整 batch 进行训练。更多信息请参考此教程

[6]:
import time

def train(
    rank,
    graph,
    features,
    train_set,
    valid_set,
    num_classes,
    model,
    device,
):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    # Create training data loader.
    dataloader = create_dataloader(
        graph,
        features,
        train_set,
        device,
        is_train=True,
    )

    for epoch in range(5):
        epoch_start = time.time()

        model.train()
        total_loss = torch.tensor(0, dtype=torch.float, device=device)
        num_train_items = 0
        for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader:
            # The input features are from the source nodes in the first
            # layer's computation graph.
            x = data.node_features["feat"]

            # The ground truth labels are from the destination nodes
            # in the last layer's computation graph.
            y = data.labels

            blocks = data.blocks

            y_hat = model(blocks, x)

            # Compute loss.
            loss = F.cross_entropy(y_hat, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.detach() * y.size(0)
            num_train_items += y.size(0)

        # Evaluate the model.
        if rank == 0:
            print("Validating...")
        acc, num_val_items = evaluate(
            rank,
            model,
            graph,
            features,
            valid_set,
            num_classes,
            device,
        )
        total_loss = weighted_reduce(total_loss, num_train_items)
        acc = weighted_reduce(acc * num_val_items, num_val_items)

        # We synchronize before measuring the epoch time.
        torch.cuda.synchronize()
        epoch_end = time.time()
        if rank == 0:
            print(
                f"Epoch {epoch:05d} | "
                f"Average Loss {total_loss.item():.4f} | "
                f"Accuracy {acc.item():.4f} | "
                f"Time {epoch_end - epoch_start:.4f}"
            )

定义训练和评估过程

以下代码定义了每个进程的主函数。它与之前的教程相似,但需要使用 torch.distributed 初始化分布式训练上下文,并使用 torch.nn.parallel.DistributedDataParallel 包装模型。

[7]:
def run(rank, world_size, devices, dataset):
    # Set up multiprocessing environment.
    device = devices[rank]
    torch.cuda.set_device(device)
    dist.init_process_group(
        backend="nccl",  # Use NCCL backend for distributed GPU training
        init_method="tcp://127.0.0.1:12345",
        world_size=world_size,
        rank=rank,
    )

    # Pin the graph and features in-place to enable GPU access.
    graph = dataset.graph.pin_memory_()
    features = dataset.feature.pin_memory_()
    train_set = dataset.tasks[0].train_set
    valid_set = dataset.tasks[0].validation_set
    num_classes = dataset.tasks[0].metadata["num_classes"]

    in_size = features.size("node", None, "feat")[0]
    hidden_size = 256
    out_size = num_classes

    # Create GraphSAGE model. It should be copied onto a GPU as a replica.
    model = SAGE(in_size, hidden_size, out_size).to(device)
    model = nn.parallel.DistributedDataParallel(model)

    # Model training.
    if rank == 0:
        print("Training...")
    train(
        rank,
        graph,
        features,
        train_set,
        valid_set,
        num_classes,
        model,
        device,
    )

    # Test the model.
    if rank == 0:
        print("Testing...")
    test_set = dataset.tasks[0].test_set
    test_acc, num_test_items = evaluate(
        rank,
        model,
        graph,
        features,
        itemset=test_set,
        num_classes=num_classes,
        device=device,
    )
    test_acc = weighted_reduce(test_acc * num_test_items, num_test_items)

    if rank == 0:
        print(f"Test Accuracy {test_acc.item():.4f}")

生成训练进程

以下代码为每个 GPU 生成一个进程,并调用上面定义的 run 函数。

[8]:
import torch.multiprocessing as mp

def main():
    if not torch.cuda.is_available():
        print("No GPU found!")
        return

    devices = [
        torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())
    ][:1]
    world_size = len(devices)

    print(f"Training with {world_size} gpus.")

    # Load and preprocess dataset.
    dataset = gb.BuiltinDataset("ogbn-products").load()

    # Thread limiting to avoid resource competition.
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // world_size)

    if world_size > 1:
        # The following launch method is not supported in a notebook.
        mp.set_sharing_strategy("file_system")
        mp.spawn(
            run,
            args=(world_size, devices, dataset),
            nprocs=world_size,
            join=True,
        )
    else:
        run(0, 1, devices, dataset)


if __name__ == "__main__":
    main()
Training with 1 gpus.
The dataset is already preprocessed.
Training...
192it [00:09, 21.32it/s]
Validating...
39it [00:00, 78.32it/s]
Epoch 00000 | Average Loss 1.2953 | Accuracy 0.8556 | Time 9.5520
192it [00:03, 61.08it/s]
Validating...
39it [00:00, 79.10it/s]
Epoch 00001 | Average Loss 0.5859 | Accuracy 0.8788 | Time 3.6609
192it [00:03, 62.82it/s]
Validating...
39it [00:00, 80.55it/s]
Epoch 00002 | Average Loss 0.4858 | Accuracy 0.8852 | Time 3.5646
192it [00:03, 60.34it/s]
Validating...
39it [00:00, 44.41it/s]
Epoch 00003 | Average Loss 0.4407 | Accuracy 0.8920 | Time 4.0852
192it [00:03, 58.87it/s]
Validating...
39it [00:00, 78.52it/s]
Epoch 00004 | Average Loss 0.4122 | Accuracy 0.8943 | Time 3.7938
Testing...
2162it [00:24, 89.75it/s]
Test Accuracy 0.7514