4.3 处理数据

(中文版)

可以在函数 process() 中实现数据处理代码,它假设原始数据已经位于 self.raw_dir 中。图机器学习中通常有三种类型的任务:图分类、节点分类和链接预测。本节将展示如何处理与这些任务相关的数据集。

本节重点介绍处理图、特征和掩码的标准方法。它将使用内置数据集作为示例,并跳过从文件构建图的实现,但会添加指向详细实现的链接。请参阅 1.4 从外部源创建图 以查看如何从外部源构建图的完整指南。

处理图分类数据集

图分类数据集与典型机器学习任务中的大多数数据集几乎相同,其中使用了 mini-batch 训练。因此,可以将原始数据处理为 dgl.DGLGraph 对象的列表和标签张量的列表。此外,如果原始数据被分割成多个文件,可以添加一个参数 split 来加载数据的特定部分。

QM7bDataset 为例

from dgl.data import DGLDataset

class QM7bDataset(DGLDataset):
    _url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
           'datasets/qm7b.mat'
    _sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'

    def __init__(self, raw_dir=None, force_reload=False, verbose=False):
        super(QM7bDataset, self).__init__(name='qm7b',
                                          url=self._url,
                                          raw_dir=raw_dir,
                                          force_reload=force_reload,
                                          verbose=verbose)

    def process(self):
        mat_path = self.raw_path + '.mat'
        # process data to a list of graphs and a list of labels
        self.graphs, self.label = self._load_graph(mat_path)

    def __getitem__(self, idx):
        """ Get graph and label by index

        Parameters
        ----------
        idx : int
            Item index

        Returns
        -------
        (dgl.DGLGraph, Tensor)
        """
        return self.graphs[idx], self.label[idx]

    def __len__(self):
        """Number of graphs in the dataset"""
        return len(self.graphs)

process() 中,原始数据被处理成图的列表和标签的列表。必须实现 __getitem__(idx)__len__() 以进行迭代。DGL 建议让 __getitem__(idx) 返回如上所示的元组 (graph, label)。有关 self._load_graph()__getitem__ 的详细信息,请查看 QM7bDataset 源代码

还可以在类中添加属性以指示数据集的一些有用信息。在 QM7bDataset 中,可以添加一个属性 num_tasks 来指示此多任务数据集中的总预测任务数

@property
def num_tasks(self):
    """Number of labels for each graph, i.e. number of prediction tasks."""
    return 14

完成所有这些编码后,最终可以如下使用 QM7bDataset

import dgl
import torch

from dgl.dataloading import GraphDataLoader

# load data
dataset = QM7bDataset()
num_tasks = dataset.num_tasks

# create dataloaders
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)

# training
for epoch in range(100):
    for g, labels in dataloader:
        # your training code here
        pass

有关训练图分类模型的完整指南,请参见 5.4 图分类

有关图分类数据集的更多示例,请参阅 DGL 的内置图分类数据集

  • gindataset

  • minigcdataset

  • qm7bdata

  • tudata

处理节点分类数据集

与图分类不同,节点分类通常在一个单一图上进行。因此,数据集的分割是在图的节点上。DGL 建议使用节点掩码来指定分割。本节使用内置数据集 CitationGraphDataset 作为示例

此外,DGL 建议重新排列节点和边,以便彼此靠近的节点具有相近的 ID。此过程可以提高访问节点邻居的局部性,这可能有利于后续在图上进行的计算和分析。DGL 为此目的提供了一个 API,称为 dgl.reorder_graph()。有关更多详细信息,请参阅以下示例中的 process() 部分。

from dgl.data import DGLBuiltinDataset
from dgl.data.utils import _get_dgl_url

class CitationGraphDataset(DGLBuiltinDataset):
    _urls = {
        'cora_v2' : 'dataset/cora_v2.zip',
        'citeseer' : 'dataset/citeseer.zip',
        'pubmed' : 'dataset/pubmed.zip',
    }

    def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
        assert name.lower() in ['cora', 'citeseer', 'pubmed']
        if name.lower() == 'cora':
            name = 'cora_v2'
        url = _get_dgl_url(self._urls[name])
        super(CitationGraphDataset, self).__init__(name,
                                                   url=url,
                                                   raw_dir=raw_dir,
                                                   force_reload=force_reload,
                                                   verbose=verbose)

    def process(self):
        # Skip some processing code
        # === data processing skipped ===

        # build graph
        g = dgl.graph(graph)
        # splitting masks
        g.ndata['train_mask'] = train_mask
        g.ndata['val_mask'] = val_mask
        g.ndata['test_mask'] = test_mask
        # node labels
        g.ndata['label'] = torch.tensor(labels)
        # node features
        g.ndata['feat'] = torch.tensor(_preprocess_features(features),
                                       dtype=F.data_type_dict['float32'])
        self._num_tasks = onehot_labels.shape[1]
        self._labels = labels
        # reorder graph to obtain better locality.
        self._g = dgl.reorder_graph(g)

    def __getitem__(self, idx):
        assert idx == 0, "This dataset has only one graph"
        return self._g

    def __len__(self):
        return 1

为简洁起见,本节跳过 process() 中的一些代码,以突出处理节点分类数据集的关键部分:分割掩码。节点特征和节点标签存储在 g.ndata 中。有关详细实现,请参阅 CitationGraphDataset 源代码

请注意,__getitem__(idx)__len__() 的实现也已更改,因为对于节点分类任务通常只有一个图。掩码在 PyTorch 和 TensorFlow 中是 bool tensors,在 MXNet 中是 float tensors

本节使用 CitationGraphDataset 的子类 dgl.data.CiteseerGraphDataset 来展示其用法

# load data
dataset = CiteseerGraphDataset(raw_dir='')
graph = dataset[0]

# get split masks
train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']

# get node features
feats = graph.ndata['feat']

# get labels
labels = graph.ndata['label']

有关训练节点分类模型的完整指南,请参见 5.1 节点分类/回归

有关节点分类数据集的更多示例,请参阅 DGL 的内置数据集

  • citationdata

  • corafulldata

  • amazoncobuydata

  • coauthordata

  • karateclubdata

  • ppidata

  • redditdata

  • sbmdata

  • sstdata

  • rdfdata