PerSourceUniform

class dgl.dataloading.negative_sampler.PerSourceUniform(k)[源码]

基类: _BaseNegativeSampler

负采样器,根据均匀分布为每个源节点随机选择负目标节点。

对于类型为 (srctype, etype, dsttype) 的每条边 (u, v),DGL 会生成 k 对负边 (u, v'),其中 v' 是从类型为 dsttype 的所有节点中均匀选择的。生成的边也将具有类型 (srctype, etype, dsttype)

参数:

k (int) – 每条边的负样本数量。

示例

>>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))
>>> neg_sampler = dgl.dataloading.negative_sampler.PerSourceUniform(2)
>>> neg_sampler(g, torch.tensor([0, 1]))
(tensor([0, 0, 1, 1]), tensor([1, 0, 2, 3]))