WikiCSDataset
- class dgl.data.WikiCSDataset(raw_dir=None, force_reload=False, verbose=False, transform=None)[source]
基类:
DGLBuiltinDataset
Wiki-CS 是一个基于维基百科的节点分类数据集,来源于 Wiki-CS: Graph Neural Networks 的维基百科基准
该数据集包含对应于计算机科学文章的节点,边基于超链接,以及表示该领域不同分支的 10 个类别。
WikiCS 数据集统计信息
节点数:11,701
边数:431,726(请注意,原始数据集有 216,123 条边,但 DGL 添加了反向边并移除了重复边,因此数量不同)
类别数:10
节点特征维度:300
不同的训练、验证、停止划分数量:20
测试划分数量:1
- 参数:
示例
>>> from dgl.data import WikiCSDataset >>> dataset = WikiCSDataset() >>> dataset.num_classes 10 >>> g = dataset[0] >>> # get node feature >>> feat = g.ndata['feat'] >>> # get node labels >>> labels = g.ndata['label'] >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> stopping_mask = g.ndata['stopping_mask'] >>> test_mask = g.ndata['test_mask'] >>> # The shape of train, val and stopping masks are (num_nodes, num_splits). >>> # The num_splits is the number of different train, validation, stopping splits. >>> # Due to the number of test spilt is 1, the shape of test mask is (num_nodes,). >>> print(train_mask.shape, val_mask.shape, stopping_mask.shape) (11701, 20) (11701, 20) (11701, 20) >>> print(test_mask.shape) (11701,)