单机多GPU Mini-batch节点分类

在本教程中,你将学习如何使用多个GPU来训练用于节点分类的图神经网络(GNN)。

本教程假设你已经阅读过使用DGL进行节点分类的随机GNN训练。它还假设你了解使用DistributedDataParallel进行多GPU通用模型训练的基础知识。

注意

请参阅PyTorch提供的本教程,了解使用DistributedDataParallel进行多GPU通用训练。此外,请参阅多GPU图分类教程的第一部分,了解如何在DGL中使用DistributedDataParallel的概述。

导入包

我们使用torch.distributed来初始化分布式训练环境,并使用torch.multiprocessing为每个GPU创建多个进程。

import os

os.environ["DGLBACKEND"] = "pytorch"
import time

import dgl.graphbolt as gb
import dgl.nn as dglnn
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm.auto import tqdm

定义模型

模型将再次与使用DGL进行节点分类的随机GNN训练中的模型相同。

class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # Three-layer GraphSAGE-mean.
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.hidden_size = hidden_size
        self.out_size = out_size
        # Set the dtype for the layers manually.
        self.float()

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
                hidden_x = self.dropout(hidden_x)
        return hidden_x

Mini-batch数据加载

与之前的教程主要区别在于,我们将使用DistributedItemSampler而不是ItemSampler来采样节点的mini-batches。DistributedItemSamplerItemSampler的分布式版本,可与DistributedDataParallel一起使用。它被实现为ItemSampler的包装器,并在所有副本上采样相同的minibatch。它还支持丢弃最后一个非完整的minibatch,以避免填充的需要。

def create_dataloader(
    graph,
    features,
    itemset,
    device,
    is_train,
):
    datapipe = gb.DistributedItemSampler(
        item_set=itemset,
        batch_size=1024,
        drop_last=is_train,
        shuffle=is_train,
        drop_uneven_inputs=is_train,
    )
    datapipe = datapipe.copy_to(device)
    # Now that we have moved to device, sample_neighbor and fetch_feature steps
    # will be executed on GPUs.
    datapipe = datapipe.sample_neighbor(graph, [10, 10, 10])
    datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
    return gb.DataLoader(datapipe)


def weighted_reduce(tensor, weight, dst=0):
    ########################################################################
    # (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
    # obtain overall average values.
    #
    # `torch.distributed.reduce` is used to reduce tensors from all the
    # sub-processes to a specified process, ReduceOp.SUM is used by default.
    #
    # Because the GPUs may have differing numbers of processed items, we
    # perform a weighted mean to calculate the exact loss and accuracy.
    ########################################################################
    dist.reduce(tensor=tensor, dst=dst)
    weight = torch.tensor(weight, device=tensor.device)
    dist.reduce(tensor=weight, dst=dst)
    return tensor / weight

评估循环

评估循环几乎与之前的教程相同。

@torch.no_grad()
def evaluate(rank, model, graph, features, itemset, num_classes, device):
    model.eval()
    y = []
    y_hats = []
    dataloader = create_dataloader(
        graph,
        features,
        itemset,
        device,
        is_train=False,
    )

    for data in tqdm(dataloader) if rank == 0 else dataloader:
        blocks = data.blocks
        x = data.node_features["feat"]
        y.append(data.labels)
        y_hats.append(model.module(blocks, x))

    res = MF.accuracy(
        torch.cat(y_hats),
        torch.cat(y),
        task="multiclass",
        num_classes=num_classes,
    )

    return res.to(device), sum(y_i.size(0) for y_i in y)

训练循环

训练循环也几乎与之前的教程相同,不同之处在于我们使用 Join Context Manager 来解决输入不均匀的问题。PyTorch中分布式数据并行(DDP)训练的机制要求所有 rank 的输入数量相同,否则程序可能会出错或挂起。为了解决这个问题,PyTorch提供了 Join Context Manager。详细信息请参阅本教程

