6.7 在大型图上进行精确离线推理

(中文版)

子图采样和邻居采样都是为了减少在使用 GPU 训练 GNN 时的内存和时间消耗。在执行推理时,通常最好真正地聚合所有邻居的信息,以消除采样带来的随机性。然而,由于内存限制,全图前向传播通常在 GPU 上不可行,而由于计算速度慢,在 CPU 上则很慢。本节介绍了如何通过 mini-batch 和邻居采样,在有限的 GPU 内存下进行全图前向传播的方法。

推理算法与训练算法不同,因为所有节点的表示应该从第一层开始,逐层计算。具体来说,对于特定的层,我们需要在 mini-batch 中计算该 GNN 层的输出表示。其结果是推理算法会有一个遍历层的外循环,以及一个遍历节点 mini-batch 的内循环。相比之下,训练算法有一个遍历节点 mini-batch 的外循环,以及一个用于邻居采样和消息传递的遍历层的内循环。

下面的动画展示了计算过程(请注意,对于每一层,仅绘制了前三个 mini-batch)。

Imgur

实现离线推理

考虑我们在第 6.1 节 调整模型以进行 Mini-batch 训练 中提到的两层 GCN。实现离线推理的方法仍然涉及使用 NeighborSampler,但每次只对一层进行采样。

datapipe = gb.ItemSampler(all_nodes_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [-1]) # 1 layers.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

请注意,离线推理是作为 GNN 模块的一个方法实现的,因为单层计算也取决于消息如何聚合和组合。

class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # Three-layer GraphSAGE-mean.
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.hidden_size = hidden_size
        self.out_size = out_size

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
                hidden_x = self.dropout(hidden_x)
        return hidden_x

    def inference(self, graph, features, dataloader, device):
        """
        Offline inference with this module
        """
        feature = features.read("node", None, "feat")

        # Compute representations layer by layer
        for layer_idx, layer in enumerate(self.layers):
            is_last_layer = layer_idx == len(self.layers) - 1

            y = torch.empty(
                graph.total_num_nodes,
                self.out_size if is_last_layer else self.hidden_size,
                dtype=torch.float32,
                device=buffer_device,
                pin_memory=pin_memory,
            )
            feature = feature.to(device)

            for step, data in tqdm(enumerate(dataloader)):
                x = feature[data.input_nodes]
                hidden_x = layer(data.blocks[0], x)  # len(blocks) = 1
                if not is_last_layer:
                    hidden_x = F.relu(hidden_x)
                    hidden_x = self.dropout(hidden_x)
                # By design, our output nodes are contiguous.
                y[
                    data.seeds[0] : data.seeds[-1] + 1
                ] = hidden_x.to(device)
            feature = y

        return y

请注意,为了计算验证集上的评估指标进行模型选择,我们通常不必计算精确的离线推理。原因是我们需要计算每一层上每个节点的表示,这通常非常昂贵,特别是在有很多未标记数据的半监督场景中。邻居采样对于模型选择和验证来说效果很好。

可以查看 GraphSAGERGCN 以获取离线推理的示例。