dgl.dataloading.BlockSampler

class dgl.dataloading.BlockSampler(prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, output_device=None)[source]

用于以消息传递流图 (MFG) 形式采样 mini-batch 的基类。

它提供预取选项,用于获取第一个 MFG 的 srcdata 的节点特征、最后一个 MFG 的 dstdata 的节点标签以及所有 MFG 的 edata 的边特征。

参数:
  • prefetch_node_feats (list[str] or dict[str, list[str]], 可选的) –

    为第一个 MFG 预取的节点数据。

    DGL 将使用原始图中给定名称的节点数据填充第一层 MFG 的 srcnodessrcdata

  • prefetch_labels (list[str] or dict[str, list[str]], 可选的) –

    为最后一个 MFG 预取的节点数据。

    DGL 将使用原始图中给定名称的节点数据填充最后一层 MFG 的 dstnodesdstdata

  • prefetch_edge_feats (list[str] or dict[etype, list[str]], 可选的) –

    为所有 MFG 预取的边数据名称。

    DGL 将使用原始图中给定名称的边数据填充每个 MFG 的 edgesedata

  • output_device (device, 可选的) – 输出子图或 MFG 的设备。默认为与种子节点的 minibatch 相同的设备。

__init__(prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, output_device=None)[source]

方法

__init__([prefetch_node_feats, ...])

assign_lazy_features(result)

为预取分配延迟加载的特征。

sample(g, seed_nodes[, exclude_eids])

从给定的种子节点采样一个 block 列表。

sample_blocks(g, seed_nodes[, exclude_eids])

从给定的种子节点生成一个 block 列表。