4.3 处理数据
可以在函数 process()
中实现数据处理代码,它假设原始数据已经位于 self.raw_dir
中。图机器学习中通常有三种类型的任务:图分类、节点分类和链接预测。本节将展示如何处理与这些任务相关的数据集。
本节重点介绍处理图、特征和掩码的标准方法。它将使用内置数据集作为示例,并跳过从文件构建图的实现,但会添加指向详细实现的链接。请参阅 1.4 从外部源创建图 以查看如何从外部源构建图的完整指南。
处理图分类数据集
图分类数据集与典型机器学习任务中的大多数数据集几乎相同,其中使用了 mini-batch 训练。因此,可以将原始数据处理为 dgl.DGLGraph
对象的列表和标签张量的列表。此外,如果原始数据被分割成多个文件,可以添加一个参数 split
来加载数据的特定部分。
以 QM7bDataset
为例
from dgl.data import DGLDataset
class QM7bDataset(DGLDataset):
_url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
'datasets/qm7b.mat'
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(QM7bDataset, self).__init__(name='qm7b',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
mat_path = self.raw_path + '.mat'
# process data to a list of graphs and a list of labels
self.graphs, self.label = self._load_graph(mat_path)
def __getitem__(self, idx):
""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
(dgl.DGLGraph, Tensor)
"""
return self.graphs[idx], self.label[idx]
def __len__(self):
"""Number of graphs in the dataset"""
return len(self.graphs)
在 process()
中,原始数据被处理成图的列表和标签的列表。必须实现 __getitem__(idx)
和 __len__()
以进行迭代。DGL 建议让 __getitem__(idx)
返回如上所示的元组 (graph, label)
。有关 self._load_graph()
和 __getitem__
的详细信息,请查看 QM7bDataset 源代码。
还可以在类中添加属性以指示数据集的一些有用信息。在 QM7bDataset
中,可以添加一个属性 num_tasks
来指示此多任务数据集中的总预测任务数
@property
def num_tasks(self):
"""Number of labels for each graph, i.e. number of prediction tasks."""
return 14
完成所有这些编码后,最终可以如下使用 QM7bDataset
import dgl
import torch
from dgl.dataloading import GraphDataLoader
# load data
dataset = QM7bDataset()
num_tasks = dataset.num_tasks
# create dataloaders
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
# training
for epoch in range(100):
for g, labels in dataloader:
# your training code here
pass
有关训练图分类模型的完整指南,请参见 5.4 图分类。
有关图分类数据集的更多示例,请参阅 DGL 的内置图分类数据集
gindataset
minigcdataset
qm7bdata
tudata
处理节点分类数据集
与图分类不同,节点分类通常在一个单一图上进行。因此,数据集的分割是在图的节点上。DGL 建议使用节点掩码来指定分割。本节使用内置数据集 CitationGraphDataset 作为示例
此外,DGL 建议重新排列节点和边,以便彼此靠近的节点具有相近的 ID。此过程可以提高访问节点邻居的局部性,这可能有利于后续在图上进行的计算和分析。DGL 为此目的提供了一个 API,称为 dgl.reorder_graph()
。有关更多详细信息,请参阅以下示例中的 process()
部分。
from dgl.data import DGLBuiltinDataset
from dgl.data.utils import _get_dgl_url
class CitationGraphDataset(DGLBuiltinDataset):
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
assert name.lower() in ['cora', 'citeseer', 'pubmed']
if name.lower() == 'cora':
name = 'cora_v2'
url = _get_dgl_url(self._urls[name])
super(CitationGraphDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
# Skip some processing code
# === data processing skipped ===
# build graph
g = dgl.graph(graph)
# splitting masks
g.ndata['train_mask'] = train_mask
g.ndata['val_mask'] = val_mask
g.ndata['test_mask'] = test_mask
# node labels
g.ndata['label'] = torch.tensor(labels)
# node features
g.ndata['feat'] = torch.tensor(_preprocess_features(features),
dtype=F.data_type_dict['float32'])
self._num_tasks = onehot_labels.shape[1]
self._labels = labels
# reorder graph to obtain better locality.
self._g = dgl.reorder_graph(g)
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
def __len__(self):
return 1
为简洁起见,本节跳过 process()
中的一些代码,以突出处理节点分类数据集的关键部分:分割掩码。节点特征和节点标签存储在 g.ndata
中。有关详细实现,请参阅 CitationGraphDataset 源代码。
请注意,__getitem__(idx)
和 __len__()
的实现也已更改,因为对于节点分类任务通常只有一个图。掩码在 PyTorch 和 TensorFlow 中是 bool tensors
,在 MXNet 中是 float tensors
。
本节使用 CitationGraphDataset
的子类 dgl.data.CiteseerGraphDataset
来展示其用法
# load data
dataset = CiteseerGraphDataset(raw_dir='')
graph = dataset[0]
# get split masks
train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
# get node features
feats = graph.ndata['feat']
# get labels
labels = graph.ndata['label']
有关训练节点分类模型的完整指南,请参见 5.1 节点分类/回归。
有关节点分类数据集的更多示例,请参阅 DGL 的内置数据集
citationdata
corafulldata
amazoncobuydata
coauthordata
karateclubdata
ppidata
redditdata
sbmdata
sstdata
rdfdata
处理链接预测数据集
链接预测数据集的处理类似于节点分类,数据集中通常只有一个图。
本节使用内置数据集 KnowledgeGraphDataset 作为示例,并仍然跳过详细的数据处理代码,以突出处理链接预测数据集的关键部分
# Example for creating Link Prediction datasets
class KnowledgeGraphDataset(DGLBuiltinDataset):
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):
self._name = name
self.reverse = reverse
url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
super(KnowledgeGraphDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
# Skip some processing code
# === data processing skipped ===
# splitting mask
g.edata['train_mask'] = train_mask
g.edata['val_mask'] = val_mask
g.edata['test_mask'] = test_mask
# edge type
g.edata['etype'] = etype
# node type
g.ndata['ntype'] = ntype
self._g = g
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
def __len__(self):
return 1
如代码所示,它将分割掩码添加到图的 edata
字段中。请查看 KnowledgeGraphDataset 源代码 以查看完整代码。以下代码使用 KnowledgeGraphDataset
的子类 dgl.data.FB15k237Dataset
来展示其用法
from dgl.data import FB15k237Dataset
# load data
dataset = FB15k237Dataset()
graph = dataset[0]
# get training mask
train_mask = graph.edata['train_mask']
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
src, dst = graph.edges(train_idx)
# get edge types in training set
rel = graph.edata['etype'][train_idx]
有关训练链接预测模型的完整指南,请参见 5.3 链接预测。
有关链接预测数据集的更多示例,请参阅 DGL 的内置数据集
kgdata
bitcoinotcdata