dgl.sampling.global_uniform_negative_sampling

dgl.sampling.global_uniform_negative_sampling(g, num_samples, exclude_self_loops=True, replace=False, etype=None, redundancy=None)[source]

执行负采样,生成源-目标对,使得指定类型的边不存在。

具体来说,此函数接受一个边类型和采样数量。它返回两个张量 srcdst,前者范围为 [0, num_src),后者范围为 [0, num_dst),其中 num_srcnum_dst 分别表示源节点类型和目标节点类型的节点数量。它保证具有源节点类型的 src 和具有目标节点类型的 dst 的对应对之间不存在边。

注意

此负采样器会尝试生成尽可能多的负样本,但极少情况下可能会返回少于 num_samples 的负样本。当图非常小或密集,不存在许多唯一的负样本时,这种情况更容易发生。

参数:
  • g (DGLGraph) – 图。

  • num_samples (int) – 希望生成的负样本数量。

  • exclude_self_loops (bool, 可选) –

    是否从负样本中排除自环。仅影响源节点类型和目标节点类型相同的边类型。

    默认值: True。

  • replace (bool, 可选) – 是否进行有放回采样。设置为 True 会更快。(默认值: False)

  • etype (strstrtuple, 可选) – 边类型。如果图只有一个边类型,则可以省略。

  • redundancy (float, 可选) –

    指示在拒绝采样期间,实际生成多少额外的负样本,以便找到唯一的对。

    增加此值会增加获得 num_samples 负样本的可能性,但也会花费更多时间和内存。

    (默认值: 由图的密度自动确定)

返回:

源节点和目标节点对。

返回类型:

tuple[Tensor, Tensor]

示例

>>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))
>>> dgl.sampling.global_uniform_negative_sampling(g, 3)
(tensor([0, 1, 3]), tensor([2, 0, 2]))