ClusterGCNSampler
- class dgl.dataloading.ClusterGCNSampler(g, k, cache_path='cluster_gcn.pkl', balance_ntypes=None, balkance_edges=False, mode='k-way', prefetch_ndata=None, prefetch_edata=None, output_device=None)[source]
基类:
Sampler
来自 Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks 的 Cluster 采样器
此采样器首先使用 METIS 分割对图进行划分,然后将每个分区的节点缓存到给定缓存目录下的文件中。
然后,采样器根据提供的分区 ID 选择图分区,获取这些分区中所有节点的并集,并在其
sample
方法中返回一个诱导子图。- 参数:
g (DGLGraph) – 原始图。必须是同构图且在 CPU 上。
k (int) – 分区数量。
cache_path (str) – 用于存储分区结果的缓存目录路径。
balance_ntypes – 传递给
dgl.metis_partition_assignment()
。balkance_edges – 传递给
dgl.metis_partition_assignment()
。mode – 传递给
dgl.metis_partition_assignment()
。prefetch_ndata (list[str], 可选) –
为子图预取的节点数据。
有关预取的详细说明,请参阅 guide-minibatch-prefetching。
prefetch_edata (list[str], 可选) –
为子图预取的边数据。
有关预取的详细说明,请参阅 guide-minibatch-prefetching。
output_device (device, 可选) – 输出子图或 MFG 的设备。默认为与分区索引 minibatch 相同的设备。
示例
节点分类
使用此采样器,数据加载器将接受分区 ID 列表作为要迭代的索引。例如,以下代码首先使用 METIS 将图分割成 1000 个分区,然后在每次迭代时获取一个由 20 个随机选择的分区所覆盖的节点诱导的子图。
>>> num_parts = 1000 >>> sampler = dgl.dataloading.ClusterGCNSampler(g, num_parts) >>> dataloader = dgl.dataloading.DataLoader( ... g, torch.arange(num_parts), sampler, ... batch_size=20, shuffle=True, drop_last=False, num_workers=4) >>> for subg in dataloader: ... train_on(subg)