SAINTSampler
- class dgl.dataloading.SAINTSampler(mode, budget, cache=True, prefetch_ndata=None, prefetch_edata=None, output_device='cpu')[源代码]
基类:
Sampler
来自 GraphSAINT: Graph Sampling Based Inductive Learning Method 的随机节点/边/游走采样器
对于每一次调用,该采样器会采样一个节点子集,然后返回一个由这些节点诱导的子图。采样节点子集有三个选项
对于
'node'
采样器,采样一个节点的概率与其出度成正比。'edge'
采样器首先采样一个边子集,然后使用这些边的端点节点。'walk'
采样器使用随机游走访问到的节点。它均匀地选择一些根节点,然后从每个根节点执行固定长度的随机游走。
- 参数:
mode (str) – 要使用的采样器,可以是
'node'
、'edge'
或'walk'
。采样器配置。
对于
'node'
采样器,budget 指定每个采样子图中的节点数。对于
'edge'
采样器,budget 指定用于诱导子图的边数。对于
'walk'
采样器,budget 是一个 tuple。budget[0] 指定生成随机游走的根节点数。budget[1] 指定随机游走的长度。
cache (bool, 可选) – 如果为 False,则不会缓存用于采样的概率数组。如果您想在不同的图中使用该采样器,则必须将其设置为 False。
prefetch_ndata (list[str], 可选) –
为子图预取的节点数据。
有关预取的详细说明,请参阅 guide-minibatch-prefetching。
prefetch_edata (list[str], 可选) –
为子图预取的边数据。
有关预取的详细说明,请参阅 guide-minibatch-prefetching。
output_device (device, 可选) – 输出子图的设备。
示例
>>> import torch >>> from dgl.dataloading import SAINTSampler, DataLoader >>> num_iters = 1000 >>> sampler = SAINTSampler(mode='node', budget=6000) >>> # Assume g.ndata['feat'] and g.ndata['label'] hold node features and labels >>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4) >>> for subg in dataloader: ... train_on(subg)