dgl.sampling.PinSAGESampler
- class dgl.sampling.PinSAGESampler(G, ntype, other_type, num_traversals, termination_prob, num_random_walks, num_neighbors, weight_column='weights')[source]
PinSAGE 式的邻居采样器。
此可调用对象适用于具有边类型
(ntype, fwtype, other_type)
和(other_type, bwtype, ntype)
的双向二分图(其中ntype
、fwtype
、bwtype
和other_type
可以是任意类型名称)。它将生成一个节点类型为ntype
的同构图,其中每个给定节点的邻居是通过从该给定节点开始的多次随机游走访问同类型节点最频繁的那些节点。每次随机游走包含多次基于元路径的遍历,每次遍历后都有一定的终止概率。元路径始终是[fwtype, bwtype]
,从节点类型ntype
游走到节点类型other_type
然后再回到ntype
。返回的同构图的边将从访问最频繁的节点连接到给定的节点,并带有一个特征指示访问次数。
此采样器支持 UVA 和 GPU 采样。有关更多详细信息,请参阅 6.8 使用 GPU 进行邻域采样。
- 参数:
G (DGLGraph) –
双向二分图。
图应仅包含两种节点类型:
ntype
和other_type
。图应仅包含两种边类型,一种从ntype
连接到other_type
,另一种从other_type
连接到ntype
。ntype (str) – 将构建图的节点类型。
other_type (str) – 另一种节点类型。
num_traversals (int) –
单次随机游走的最大元路径遍历次数。
通常被视为超参数。
termination_prob (int) –
每次元路径遍历后的终止概率。
通常被视为超参数。
num_random_walks (int) –
对每个给定节点尝试的随机游走次数。
通常被视为超参数。
num_neighbors (int) – 为每个给定节点选择的邻居数量(或访问最频繁的节点数量)。
weight_column (str, 默认为 "weights") – 存储在返回图中表示访问次数的边特征名称。
示例
生成一个包含 3000 个“A”节点和 5000 个“B”节点的随机双向二分图。
>>> g = scipy.sparse.random(3000, 5000, 0.003) >>> G = dgl.heterograph({ ... ('A', 'AB', 'B'): g.nonzero(), ... ('B', 'BA', 'A'): g.T.nonzero()})
然后我们创建一个 PinSage 邻居采样器,用于采样节点类型为“A”的图。每个节点将具有(最多)10个邻居。
>>> sampler = dgl.sampling.PinSAGESampler(G, 'A', 'B', 3, 0.5, 200, 10)
这是根据 PinSAGE 算法选择类型为“A”的节点 #0、#1 和 #2 的邻居的方法
>>> seeds = torch.LongTensor([0, 1, 2]) >>> frontier = sampler(seeds) >>> frontier.all_edges(form='uv') (tensor([ 230, 0, 802, 47, 50, 1639, 1533, 406, 2110, 2687, 2408, 2823, 0, 972, 1230, 1658, 2373, 1289, 1745, 2918, 1818, 1951, 1191, 1089, 1282, 566, 2541, 1505, 1022, 812]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]))
有关 PinSAGE 模型的端到端示例,包括在多层上采样和使用采样图进行计算,请参阅我们位于
examples/pytorch/pinsage
中的 PinSage 示例。参考文献
- Graph Convolutional Neural Networks for Web-Scale Recommender Systems
Ying et al., 2018, https://arxiv.org/abs/1806.01973
- __init__(G, ntype, other_type, num_traversals, termination_prob, num_random_walks, num_neighbors, weight_column='weights')[source]
方法
__init__
(G, ntype, other_type, ...[, ...])