dgl.sampling.sample_neighbors

dgl.sampling.sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, copy_ndata=True, copy_edata=True, _dist_training=False, exclude_edges=None, output_device=None)[源码]

对给定节点的邻居边进行采样并返回诱导子图。

对于每个节点,将随机选择一定数量的入边(或出边,当 edge_dir == 'out' 时)。返回的图将包含原始图中的所有节点,但仅包含采样的边。

节点/边特征不保留。采样的边的原始 ID 将作为 dgl.EID 特征存储在返回的图中。

此函数支持 GPU 采样。有关详细信息,请参阅6.8 使用 GPU 进行邻居采样

参数:
  • g (DGLGraph) – 图。可以在 CPU 或 GPU 上。

  • nodes (tensordict) –

    要从中采样邻居的节点 ID。

    此参数可以接受单个 ID tensor 或节点类型和 ID tensor 的字典。如果给定单个 tensor,则图必须只有一种节点类型。

  • fanout (intdict[etype, int]) –

    对于每种边类型,每个节点要采样的边数。

    此参数可以接受单个 int 或边类型和 int 的字典。如果给定单个 int,DGL 将为每种边类型上的每个节点采样此数量的边。

    如果为单个边类型指定 -1,将选择该边类型所有具有非零概率的邻居边。

  • edge_dir (str, 可选) –

    确定是采样入边还是出边。

    可以是 in 表示入边,或 out 表示出边。

  • prob (str, 可选) –

    用作与节点的每个邻居边相关的(未归一化)概率的特征名称。该特征对每条边必须只有一个元素。

    该特征必须是非负浮点数或布尔值。否则,结果将未定义。

  • exclude_edges (tensordict) –

    为种子节点采样邻居时要排除的边 ID。

    此参数可以接受单个 ID tensor 或边类型和 ID tensor 的字典。如果给定单个 tensor,则图必须只有一种节点类型。

  • replace (bool, 可选) – 如果为 True,则有放回采样。

  • copy_ndata (bool, 可选) –

    如果为 True,新图的节点特征会从原始图复制。如果为 False,新图将没有任何节点特征。

    (默认值: True)

  • copy_edata (bool, 可选) –

    如果为 True,新图的边特征会从原始图复制。如果为 False,新图将没有任何边特征。

    (默认值: True)

  • _dist_training (bool, 可选) –

    内部参数。请勿使用。

    (默认值: False)

  • output_device (框架特定的设备上下文对象, 可选) – 输出设备。默认为与输入图相同。

返回:

一个采样得到的子图,只包含采样的邻居边。

返回类型:

DGLGraph

注意

如果 copy_ndatacopy_edata 为 True,原始图和新图的节点或边特征使用相同的 tensor。因此,用户应避免在新图的节点特征上执行原地(in-place)操作,以避免特征损坏。

示例

假设您有以下图

>>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))

和权重

>>> g.edata['prob'] = torch.FloatTensor([0., 1., 0., 1., 0., 1.])

为节点 0 和节点 1 各采样一条入边

>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 1)
>>> sg.edges(order='eid')
(tensor([1, 0]), tensor([0, 1]))
>>> sg.edata[dgl.EID]
tensor([2, 0])

为节点 0 和节点 1 各采样一条入边,使用边特征 prob 中的概率

>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 1, prob='prob')
>>> sg.edges(order='eid')
(tensor([2, 1]), tensor([0, 1]))

fanout 大于实际邻居数且不放回采样时,DGL 将选择所有邻居

>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3)
>>> sg.edges(order='eid')
(tensor([1, 2, 0, 1]), tensor([0, 0, 1, 1]))

为种子节点采样时排除某些 EID

>>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))
>>> g_edges = g.all_edges(form='all')``
(tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5]))
>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3, exclude_edges=[0, 1, 2])
>>> sg.all_edges(form='all')
(tensor([2, 1]), tensor([0, 1]), tensor([0, 1]))
>>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3])
tensor([False, False, False])
>>> g = dgl.heterograph({
...   ('drug', 'interacts', 'drug'): ([0, 0, 1, 1, 3, 2], [1, 2, 0, 1, 2, 0]),
...   ('drug', 'interacts', 'gene'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]),
...   ('drug', 'treats', 'disease'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0])})
>>> g_edges = g.all_edges(form='all', etype=('drug', 'interacts', 'drug'))
(tensor([0, 0, 1, 1, 3, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5]))
>>> excluded_edges  = {('drug', 'interacts', 'drug'): g_edges[2][:3]}
>>> sg = dgl.sampling.sample_neighbors(g, {'drug':[0, 1]}, 3, exclude_edges=excluded_edges)
>>> sg.all_edges(form='all', etype=('drug', 'interacts', 'drug'))
(tensor([2, 1]), tensor([0, 1]), tensor([0, 1]))
>>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3],etype=('drug', 'interacts', 'drug'))
tensor([False, False, False])