dgl.dataloading.as_edge_prediction_sampler

dgl.dataloading.as_edge_prediction_sampler(sampler, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None, prefetch_labels=None)[source]

从节点采样器创建一个边采样器。

对于每一批次边,采样器会将其应用的节点采样器作用于这些边的源节点和目标节点,以提取子图。如果提供了负采样器,它还会生成负边,并提取这些负边关联节点的子图。

对于每次迭代,采样器将生成

  • 计算边表示所需的输入节点张量,或者包含节点类型名称及对应张量的字典。

  • 一个子图,只包含迷你批次中的边及其关联节点。注意,该图具有与原始图相同的元图结构。

  • 如果提供了负采样器,则会生成另一个图,其中包含“负边”,连接由给定负采样器生成的源节点和目标节点。

  • 由提供的节点采样器返回的子图或 MFG,这些子图/MFG 是从迷你批次中的边(以及适用时的负边)的关联节点生成的。

参数:
  • sampler (Sampler) – 节点采样器对象。此外,它要求 sample 方法必须有一个可选的第三个参数 exclude_eids,表示要从邻域中排除的边 ID。对于同构图,此参数是一个张量;对于异构图,此参数是一个包含边类型和对应张量的字典。

  • exclude (Union[str, callable], optional) –

    是否以及如何排除迷你批次中采样边相关的依赖关系。可能的值有

    • None,不排除任何边。

    • self,排除当前迷你批次中的边。

    • reverse_id,不仅排除当前迷你批次中的边,还根据参数 reverse_eids 中的 ID 映射排除它们的反向边。

    • reverse_types,不仅排除当前迷你批次中的边,还根据参数 reverse_etypes 中存储在另一种类型中的反向边。

    • 用户自定义排除规则。它是一个可调用对象,以当前迷你批次中的边作为单个参数,并应返回要排除的边。

  • reverse_eids (Tensordict[etype, Tensor], optional) –

    一个反向边 ID 映射的张量。第 i 个元素表示第 i 条边的反向边的 ID。

    如果图是异构的,此参数需要一个包含边类型和反向边 ID 映射张量的字典。

  • reverse_etypes (dict[etype, etype], optional) – 从原始边类型到其反向边类型的映射。

  • negative_sampler (callable, optional) – 负采样器。

  • prefetch_labels (list[str] 或 dict[etype, list[str]], optional) –

    为返回的正对图预取的边标签。

    有关预取的详细说明,请参阅 guide-minibatch-prefetching

示例

以下示例展示了如何在同构无向图上,针对边集 train_eid 进行边分类训练一个 3 层 GNN。每个节点接收来自所有邻居的消息。

给定源节点 ID 数组 src 和目标节点 ID 数组 dst,以下代码创建一个双向图

>>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))

在上面的图表中,边 \(i\) 的反向边是边 \(i + |E|\)。因此,我们可以通过以下方式创建反向边映射 reverse_eids

>>> E = len(src)
>>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])

通过将 reverse_eids 传递给边采样器,当前迷你批次中的边及其反向边将被从提取的子图中排除,以避免信息泄露。

>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     exclude='reverse_id', reverse_eids=reverse_eids)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, blocks)

对于链接预测,可以提供一个负采样器来采样负边。以下代码使用 DGL 的 Uniform 为每条边生成 5 个负样本

>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     sampler, exclude='reverse_id', reverse_eids=reverse_eids,
...     negative_sampler=neg_sampler)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)

对于异构图,反向边可能属于不同的关系。例如,下图中的“user-click-item”和“item-click-by-user”关系是相互反向的。

>>> g = dgl.heterograph({
...     ('user', 'click', 'item'): (user, item),
...     ('item', 'clicked-by', 'user'): (item, user)})

为了正确地从每个迷你批次中排除边,设置 exclude='reverse_types' 并将字典 {'click': 'clicked-by', 'clicked-by': 'click'} 传递给 reverse_etypes 参数。

>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     exclude='reverse_types',
...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'})
>>> dataloader = dgl.dataloading.DataLoader(
...     g, {'click': train_eid}, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, blocks)

对于链接预测,提供一个负采样器来生成负样本

>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     exclude='reverse_types',
...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
...     negative_sampler=neg_sampler)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)