RedditDataset

class dgl.data.RedditDataset(self_loop=False, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

基类: DGLBuiltinDataset

用于社区检测(节点分类)的 Reddit 数据集

这是从 Reddit 帖子(创建于 2014 年 9 月)构建的图数据集。在这种情况下,节点标签是帖子所属的社区,或称为“subreddit”。作者采样了 50 个大型社区,构建了一个帖子到帖子的图,如果同一用户评论了两个帖子,则连接它们。该数据集总共包含 232,965 个帖子,平均度为 492。我们使用前 20 天的数据进行训练,剩余天数用于测试(其中 30% 用于验证)。

参考: http://snap.stanford.edu/graphsage/

统计信息

  • 节点数: 232,965

  • 边数: 114,615,892

  • 节点特征大小: 602

  • 训练样本数: 153,431

  • 验证样本数: 23,831

  • 测试样本数: 55,703

参数:
  • self_loop (bool) – 是否加载带有自环连接的数据集。默认值: False

  • raw_dir (str) – 下载原始文件或包含输入数据目录的目录。默认值: ~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。默认值: False

  • verbose (bool) – 是否打印进度信息。默认值: True。

  • transform (callable, optional) – 一个转换函数,接受一个 DGLGraph 对象并返回一个转换后的版本。每次访问 DGLGraph 对象之前都会对其进行转换。

num_classes

每个节点的类别数

类型:

int

示例

>>> data = RedditDataset()
>>> g = data[0]
>>> num_classes = data.num_classes
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
>>>
>>> # Train, Validation and Test
__getitem__(idx)[source]

按索引获取图

参数:

idx (int) – 项目索引

返回:

图结构、节点标签、节点特征和分割掩码

  • ndata['label']: 节点标签

  • ndata['feat']: 节点特征

  • ndata['train_mask']: 训练节点集的掩码

  • ndata['val_mask']: 验证节点集的掩码

  • ndata['test_mask']: 测试节点集的掩码

返回类型:

dgl.DGLGraph

__len__()[source]

数据集中图的数量