def train(
    rank,
    graph,
    features,
    train_set,
    valid_set,
    num_classes,
    model,
    device,
):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    # Create training data loader.
    dataloader = create_dataloader(
        graph,
        features,
        train_set,
        device,
        is_train=True,
    )

    for epoch in range(5):
        epoch_start = time.time()

        model.train()
        total_loss = torch.tensor(0, dtype=torch.float, device=device)
        num_train_items = 0
        with Join([model]):
            for data in tqdm(dataloader) if rank == 0 else dataloader:
                # The input features are from the source nodes in the first
                # layer's computation graph.
                x = data.node_features["feat"]

                # The ground truth labels are from the destination nodes
                # in the last layer's computation graph.
                y = data.labels

                blocks = data.blocks

                y_hat = model(blocks, x)

                # Compute loss.
                loss = F.cross_entropy(y_hat, y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.detach() * y.size(0)
                num_train_items += y.size(0)

        # Evaluate the model.
        if rank == 0:
            print("Validating...")
        acc, num_val_items = evaluate(
            rank,
            model,
            graph,
            features,
            valid_set,
            num_classes,
            device,
        )
        total_loss = weighted_reduce(total_loss, num_train_items)
        acc = weighted_reduce(acc * num_val_items, num_val_items)

        # We synchronize before measuring the epoch time.
        torch.cuda.synchronize()
        epoch_end = time.time()
        if rank == 0:
            print(
                f"Epoch {epoch:05d} | "
                f"Average Loss {total_loss.item():.4f} | "
                f"Accuracy {acc.item():.4f} | "
                f"Time {epoch_end - epoch_start:.4f}"
            )

定义训练和评估过程

以下代码定义了每个进程的主函数。它与之前的教程相似,不同之处在于我们需要使用torch.distributed初始化分布式训练环境,并使用torch.nn.parallel.DistributedDataParallel包装模型。

def run(rank, world_size, devices, dataset):
    # Set up multiprocessing environment.
    device = devices[rank]
    torch.cuda.set_device(device)
    dist.init_process_group(
        backend="nccl",  # Use NCCL backend for distributed GPU training
        init_method="tcp://127.0.0.1:12345",
        world_size=world_size,
        rank=rank,
    )

    # Pin the graph and features in-place to enable GPU access.
    graph = dataset.graph.pin_memory_()
    features = dataset.feature.pin_memory_()
    train_set = dataset.tasks[0].train_set
    valid_set = dataset.tasks[0].validation_set
    num_classes = dataset.tasks[0].metadata["num_classes"]

    in_size = features.size("node", None, "feat")[0]
    hidden_size = 256
    out_size = num_classes

    # Create GraphSAGE model. It should be copied onto a GPU as a replica.
    model = SAGE(in_size, hidden_size, out_size).to(device)
    model = DDP(model)

    # Model training.
    if rank == 0:
        print("Training...")
    train(
        rank,
        graph,
        features,
        train_set,
        valid_set,
        num_classes,
        model,
        device,
    )

    # Test the model.
    if rank == 0:
        print("Testing...")
    test_set = dataset.tasks[0].test_set
    test_acc, num_test_items = evaluate(
        rank,
        model,
        graph,
        features,
        itemset=test_set,
        num_classes=num_classes,
        device=device,
    )
    test_acc = weighted_reduce(test_acc * num_test_items, num_test_items)

    if rank == 0:
        print(f"Test Accuracy {test_acc.item():.4f}")

生成训练进程

以下代码为每个GPU生成一个进程,并调用上面定义的run函数。

def main():
    if not torch.cuda.is_available():
        print("No GPU found!")
        return

    devices = [
        torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())
    ]
    world_size = len(devices)

    print(f"Training with {world_size} gpus.")

    # Load and preprocess dataset.
    dataset = gb.BuiltinDataset("ogbn-arxiv").load()

    # Thread limiting to avoid resource competition.
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // world_size)

    mp.set_sharing_strategy("file_system")
    mp.spawn(
        run,
        args=(world_size, devices, dataset),
        nprocs=world_size,
        join=True,
    )


if __name__ == "__main__":
    main()
No GPU found!

脚本总运行时间: (0 分钟 0.076 秒)

由 Sphinx-Gallery 生成的图集