MNISTSuperPixelDataset
- class dgl.data.MNISTSuperPixelDataset(raw_dir=None, split='train', use_feature=False, force_reload=False, verbose=False, transform=None)[source]
基类:
SuperPixelDataset
用于图分类任务的 MNIST 超像素数据集。
benchmark-gnn 中的 DGL MNIST 和 CIFAR10 数据集,其中包含从原始 MNIST 和 CIFAR10 图像转换而来的图。
参考 http://arxiv.org/abs/2003.00982
统计信息
训练样本:60,000
测试样本:10,000
数据集图像尺寸:28
- 参数:
raw_dir (str) – 存储所有下载的原始数据集的目录。默认值:“~/.dgl/”。
split (str) – 应从 [“train”, “test”] 中选择。默认值:“train”。
use_feature (bool) –
True: 邻接矩阵由超像素位置 + 特征定义
False: 邻接矩阵仅由超像素位置定义
默认值:False。
force_reload (bool) – 是否重新加载数据集。默认值:False。
verbose (bool) – 是否打印进度信息。默认值:False。
transform (可调用对象, 可选) – 一个转换函数,接收
DGLGraph
对象并返回转换后的版本。该DGLGraph
对象在每次访问前都会被转换。
示例
>>> from dgl.data import MNISTSuperPixelDataset
>>> # MNIST dataset >>> train_dataset = MNISTSuperPixelDataset(split="train") >>> len(train_dataset) 60000 >>> graph, label = train_dataset[0] >>> graph Graph(num_nodes=71, num_edges=568, ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
>>> # support tensor to be index when transform is None >>> # see details in __getitem__ function >>> import torch >>> idx = torch.tensor([0, 1, 2]) >>> train_dataset_subset = train_dataset[idx] >>> train_dataset_subset[0] Graph(num_nodes=71, num_edges=568, ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
- __getitem__(idx)
获取第 idx 个样本。
- 参数:
idx (int 或 张量) – 样本索引。当 transform 为 None 时,允许使用 1 维张量作为 idx。
- 返回:
(
dgl.DGLGraph
, 张量) – 存储节点特征在feat
字段中的图及其标签。或
dgl.data.utils.Subset
– 指定索引处的数据集子集
- __len__()
数据集中的样本数量。