MultiLayerFullNeighborSampler
- class dgl.dataloading.MultiLayerFullNeighborSampler(num_layers, **kwargs)[source]
基类:
NeighborSampler
通过从多层 GNN 的所有邻居接收消息来构建节点表示的计算依赖关系的采样器。
此采样器将使每个节点针对每种边类型从每个邻居收集消息。
- 参数:
num_layers (int) – 要采样的 GNN 层数。
kwargs – 传递给
dgl.dataloading.NeighborSampler
。
示例
要在同构图上对节点集
train_nid
进行节点分类,训练一个 3 层 GNN,其中每个节点分别从第一、第二和第三层的全部邻居接收消息(假设后端是 PyTorch)>>> sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_nid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, output_nodes, blocks in dataloader: ... train_on(blocks)
注意事项
有关 MFG 的概念,请参阅用户指南第 6 节和小批量训练教程。