MultiLayerFullNeighborSampler

class dgl.dataloading.MultiLayerFullNeighborSampler(num_layers, **kwargs)[source]

基类: NeighborSampler

通过从多层 GNN 的所有邻居接收消息来构建节点表示的计算依赖关系的采样器。

此采样器将使每个节点针对每种边类型从每个邻居收集消息。

参数:

示例

要在同构图上对节点集 train_nid 进行节点分类,训练一个 3 层 GNN,其中每个节点分别从第一、第二和第三层的全部邻居接收消息(假设后端是 PyTorch)

>>> sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_nid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
...     train_on(blocks)

注意事项

有关 MFG 的概念,请参阅用户指南第 6 节小批量训练教程