NeighborSampler

class dgl.dataloading.NeighborSampler(fanouts, edge_dir='in', prob=None, mask=None, replace=False, prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, output_device=None, fused=True)[source]

基类:BlockSampler

一个采样器,通过邻居采样为多层 GNN 构建节点表示的计算依赖关系。

该采样器将使得每个节点从每种边类型的固定数量的邻居收集消息。邻居是均匀采样的。

参数:
  • fanouts (list[int] or list[dict[etype, int]]) –

    一个列表,指定每个 GNN 层每种边类型要采样的邻居数,其中第 i 个元素是第 i 个 GNN 层的 fanout(扇出)。

    如果只提供一个整数,DGL 假定所有边类型都具有相同的 fanout。

    如果某个层对某种边类型提供了 -1,则将包含该边类型的所有入边。

  • edge_dir (str, 默认 'in') – 可以是 'in' ``,邻居将根据入边采样,或者 ``'out',否则,与 dgl.sampling.sample_neighbors() 相同。

  • prob (str, 可选) –

    如果给定,每个邻居被采样的概率与 g.edata 中给定名称的边特征值成比例。该特征必须是每条边上的标量。

    此参数与 mask 互斥。如果想同时指定 mask 和 probability,可以考虑将 probability 乘以 mask。

  • mask (str, 可选) –

    如果给定,只有当 g.edata 中给定名称的边 mask 为 True 时,才能选择该邻居。该数据必须是每条边上的布尔值。

    此参数与 prob 互斥。如果想同时指定 mask 和 probability,可以考虑将 probability 乘以 mask。

  • replace (bool, 默认为 False) – 是否进行有放回采样

  • prefetch_node_feats (list[str] or dict[ntype, list[str]], 可选) – 为第一个 MFG 预取源节点数据,对应于第一个 GNN 层所需的输入节点特征。

  • prefetch_labels (list[str] or dict[ntype, list[str]], 可选) – 为最后一个 MFG 预取目标节点数据,对应于 minibatch 的节点标签。

  • prefetch_edge_feats (list[str] or dict[etype, list[str]], 可选) – 为所有 MFG 预取的边数据名称,对应于所有 GNN 层所需的边特征。

  • output_device (device, 可选) – 输出子图或 MFG 的设备。默认为与 seed 节点 minibatch 相同的设备。

  • fused (bool, 默认为 True) – 如果为 True 且设备为 CPU,则调用 fused 邻居采样。此版本要求 seed_nodes 是唯一的。

示例

节点分类

在同构图上对节点集 train_nid 进行节点分类,训练一个 3 层 GNN,其中每个节点分别从第一、第二、第三层的 5、10、15 个邻居获取消息(假设后端是 PyTorch)。

>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15])
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_nid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
...     train_on(blocks)

如果在异构图上进行训练,并且希望每种边类型有不同的邻居数量,则应该提供一个字典列表。每个字典指定每种边类型要选择的邻居数量。

>>> sampler = dgl.dataloading.NeighborSampler([
...     {('user', 'follows', 'user'): 5,
...      ('user', 'plays', 'game'): 4,
...      ('game', 'played-by', 'user'): 3}] * 3)

如果想要非均匀的邻居采样

>>> g.edata['p'] = torch.rand(g.num_edges())   # any non-negative 1D vector works
>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='p')

或者在边掩码上采样

>>> g.edata['mask'] = torch.rand(g.num_edges()) < 0.2   # any 1D boolean mask works
>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='mask')

边分类和链接预测

此类还可以与 as_edge_prediction_sampler() 一起用于边分类和链接预测。

>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15])
>>> sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)

更多详情请参阅 as_edge_prediction_sampler() 的文档。

注意

关于 MFG 的概念,请参考用户指南第 6 节Minibatch 训练教程