GraphDataLoader
- class dgl.dataloading.GraphDataLoader(dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs)[source]
基类:
DataLoader
批量图数据加载器。
用于批量迭代一组图的 PyTorch 数据加载器,生成所述小批量图的批量图和相应的标签张量(如果提供)。
- 参数:
dataset (torch.utils.data.Dataset) – 从中加载图的数据集。
collate_fn (Function, default is None) – 自定义 collate 函数。如果未给出,将使用默认的 collate 函数。
use_ddp (boolean, optional) –
如果为 True,则告知 DataLoader 使用
torch.utils.data.distributed.DistributedSampler
为每个参与进程适当地拆分训练集。覆盖
torch.utils.data.DataLoader
的sampler
参数。ddp_seed (int, optional) –
在
torch.utils.data.distributed.DistributedSampler
中用于打乱数据集的种子。仅当
use_ddp
为 True 时有效。kwargs (dict) –
要传递给父 PyTorch
torch.utils.data.DataLoader
类的关键字参数。常用参数有:batch_size
(int): 每个批次中的索引数量。drop_last
(bool): 是否丢弃最后一个不完整的批次。shuffle
(bool): 每个 epoch 是否随机打乱索引。
示例
要在
dataset
中的一组图上训练用于图分类的 GNN>>> dataloader = dgl.dataloading.GraphDataLoader( ... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for batched_graph, labels in dataloader: ... train_on(batched_graph, labels)
使用分布式数据并行 (Distributed Data Parallel)
如果您正在使用 PyTorch 的分布式训练(例如,使用
torch.nn.parallel.DistributedDataParallel
时),您可以通过开启use_ddp
选项来训练模型>>> dataloader = dgl.dataloading.GraphDataLoader( ... dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for epoch in range(start_epoch, n_epochs): ... dataloader.set_epoch(epoch) ... for batched_graph, labels in dataloader: ... train_on(batched_graph, labels)