FlickrDataset

class dgl.data.FlickrDataset(raw_dir=None, force_reload=False, verbose=False, transform=None, reorder=False)[source]

基类:DGLBuiltinDataset

用于节点分类的 Flickr 数据集,来自 GraphSAINT: Graph Sampling Based Inductive Learning Method

该数据集的任务是根据在线图像的描述和共同属性对图像类型进行分类。

Flickr 数据集统计信息

  • 节点数:89,250

  • 边数:899,756

  • 类别数:7

  • 节点特征维度:500

参数:
  • raw_dir (str) – 用于下载/包含输入数据目录的原始文件目录。默认为:~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。默认为:False

  • verbose (bool) – 是否打印进度信息。默认为:False

  • transform (callable, optional) – 一个转换函数,接受一个 DGLGraph 对象并返回转换后的版本。每次访问前都会对 DGLGraph 对象进行转换。

  • reorder (bool) – 是否使用 reorder_graph() 对图进行重新排序。默认为:False。

num_classes

节点类别数

类型:

int

示例

>>> from dgl.data import FlickrDataset
>>> dataset = FlickrDataset()
>>> dataset.num_classes
7
>>> g = dataset[0]
>>> # get node feature
>>> feat = g.ndata['feat']
>>> # get node labels
>>> labels = g.ndata['label']
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
__getitem__(idx)[source]

获取图对象

参数:

idx (int) – 元素索引,FlickrDataset 只有一个图对象

返回:

图对象包含

  • ndata['label']: 节点标签

  • ndata['feat']: 节点特征

  • ndata['train_mask']: 训练节点集的掩码

  • ndata['val_mask']: 验证节点集的掩码

  • ndata['test_mask']: 测试节点集的掩码

返回类型:

dgl.DGLGraph

__len__()[source]

数据集中的图数量。