注意
要下载完整的示例代码,请转到末尾。
使用 GNN 进行图分类训练
通过本教程,您将能够
加载 DGL 提供的图分类数据集。
理解 readout 函数的作用。
理解如何创建和使用图的 mini-batch。
构建基于 GNN 的图分类模型。
在 DGL 提供的数集上训练和评估模型。
(预计时间:18 分钟)
import os
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
使用 GNN 进行图分类概述
图分类或回归任务要求模型根据给定图的节点和边特征,预测该图的某些图级属性。分子属性预测是其中的一个具体应用。
本教程展示了如何使用论文 How Powerful Are Graph Neural Networks 中的一个小数据集来训练图分类模型。
加载数据
# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes.
dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)
Downloading /root/.dgl/GINDataset.zip from https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip...
/root/.dgl/GINDataset.zip: 0%| | 0.00/33.4M [00:00<?, ?B/s]
/root/.dgl/GINDataset.zip: 21%|██ | 6.95M/33.4M [00:00<00:00, 69.5MB/s]
/root/.dgl/GINDataset.zip: 49%|████▊ | 16.3M/33.4M [00:00<00:00, 83.5MB/s]
/root/.dgl/GINDataset.zip: 76%|███████▌ | 25.4M/33.4M [00:00<00:00, 87.2MB/s]
/root/.dgl/GINDataset.zip: 100%|██████████| 33.4M/33.4M [00:00<00:00, 87.2MB/s]
Extracting file to /root/.dgl/GINDataset
该数据集是一组图,每个图都有节点特征和单个标签。可以在 GINDataset
对象的 dim_nfeats
和 gclasses
属性中查看节点特征的维度和可能的图类别数量。
print("Node feature dimensionality:", dataset.dim_nfeats)
print("Number of graph categories:", dataset.gclasses)
from dgl.dataloading import GraphDataLoader
Node feature dimensionality: 3
Number of graph categories: 2
定义数据加载器
图分类数据集通常包含两种类型的元素:一组图及其图级标签。与图像分类任务类似,当数据集足够大时,我们需要使用 mini-batch 进行训练。当您训练图像分类或语言建模模型时,会使用 DataLoader
来遍历数据集。在 DGL 中,您可以使用 GraphDataLoader
。
您还可以使用 torch.utils.data.sampler 中提供的各种数据集采样器。例如,本教程创建了一个训练 GraphDataLoader
和一个测试 GraphDataLoader
,使用 SubsetRandomSampler
来指示 PyTorch 仅从数据集的一个子集中进行采样。
from torch.utils.data.sampler import SubsetRandomSampler
num_examples = len(dataset)
num_train = int(num_examples * 0.8)
train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))
train_dataloader = GraphDataLoader(
dataset, sampler=train_sampler, batch_size=5, drop_last=False
)
test_dataloader = GraphDataLoader(
dataset, sampler=test_sampler, batch_size=5, drop_last=False
)
您可以尝试迭代创建的 GraphDataLoader
并查看其输出
[Graph(num_nodes=236, num_edges=1068,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={}), tensor([0, 1, 0, 0, 0])]
由于 dataset
中的每个元素都包含一个图和一个标签,因此 GraphDataLoader
将在每次迭代时返回两个对象。第一个元素是批处理图,第二个元素是简单的标签向量,表示 mini-batch 中每个图的类别。接下来,我们将讨论批处理图。
DGL 中的批处理图
在每个 mini-batch 中,通过 dgl.batch
将采样到的图合并成一个更大的批处理图。这个更大的批处理图将所有原始图合并为独立的连通分量,并连接它们的节点和边特征。这个更大的图也是一个 DGLGraph
实例(因此您仍然可以像这里一样将其视为一个普通的 DGLGraph
对象)。然而,它包含了恢复原始图所需的信息,例如每个图元素的节点数和边数。
batched_graph, labels = batch
print(
"Number of nodes for each graph element in the batch:",
batched_graph.batch_num_nodes(),
)
print(
"Number of edges for each graph element in the batch:",
batched_graph.batch_num_edges(),
)
# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
print("The original graphs in the minibatch:")
print(graphs)
Number of nodes for each graph element in the batch: tensor([ 17, 15, 38, 146, 20])
Number of edges for each graph element in the batch: tensor([ 81, 71, 198, 628, 90])
The original graphs in the minibatch:
[Graph(num_nodes=17, num_edges=81,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={}), Graph(num_nodes=15, num_edges=71,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={}), Graph(num_nodes=38, num_edges=198,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={}), Graph(num_nodes=146, num_edges=628,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={}), Graph(num_nodes=20, num_edges=90,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={})]
定义模型
本教程将构建一个两层的 图卷积网络 (GCN)。其每一层通过聚合邻居信息来计算新的节点表示。如果您已经阅读了入门教程,您会注意到两个区别:
由于任务是预测 整个图 的单个类别,而不是每个节点的类别,因此您需要聚合所有节点(可能还有边)的表示,以形成图级表示。这个过程通常被称为 readout。一个简单的选择是使用
dgl.mean_nodes()
对图的节点特征进行平均。模型的输入图将是由
GraphDataLoader
生成的批处理图。DGL 提供的 readout 函数可以处理批处理图,以便为每个 mini-batch 元素返回一个表示。
from dgl.nn import GraphConv
class GCN(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(GCN, self).__init__()
self.conv1 = GraphConv(in_feats, h_feats)
self.conv2 = GraphConv(h_feats, num_classes)
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
g.ndata["h"] = h
return dgl.mean_nodes(g, "h")
训练循环
训练循环使用 GraphDataLoader
对象遍历训练集并计算梯度,就像图像分类或语言建模一样。
# Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(20):
for batched_graph, labels in train_dataloader:
pred = model(batched_graph, batched_graph.ndata["attr"].float())
loss = F.cross_entropy(pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
pred = model(batched_graph, batched_graph.ndata["attr"].float())
num_correct += (pred.argmax(1) == labels).sum().item()
num_tests += len(labels)
print("Test accuracy:", num_correct / num_tests)
Test accuracy: 0.05829596412556054
下一步
有关端到端图分类模型的示例,请参见 GIN 示例。
# Thumbnail credits: DGL
# sphinx_gallery_thumbnail_path = '_static/blitz_5_graph_classification.png'
脚本总运行时间: (0 分 55.250 秒)