RedditDataset
- class dgl.data.RedditDataset(self_loop=False, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]
基类:
DGLBuiltinDataset
用于社区检测(节点分类)的 Reddit 数据集
这是从 Reddit 帖子(创建于 2014 年 9 月)构建的图数据集。在这种情况下,节点标签是帖子所属的社区,或称为“subreddit”。作者采样了 50 个大型社区,构建了一个帖子到帖子的图,如果同一用户评论了两个帖子,则连接它们。该数据集总共包含 232,965 个帖子,平均度为 492。我们使用前 20 天的数据进行训练,剩余天数用于测试(其中 30% 用于验证)。
参考: http://snap.stanford.edu/graphsage/
统计信息
节点数: 232,965
边数: 114,615,892
节点特征大小: 602
训练样本数: 153,431
验证样本数: 23,831
测试样本数: 55,703
- 参数:
示例
>>> data = RedditDataset() >>> g = data[0] >>> num_classes = data.num_classes >>> >>> # get node feature >>> feat = g.ndata['feat'] >>> >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> test_mask = g.ndata['test_mask'] >>> >>> # get labels >>> label = g.ndata['label'] >>> >>> # Train, Validation and Test