6.5 使用 DGL sparse 训练 GNN
本教程演示如何使用 dgl sparse 库在图上进行采样和训练模型。它使用 sparse sample 和 compact 算子训练和测试 GraphSAGE 模型,以从整个矩阵中采样子矩阵。
使用 DGL sparse 训练 GNN 与 6.1 使用邻居采样训练用于节点分类的 GNN 非常相似。主要区别在于自定义采样器和表示图的矩阵。
我们在 6.4 实现自定义图采样器 中自定义了一个采样器。在本教程中,我们将使用 DGL sparse 库自定义另一个采样器,如下所示。
@functional_datapipe("sample_sparse_neighbor")
class SparseNeighborSampler(SubgraphSampler):
def __init__(self, datapipe, matrix, fanouts):
super().__init__(datapipe)
self.matrix = matrix
# Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
def sample_subgraphs(self, seeds):
sampled_matrices = []
src = seeds
#####################################################################
# (HIGHLIGHT) Using the sparse sample operator to preform random
# sampling on the neighboring nodes of the seeds nodes. The sparse
# compact operator is then employed to compact and relabel the sampled
# matrix, resulting in the sampled matrix and the relabel index.
#####################################################################
for fanout in self.fanouts:
# Sample neighbors.
sampled_matrix = self.matrix.sample(1, fanout, ids=src).coalesce()
# Compact the sampled matrix.
compacted_mat, row_ids = sampled_matrix.compact(0)
sampled_matrices.insert(0, compacted_mat)
src = row_ids
return src, sampled_matrices
另一个主要区别是表示图的矩阵。之前我们使用 FusedCSCSamplingGraph
进行采样。在本教程中,我们使用 SparseMatrix
来表示图。
dataset = gb.BuiltinDataset("ogbn-products").load()
g = dataset.graph
# Create sparse.
N = g.num_nodes
A = dglsp.from_csc(g.csc_indptr, g.indices, shape=(N, N))
其余代码与节点分类教程几乎相同。
将此采样器与 DataLoader
一起使用
datapipe = gb.ItemSampler(ids, batch_size=1024)
# Customize graphbolt sampler by sparse.
datapipe = datapipe.sample_sparse_neighbor(A, fanouts)
# Use grapbolt to fetch features.
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
模型定义如下所示
class SAGEConv(nn.Module):
r"""GraphSAGE layer from `Inductive Representation Learning on
Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__
"""
def __init__(
self,
in_feats,
out_feats,
):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = in_feats, in_feats
self._out_feats = out_feats
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=True)
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def forward(self, A, feat):
feat_src = feat
feat_dst = feat[: A.shape[1]]
# Aggregator type: mean.
srcdata = self.fc_neigh(feat_src)
# Divided by degree.
D_hat = dglsp.diag(A.sum(0)) ** -1
A_div = A @ D_hat
# Conv neighbors.
dstdata = A_div.T @ srcdata
rst = self.fc_self(feat_dst) + dstdata
return rst
class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# Three-layer GraphSAGE-gcn.
self.layers.append(SAGEConv(in_size, hid_size))
self.layers.append(SAGEConv(hid_size, hid_size))
self.layers.append(SAGEConv(hid_size, out_size))
self.dropout = nn.Dropout(0.5)
self.hid_size = hid_size
self.out_size = out_size
def forward(self, sampled_matrices, x):
hidden_x = x
for layer_idx, (layer, sampled_matrix) in enumerate(
zip(self.layers, sampled_matrices)
):
hidden_x = layer(sampled_matrix, hidden_x)
if layer_idx != len(self.layers) - 1:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
return hidden_x
启动训练
features = dataset.feature
# Create GraphSAGE model.
in_size = features.size("node", None, "feat")[0]
num_classes = dataset.tasks[0].metadata["num_classes"]
out_size = num_classes
model = SAGE(in_size, 256, out_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
for epoch in range(10):
model.train()
total_loss = 0
for it, data in enumerate(dataloader):
node_feature = data.node_features["feat"].float()
blocks = data.sampled_subgraphs
y = data.labels
y_hat = model(blocks, node_feature)
loss = F.cross_entropy(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
更多详情,请参阅完整示例。