dgl.dataloading.as_edge_prediction_sampler
- dgl.dataloading.as_edge_prediction_sampler(sampler, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None, prefetch_labels=None)[source]
从节点采样器创建一个边采样器。
对于每一批次边,采样器会将其应用的节点采样器作用于这些边的源节点和目标节点,以提取子图。如果提供了负采样器,它还会生成负边,并提取这些负边关联节点的子图。
对于每次迭代,采样器将生成
计算边表示所需的输入节点张量,或者包含节点类型名称及对应张量的字典。
一个子图,只包含迷你批次中的边及其关联节点。注意,该图具有与原始图相同的元图结构。
如果提供了负采样器,则会生成另一个图,其中包含“负边”,连接由给定负采样器生成的源节点和目标节点。
由提供的节点采样器返回的子图或 MFG,这些子图/MFG 是从迷你批次中的边(以及适用时的负边)的关联节点生成的。
- 参数:
sampler (Sampler) – 节点采样器对象。此外,它要求
sample
方法必须有一个可选的第三个参数exclude_eids
,表示要从邻域中排除的边 ID。对于同构图,此参数是一个张量;对于异构图,此参数是一个包含边类型和对应张量的字典。exclude (Union[str, callable], optional) –
是否以及如何排除迷你批次中采样边相关的依赖关系。可能的值有
None,不排除任何边。
self
,排除当前迷你批次中的边。reverse_id
,不仅排除当前迷你批次中的边,还根据参数reverse_eids
中的 ID 映射排除它们的反向边。reverse_types
,不仅排除当前迷你批次中的边,还根据参数reverse_etypes
中存储在另一种类型中的反向边。用户自定义排除规则。它是一个可调用对象,以当前迷你批次中的边作为单个参数,并应返回要排除的边。
reverse_eids (Tensor 或 dict[etype, Tensor], optional) –
一个反向边 ID 映射的张量。第 i 个元素表示第 i 条边的反向边的 ID。
如果图是异构的,此参数需要一个包含边类型和反向边 ID 映射张量的字典。
reverse_etypes (dict[etype, etype], optional) – 从原始边类型到其反向边类型的映射。
negative_sampler (callable, optional) – 负采样器。
prefetch_labels (list[str] 或 dict[etype, list[str]], optional) –
为返回的正对图预取的边标签。
有关预取的详细说明,请参阅 guide-minibatch-prefetching。
示例
以下示例展示了如何在同构无向图上,针对边集
train_eid
进行边分类训练一个 3 层 GNN。每个节点接收来自所有邻居的消息。给定源节点 ID 数组
src
和目标节点 ID 数组dst
,以下代码创建一个双向图>>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))
在上面的图表中,边 \(i\) 的反向边是边 \(i + |E|\)。因此,我们可以通过以下方式创建反向边映射
reverse_eids
>>> E = len(src) >>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])
通过将
reverse_eids
传递给边采样器,当前迷你批次中的边及其反向边将被从提取的子图中排除,以避免信息泄露。>>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... exclude='reverse_id', reverse_eids=reverse_eids) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_eid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, blocks)
对于链接预测,可以提供一个负采样器来采样负边。以下代码使用 DGL 的
Uniform
为每条边生成 5 个负样本>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) >>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... sampler, exclude='reverse_id', reverse_eids=reverse_eids, ... negative_sampler=neg_sampler) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_eid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
对于异构图,反向边可能属于不同的关系。例如,下图中的“user-click-item”和“item-click-by-user”关系是相互反向的。
>>> g = dgl.heterograph({ ... ('user', 'click', 'item'): (user, item), ... ('item', 'clicked-by', 'user'): (item, user)})
为了正确地从每个迷你批次中排除边,设置
exclude='reverse_types'
并将字典{'click': 'clicked-by', 'clicked-by': 'click'}
传递给reverse_etypes
参数。>>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... exclude='reverse_types', ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}) >>> dataloader = dgl.dataloading.DataLoader( ... g, {'click': train_eid}, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, blocks)
对于链接预测,提供一个负采样器来生成负样本
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) >>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... exclude='reverse_types', ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}, ... negative_sampler=neg_sampler) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_eid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)