CIFAR10SuperPixelDataset

class dgl.data.CIFAR10SuperPixelDataset(raw_dir=None, split='train', use_feature=False, force_reload=False, verbose=False, transform=None)[源码]

基类: SuperPixelDataset

用于图分类任务的 CIFAR10 超像素数据集。

benchmark-gnn 中用于 CIFAR10 的 DGL 数据集,它包含从原始 CIFAR10 图像转换而来的图。

参考 http://arxiv.org/abs/2003.00982

统计信息

  • 训练样本数: 50,000

  • 测试样本数: 10,000

  • 数据集图像大小: 32

参数:
  • raw_dir (str) – 存储所有下载的原始数据集的目录。 默认值: “~/.dgl/”。

  • split (str) – 应从 [“train”, “test”] 中选择。 默认值: “train”。

  • use_feature (bool) –

    • True: 邻接矩阵由超像素位置 + 特征定义

    • False: 邻接矩阵仅由超像素位置定义

    默认值: False。

  • force_reload (bool) – 是否重新加载数据集。 默认值: False。

  • verbose (bool) – 是否打印进度信息。 默认值: False。

  • transform (callable, optional) – 一个转换函数,它接收一个 DGLGraph 对象并返回其转换后的版本。 该 DGLGraph 对象将在每次访问前被转换。

示例

>>> from dgl.data import CIFAR10SuperPixelDataset
>>> # CIFAR10 dataset
>>> train_dataset = CIFAR10SuperPixelDataset(split="train")
>>> len(train_dataset)
50000
>>> graph, label = train_dataset[0]
>>> graph
Graph(num_nodes=123, num_edges=984,
    ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)}
    edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}),
>>> # support tensor to be index when transform is None
>>> # see details in __getitem__ function
>>> import torch
>>> idx = torch.tensor([0, 1, 2])
>>> train_dataset_subset = train_dataset[idx]
>>> train_dataset_subset[0]
Graph(num_nodes=123, num_edges=984,
    ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)}
    edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}),
__getitem__(idx)

获取第 idx 个样本。

参数:

idx (int or tensor) – 样本索引。 当 transform 为 None 时,允许使用 1-D tensor 作为 idx

返回值:

__len__()

数据集中的样本数量。