GraphDataLoader

class dgl.dataloading.GraphDataLoader(dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs)[source]

基类: DataLoader

批量图数据加载器。

用于批量迭代一组图的 PyTorch 数据加载器,生成所述小批量图的批量图和相应的标签张量(如果提供)。

参数:
  • dataset (torch.utils.data.Dataset) – 从中加载图的数据集。

  • collate_fn (Function, default is None) – 自定义 collate 函数。如果未给出,将使用默认的 collate 函数。

  • use_ddp (boolean, optional) –

    如果为 True,则告知 DataLoader 使用 torch.utils.data.distributed.DistributedSampler 为每个参与进程适当地拆分训练集。

    覆盖 torch.utils.data.DataLoadersampler 参数。

  • ddp_seed (int, optional) –

    torch.utils.data.distributed.DistributedSampler 中用于打乱数据集的种子。

    仅当 use_ddp 为 True 时有效。

  • kwargs (dict) –

    要传递给父 PyTorch torch.utils.data.DataLoader 类的关键字参数。常用参数有:

    • batch_size (int): 每个批次中的索引数量。

    • drop_last (bool): 是否丢弃最后一个不完整的批次。

    • shuffle (bool): 每个 epoch 是否随机打乱索引。

示例

要在 dataset 中的一组图上训练用于图分类的 GNN

>>> dataloader = dgl.dataloading.GraphDataLoader(
...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
...     train_on(batched_graph, labels)

使用分布式数据并行 (Distributed Data Parallel)

如果您正在使用 PyTorch 的分布式训练(例如,使用 torch.nn.parallel.DistributedDataParallel 时),您可以通过开启 use_ddp 选项来训练模型

>>> dataloader = dgl.dataloading.GraphDataLoader(
...     dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
...     dataloader.set_epoch(epoch)
...     for batched_graph, labels in dataloader:
...         train_on(batched_graph, labels)