6.9 数据加载并行性
在 GNN 的小批量训练中,我们通常需要经历几个阶段来生成一个小的训练批次,包括
迭代项目集合并按批次大小生成小批量种子。
从图中为每个种子采样负样本项。
从图中为每个种子采样邻居。
从采样的子图中排除种子边。
获取采样子图的节点和边特征。
将小批量数据复制到目标设备。
datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_uniform_negative(g, 5)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
所有这些阶段都通过独立的 IterableDataPipe 实现,并与 PyTorch DataLoader 组合在一起。这种设计使得我们可以通过将不同的数据管道串联起来轻松定制数据加载过程。例如,如果我们要从图中为每个种子采样负样本项,只需将 NegativeSampler
串联在 ItemSampler
之后即可。
但简单地将数据管道串联在一起会产生性能开销,因为不同阶段会利用 CPU、GPU、PCIe 等各种硬件资源。因此,对数据加载机制进行了优化,以最小化开销并实现最佳性能。
具体来说,GraphBolt 在 fetch_feature
之前使用多进程包装数据管道,从而使多个进程可以并行运行。对于 fetch_feature
数据管道,我们让它在主进程中运行,以避免进程间的数据移动开销。
此外,为了重叠数据移动和模型计算,我们在 copy_to
之前使用 torchdata.datapipes.iter.Prefetcher 包装数据管道,它会从前一个数据管道预取元素并放入缓冲区。这种预取对用户完全透明,无需额外代码。它为 GNN 的小批量训练带来了显著的性能提升。
有关更多详细信息,请参阅 DataLoader
的源代码。