dgl.sampling.select_topk

dgl.sampling.select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False, copy_ndata=True, copy_edata=True, output_device=None)[源代码]

选择给定节点的 k 个最大(或最小)权重的相邻边,并返回诱导子图。

对于每个节点,将选择具有最大(或最小,当 ascending == True 时)权重的入站(或出站,当 edge_dir == 'out' 时)边。返回的图将包含原始图中的所有节点,但仅包含采样的边。

节点/边特征不会保留。采样的边的原始 ID 将作为 dgl.EID 特征存储在返回的图中。

参数:
  • g (DGLGraph) – 图。必须在 CPU 上。

  • k (int or dict[etype, int]) –

    为每个节点在每种边类型上选择的边数量。

    此参数可以是一个单独的 int 值,也可以是一个边类型和 int 值的字典。如果给定一个单独的 int 值,DGL 将为每种边类型上的每个节点选择此数量的边。

    如果对于某个边类型给定 -1,则将选择该边类型的所有相邻边。

  • weight (str) – 与每条边关联的权重的特征名称。该特征对于每条边应只有一个元素。该特征可以是 int32/64 或 float32/64。

  • nodes (tensor or dict, optional) –

    要从中采样邻居的节点 ID。

    此参数可以是一个单独的 ID tensor,也可以是一个节点类型和 ID tensor 的字典。如果给定一个单独的 tensor,则图必须只有一种节点类型。

    如果为 None,DGL 将为所有节点选择边。

  • edge_dir (str, optional) –

    确定是采样入站边还是出站边。

    可以取 in 表示入站边,或 out 表示出站边。

  • ascending (bool, optional) – 如果为 True,DGL 将返回具有 k 个最小权重的边,而不是 k 个最大权重的边。

  • copy_ndata (bool, optional) –

    如果为 True,则新图的节点特征将从原始图复制。如果为 False,则新图将没有任何节点特征。

    (默认值: True)

  • copy_edata (bool, optional) –

    如果为 True,则新图的边特征将从原始图复制。如果为 False,则新图将没有任何边特征。

    (默认值: True)

  • output_device (Framework-specific device context object, optional) – 输出设备。默认与输入图相同。

返回:

一个仅包含采样的相邻边的采样子图。它在 CPU 上。

返回类型:

DGLGraph

注意

如果 copy_ndatacopy_edata 为 True,则新图的节点或边特征使用与原始图相同的 tensor。因此,用户应避免在新图的节点特征上执行原地操作,以避免特征损坏。

示例

>>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))
>>> g.edata['weight'] = torch.FloatTensor([0, 1, 0, 1, 0, 1])
>>> sg = dgl.sampling.select_topk(g, 1, 'weight')
>>> sg.edges(order='eid')
(tensor([2, 1, 0]), tensor([0, 1, 2]))