dgl.sampling.sample_neighbors
- dgl.sampling.sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, copy_ndata=True, copy_edata=True, _dist_training=False, exclude_edges=None, output_device=None)[源码]
对给定节点的邻居边进行采样并返回诱导子图。
对于每个节点,将随机选择一定数量的入边(或出边,当
edge_dir == 'out'
时)。返回的图将包含原始图中的所有节点,但仅包含采样的边。节点/边特征不保留。采样的边的原始 ID 将作为 dgl.EID 特征存储在返回的图中。
此函数支持 GPU 采样。有关详细信息,请参阅6.8 使用 GPU 进行邻居采样。
- 参数:
g (DGLGraph) – 图。可以在 CPU 或 GPU 上。
nodes (tensor 或 dict) –
要从中采样邻居的节点 ID。
此参数可以接受单个 ID tensor 或节点类型和 ID tensor 的字典。如果给定单个 tensor,则图必须只有一种节点类型。
fanout (int 或 dict[etype, int]) –
对于每种边类型,每个节点要采样的边数。
此参数可以接受单个 int 或边类型和 int 的字典。如果给定单个 int,DGL 将为每种边类型上的每个节点采样此数量的边。
如果为单个边类型指定 -1,将选择该边类型所有具有非零概率的邻居边。
edge_dir (str, 可选) –
确定是采样入边还是出边。
可以是
in
表示入边,或out
表示出边。prob (str, 可选) –
用作与节点的每个邻居边相关的(未归一化)概率的特征名称。该特征对每条边必须只有一个元素。
该特征必须是非负浮点数或布尔值。否则,结果将未定义。
exclude_edges (tensor 或 dict) –
为种子节点采样邻居时要排除的边 ID。
此参数可以接受单个 ID tensor 或边类型和 ID tensor 的字典。如果给定单个 tensor,则图必须只有一种节点类型。
replace (bool, 可选) – 如果为 True,则有放回采样。
copy_ndata (bool, 可选) –
如果为 True,新图的节点特征会从原始图复制。如果为 False,新图将没有任何节点特征。
(默认值: True)
copy_edata (bool, 可选) –
如果为 True,新图的边特征会从原始图复制。如果为 False,新图将没有任何边特征。
(默认值: True)
_dist_training (bool, 可选) –
内部参数。请勿使用。
(默认值: False)
output_device (框架特定的设备上下文对象, 可选) – 输出设备。默认为与输入图相同。
- 返回:
一个采样得到的子图,只包含采样的邻居边。
- 返回类型:
注意
如果
copy_ndata
或copy_edata
为 True,原始图和新图的节点或边特征使用相同的 tensor。因此,用户应避免在新图的节点特征上执行原地(in-place)操作,以避免特征损坏。示例
假设您有以下图
>>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))
和权重
>>> g.edata['prob'] = torch.FloatTensor([0., 1., 0., 1., 0., 1.])
为节点 0 和节点 1 各采样一条入边
>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 1) >>> sg.edges(order='eid') (tensor([1, 0]), tensor([0, 1])) >>> sg.edata[dgl.EID] tensor([2, 0])
为节点 0 和节点 1 各采样一条入边,使用边特征
prob
中的概率>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 1, prob='prob') >>> sg.edges(order='eid') (tensor([2, 1]), tensor([0, 1]))
当
fanout
大于实际邻居数且不放回采样时,DGL 将选择所有邻居>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3) >>> sg.edges(order='eid') (tensor([1, 2, 0, 1]), tensor([0, 0, 1, 1]))
为种子节点采样时排除某些 EID
>>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0])) >>> g_edges = g.all_edges(form='all')`` (tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5])) >>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3, exclude_edges=[0, 1, 2]) >>> sg.all_edges(form='all') (tensor([2, 1]), tensor([0, 1]), tensor([0, 1])) >>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3]) tensor([False, False, False]) >>> g = dgl.heterograph({ ... ('drug', 'interacts', 'drug'): ([0, 0, 1, 1, 3, 2], [1, 2, 0, 1, 2, 0]), ... ('drug', 'interacts', 'gene'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]), ... ('drug', 'treats', 'disease'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0])}) >>> g_edges = g.all_edges(form='all', etype=('drug', 'interacts', 'drug')) (tensor([0, 0, 1, 1, 3, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5])) >>> excluded_edges = {('drug', 'interacts', 'drug'): g_edges[2][:3]} >>> sg = dgl.sampling.sample_neighbors(g, {'drug':[0, 1]}, 3, exclude_edges=excluded_edges) >>> sg.all_edges(form='all', etype=('drug', 'interacts', 'drug')) (tensor([2, 1]), tensor([0, 1]), tensor([0, 1])) >>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3],etype=('drug', 'interacts', 'drug')) tensor([False, False, False])