ClusterGCNSampler

class dgl.dataloading.ClusterGCNSampler(g, k, cache_path='cluster_gcn.pkl', balance_ntypes=None, balkance_edges=False, mode='k-way', prefetch_ndata=None, prefetch_edata=None, output_device=None)[source]

基类: Sampler

来自 Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks 的 Cluster 采样器

此采样器首先使用 METIS 分割对图进行划分,然后将每个分区的节点缓存到给定缓存目录下的文件中。

然后,采样器根据提供的分区 ID 选择图分区,获取这些分区中所有节点的并集,并在其 sample 方法中返回一个诱导子图。

参数:
  • g (DGLGraph) – 原始图。必须是同构图且在 CPU 上。

  • k (int) – 分区数量。

  • cache_path (str) – 用于存储分区结果的缓存目录路径。

  • balance_ntypes – 传递给 dgl.metis_partition_assignment()

  • balkance_edges – 传递给 dgl.metis_partition_assignment()

  • mode – 传递给 dgl.metis_partition_assignment()

  • prefetch_ndata (list[str], 可选) –

    为子图预取的节点数据。

    有关预取的详细说明,请参阅 guide-minibatch-prefetching

  • prefetch_edata (list[str], 可选) –

    为子图预取的边数据。

    有关预取的详细说明,请参阅 guide-minibatch-prefetching

  • output_device (device, 可选) – 输出子图或 MFG 的设备。默认为与分区索引 minibatch 相同的设备。

示例

节点分类

使用此采样器,数据加载器将接受分区 ID 列表作为要迭代的索引。例如,以下代码首先使用 METIS 将图分割成 1000 个分区,然后在每次迭代时获取一个由 20 个随机选择的分区所覆盖的节点诱导的子图。

>>> num_parts = 1000
>>> sampler = dgl.dataloading.ClusterGCNSampler(g, num_parts)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, torch.arange(num_parts), sampler,
...     batch_size=20, shuffle=True, drop_last=False, num_workers=4)
>>> for subg in dataloader:
...     train_on(subg)