dgl.sampling.PinSAGESampler

class dgl.sampling.PinSAGESampler(G, ntype, other_type, num_traversals, termination_prob, num_random_walks, num_neighbors, weight_column='weights')[source]

PinSAGE 式的邻居采样器。

此可调用对象适用于具有边类型 (ntype, fwtype, other_type)(other_type, bwtype, ntype) 的双向二分图(其中 ntypefwtypebwtypeother_type 可以是任意类型名称)。它将生成一个节点类型为 ntype 的同构图,其中每个给定节点的邻居是通过从该给定节点开始的多次随机游走访问同类型节点最频繁的那些节点。每次随机游走包含多次基于元路径的遍历,每次遍历后都有一定的终止概率。元路径始终是 [fwtype, bwtype],从节点类型 ntype 游走到节点类型 other_type 然后再回到 ntype

返回的同构图的边将从访问最频繁的节点连接到给定的节点,并带有一个特征指示访问次数。

此采样器支持 UVA 和 GPU 采样。有关更多详细信息,请参阅 6.8 使用 GPU 进行邻域采样

参数:
  • G (DGLGraph) –

    双向二分图。

    图应仅包含两种节点类型:ntypeother_type。图应仅包含两种边类型,一种从 ntype 连接到 other_type,另一种从 other_type 连接到 ntype

  • ntype (str) – 将构建图的节点类型。

  • other_type (str) – 另一种节点类型。

  • num_traversals (int) –

    单次随机游走的最大元路径遍历次数。

    通常被视为超参数。

  • termination_prob (int) –

    每次元路径遍历后的终止概率。

    通常被视为超参数。

  • num_random_walks (int) –

    对每个给定节点尝试的随机游走次数。

    通常被视为超参数。

  • num_neighbors (int) – 为每个给定节点选择的邻居数量(或访问最频繁的节点数量)。

  • weight_column (str, 默认为 "weights") – 存储在返回图中表示访问次数的边特征名称。

示例

生成一个包含 3000 个“A”节点和 5000 个“B”节点的随机双向二分图。

>>> g = scipy.sparse.random(3000, 5000, 0.003)
>>> G = dgl.heterograph({
...     ('A', 'AB', 'B'): g.nonzero(),
...     ('B', 'BA', 'A'): g.T.nonzero()})

然后我们创建一个 PinSage 邻居采样器,用于采样节点类型为“A”的图。每个节点将具有(最多)10个邻居。

>>> sampler = dgl.sampling.PinSAGESampler(G, 'A', 'B', 3, 0.5, 200, 10)

这是根据 PinSAGE 算法选择类型为“A”的节点 #0、#1 和 #2 的邻居的方法

>>> seeds = torch.LongTensor([0, 1, 2])
>>> frontier = sampler(seeds)
>>> frontier.all_edges(form='uv')
(tensor([ 230,    0,  802,   47,   50, 1639, 1533,  406, 2110, 2687, 2408, 2823,
            0,  972, 1230, 1658, 2373, 1289, 1745, 2918, 1818, 1951, 1191, 1089,
         1282,  566, 2541, 1505, 1022,  812]),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2]))

有关 PinSAGE 模型的端到端示例,包括在多层上采样和使用采样图进行计算,请参阅我们位于 examples/pytorch/pinsage 中的 PinSage 示例。

参考文献

Graph Convolutional Neural Networks for Web-Scale Recommender Systems

Ying et al., 2018, https://arxiv.org/abs/1806.01973

__init__(G, ntype, other_type, num_traversals, termination_prob, num_random_walks, num_neighbors, weight_column='weights')[source]

方法

__init__(G, ntype, other_type, ...[, ...])