6.5 使用 DGL sparse 训练 GNN

本教程演示如何使用 dgl sparse 库在图上进行采样和训练模型。它使用 sparse sample 和 compact 算子训练和测试 GraphSAGE 模型,以从整个矩阵中采样子矩阵。

使用 DGL sparse 训练 GNN 与 6.1 使用邻居采样训练用于节点分类的 GNN 非常相似。主要区别在于自定义采样器和表示图的矩阵。

我们在 6.4 实现自定义图采样器 中自定义了一个采样器。在本教程中,我们将使用 DGL sparse 库自定义另一个采样器,如下所示。

@functional_datapipe("sample_sparse_neighbor")
class SparseNeighborSampler(SubgraphSampler):
    def __init__(self, datapipe, matrix, fanouts):
        super().__init__(datapipe)
        self.matrix = matrix
        # Convert fanouts to a list of tensors.
        self.fanouts = []
        for fanout in fanouts:
            if not isinstance(fanout, torch.Tensor):
                fanout = torch.LongTensor([int(fanout)])
            self.fanouts.insert(0, fanout)

    def sample_subgraphs(self, seeds):
        sampled_matrices = []
        src = seeds

        #####################################################################
        # (HIGHLIGHT) Using the sparse sample operator to preform random
        # sampling on the neighboring nodes of the seeds nodes. The sparse
        # compact operator is then employed to compact and relabel the sampled
        # matrix, resulting in the sampled matrix and the relabel index.
        #####################################################################
        for fanout in self.fanouts:
            # Sample neighbors.
            sampled_matrix = self.matrix.sample(1, fanout, ids=src).coalesce()
            # Compact the sampled matrix.
            compacted_mat, row_ids = sampled_matrix.compact(0)
            sampled_matrices.insert(0, compacted_mat)
            src = row_ids

        return src, sampled_matrices

另一个主要区别是表示图的矩阵。之前我们使用 FusedCSCSamplingGraph 进行采样。在本教程中,我们使用 SparseMatrix 来表示图。

dataset = gb.BuiltinDataset("ogbn-products").load()
g = dataset.graph
# Create sparse.
N = g.num_nodes
A = dglsp.from_csc(g.csc_indptr, g.indices, shape=(N, N))

其余代码与节点分类教程几乎相同。

将此采样器与 DataLoader 一起使用

datapipe = gb.ItemSampler(ids, batch_size=1024)
# Customize graphbolt sampler by sparse.
datapipe = datapipe.sample_sparse_neighbor(A, fanouts)
# Use grapbolt to fetch features.
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

模型定义如下所示

class SAGEConv(nn.Module):
    r"""GraphSAGE layer from `Inductive Representation Learning on
    Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__
    """

    def __init__(
        self,
        in_feats,
        out_feats,
    ):
        super(SAGEConv, self).__init__()
        self._in_src_feats, self._in_dst_feats = in_feats, in_feats
        self._out_feats = out_feats

        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
        self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=True)
        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain("relu")
        nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def forward(self, A, feat):
        feat_src = feat
        feat_dst = feat[: A.shape[1]]

        # Aggregator type: mean.
        srcdata = self.fc_neigh(feat_src)
        # Divided by degree.
        D_hat = dglsp.diag(A.sum(0)) ** -1
        A_div = A @ D_hat
        # Conv neighbors.
        dstdata = A_div.T @ srcdata

        rst = self.fc_self(feat_dst) + dstdata
        return rst


class SAGE(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # Three-layer GraphSAGE-gcn.
        self.layers.append(SAGEConv(in_size, hid_size))
        self.layers.append(SAGEConv(hid_size, hid_size))
        self.layers.append(SAGEConv(hid_size, out_size))
        self.dropout = nn.Dropout(0.5)
        self.hid_size = hid_size
        self.out_size = out_size

    def forward(self, sampled_matrices, x):
        hidden_x = x
        for layer_idx, (layer, sampled_matrix) in enumerate(
            zip(self.layers, sampled_matrices)
        ):
            hidden_x = layer(sampled_matrix, hidden_x)
            if layer_idx != len(self.layers) - 1:
                hidden_x = F.relu(hidden_x)
                hidden_x = self.dropout(hidden_x)
        return hidden_x

启动训练

features = dataset.feature
# Create GraphSAGE model.
in_size = features.size("node", None, "feat")[0]
num_classes = dataset.tasks[0].metadata["num_classes"]
out_size = num_classes
model = SAGE(in_size, 256, out_size).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)

for epoch in range(10):
    model.train()
    total_loss = 0
    for it, data in enumerate(dataloader):
        node_feature = data.node_features["feat"].float()
        blocks = data.sampled_subgraphs
        y = data.labels
        y_hat = model(blocks, node_feature)
        loss = F.cross_entropy(y_hat, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

更多详情,请参阅完整示例