SAINTSampler

class dgl.dataloading.SAINTSampler(mode, budget, cache=True, prefetch_ndata=None, prefetch_edata=None, output_device='cpu')[源代码]

基类: Sampler

来自 GraphSAINT: Graph Sampling Based Inductive Learning Method 的随机节点/边/游走采样器

对于每一次调用,该采样器会采样一个节点子集,然后返回一个由这些节点诱导的子图。采样节点子集有三个选项

  • 对于 'node' 采样器,采样一个节点的概率与其出度成正比。

  • 'edge' 采样器首先采样一个边子集,然后使用这些边的端点节点。

  • 'walk' 采样器使用随机游走访问到的节点。它均匀地选择一些根节点,然后从每个根节点执行固定长度的随机游走。

参数:
  • mode (str) – 要使用的采样器,可以是 'node''edge''walk'

  • budget (inttuple[int]) –

    采样器配置。

    • 对于 'node' 采样器,budget 指定每个采样子图中的节点数。

    • 对于 'edge' 采样器,budget 指定用于诱导子图的边数。

    • 对于 'walk' 采样器,budget 是一个 tuple。budget[0] 指定生成随机游走的根节点数。budget[1] 指定随机游走的长度。

  • cache (bool, 可选) – 如果为 False,则不会缓存用于采样的概率数组。如果您想在不同的图中使用该采样器,则必须将其设置为 False。

  • prefetch_ndata (list[str], 可选) –

    为子图预取的节点数据。

    有关预取的详细说明,请参阅 guide-minibatch-prefetching

  • prefetch_edata (list[str], 可选) –

    为子图预取的边数据。

    有关预取的详细说明,请参阅 guide-minibatch-prefetching

  • output_device (device, 可选) – 输出子图的设备。

示例

>>> import torch
>>> from dgl.dataloading import SAINTSampler, DataLoader
>>> num_iters = 1000
>>> sampler = SAINTSampler(mode='node', budget=6000)
>>> # Assume g.ndata['feat'] and g.ndata['label'] hold node features and labels
>>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4)
>>> for subg in dataloader:
...     train_on(subg)