ShaDowKHopSampler

class dgl.dataloading.ShaDowKHopSampler(fanouts, replace=False, prob=None, prefetch_node_feats=None, prefetch_edge_feats=None, output_device=None)[源代码]

基类:Sampler

摘自 Deep Graph Neural Networks with Shallow Subgraph Samplers 的 K 跳子图采样器。

它执行节点维度的邻居采样,并返回由所有采样节点导出的子图。从中采样邻居的种子节点将首先出现在子图的导出节点中。

参数:
  • fanouts (list[int] 或 list[dict[etype, int]]) –

    针对每个 GNN 层,每种边类型要采样的邻居数量列表,其中第 i 个元素是第 i 个 GNN 层的 fanout。

    如果只提供一个整数,DGL 会假定每种边类型都有相同的 fanout。

    如果在某一层中对某种边类型提供了 -1,则将包含该边类型的所有入边。

  • replace (bool, 默认为 True) – 是否进行有放回采样

  • prob (str, 可选) – 如果给定,每个邻居被采样的概率与其在 g.edata 中给定名称的边特征值成正比。该特征在每条边上必须是标量。

示例

节点分类

对于同构图上的一组节点 train_nid,训练一个 3 层 GNN 进行节点分类,其中每个节点在第一、第二、第三层分别从 5、10、15 个邻居接收消息(假设后端是 PyTorch)

>>> g = dgl.data.CoraFullDataset()[0]
>>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])
>>> dataloader = dgl.dataloading.DataLoader(
...     g, torch.arange(g.num_nodes()), sampler,
...     batch_size=5, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, subgraph in dataloader:
...     print(subgraph)
...     assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
...     assert torch.equal(input_nodes[:output_nodes.shape[0]], output_nodes)
...     break
Graph(num_nodes=529, num_edges=3796,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64),
                     'feat': Scheme(shape=(8710,), dtype=torch.float32),
                     '_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})

如果在异构图上训练,并且希望每种边类型有不同数量的邻居,则应改为提供一个字典列表。每个字典指定每种边类型要选择的邻居数量。

>>> sampler = dgl.dataloading.ShaDowKHopSampler([
...     {('user', 'follows', 'user'): 5,
...      ('user', 'plays', 'game'): 4,
...      ('game', 'played-by', 'user'): 3}] * 3)

如果你想要非均匀的邻居采样

>>> g.edata['p'] = torch.rand(g.num_edges())   # any non-negative 1D vector works
>>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p')