dgl.sampling.pack_traces
- dgl.sampling.pack_traces(traces, types)[源代码]
将由
random_walk()
返回的填充(padded)游走路径打包到一个连接的数组中。填充值 (-1) 会被移除,返回每个游走路径的长度和偏移量,以及连接后的节点 ID 和节点类型数组。- 参数:
traces (Tensor) – 一个二维的节点 ID tensor。必须在 CPU 上,且类型必须是
int32
或int64
。types (Tensor) – 一个一维的节点类型 ID tensor。必须在 CPU 上,且类型必须是
int32
或int64
。
- 返回:
concat_vids (Tensor) – 所有节点 ID 连接后的数组,已移除填充值。
concat_types (Tensor) – 与
concat_vids
中每个节点对应的节点类型数组。长度与concat_vids
相同。lengths (Tensor) – 原始 traces tensor 中每个游走路径的长度。
offsets (Tensor) – 原始 traces tensor 中每个游走路径在新连接 tensor 中的偏移量。
注意事项
返回的 tensor 位于 CPU 上。
示例
>>> g2 = dgl.heterograph({ ... ('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]), ... ('user', 'view', 'item'): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]), ... ('item', 'viewed-by', 'user'): ([0, 1, 1, 2, 2, 1], [0, 0, 1, 2, 3, 3]) >>> traces, types = dgl.sampling.random_walk( ... g2, [0, 0], metapath=['follow', 'view', 'viewed-by'] * 2, ... restart_prob=torch.FloatTensor([0, 0.5, 0, 0, 0.5, 0])) >>> traces, types (tensor([[ 0, 1, -1, -1, -1, -1, -1], [ 0, 1, 1, 3, 0, 0, 0]]), tensor([0, 0, 1, 0, 0, 1, 0])) >>> concat_vids, concat_types, lengths, offsets = dgl.sampling.pack_traces(traces, types) >>> concat_vids tensor([0, 1, 0, 1, 1, 3, 0, 0, 0]) >>> concat_types tensor([0, 0, 0, 0, 1, 0, 0, 1, 0]) >>> lengths tensor([2, 7]) >>> offsets tensor([0, 2]))
第一个 tensor
concat_vids
是所有路径的连接,即traces
的展平数组,但不包含所有填充值 (-1)。第二个 tensor
concat_types
表示第一个 tensor 中所有对应节点的节点类型 ID。第三个和第四个 tensor 表示每条路径的长度和偏移量。利用这些 tensor,可以很容易地通过以下方式获取第 i 个随机游走路径:
>>> vids = concat_vids.split(lengths.tolist()) >>> vtypes = concat_vtypes.split(lengths.tolist()) >>> vids[1], vtypes[1] (tensor([0, 1, 1, 3, 0, 0, 0]), tensor([0, 0, 1, 0, 0, 1, 0]))