WikiCSDataset

class dgl.data.WikiCSDataset(raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

基类:DGLBuiltinDataset

Wiki-CS 是一个基于维基百科的节点分类数据集,来源于 Wiki-CS: Graph Neural Networks 的维基百科基准

该数据集包含对应于计算机科学文章的节点,边基于超链接,以及表示该领域不同分支的 10 个类别。

WikiCS 数据集统计信息

  • 节点数:11,701

  • 边数:431,726(请注意,原始数据集有 216,123 条边,但 DGL 添加了反向边并移除了重复边,因此数量不同)

  • 类别数:10

  • 节点特征维度:300

  • 不同的训练、验证、停止划分数量:20

  • 测试划分数量:1

参数
  • raw_dir (str) – 下载/包含输入数据目录的原始文件目录。默认值:~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。默认值:False

  • verbose (bool) – 是否打印进度信息。默认值:False

  • transform (callable, optional) – 一个转换函数,它接收一个 DGLGraph 对象并返回一个转换后的版本。每次访问 DGLGraph 对象之前都会对其进行转换。

num_classes

节点类别数

类型

int

示例

>>> from dgl.data import WikiCSDataset
>>> dataset = WikiCSDataset()
>>> dataset.num_classes
10
>>> g = dataset[0]
>>> # get node feature
>>> feat = g.ndata['feat']
>>> # get node labels
>>> labels = g.ndata['label']
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> stopping_mask = g.ndata['stopping_mask']
>>> test_mask = g.ndata['test_mask']
>>> # The shape of train, val and stopping masks are (num_nodes, num_splits).
>>> # The num_splits is the number of different train, validation, stopping splits.
>>> # Due to the number of test spilt is 1, the shape of test mask is (num_nodes,).
>>> print(train_mask.shape, val_mask.shape, stopping_mask.shape)
(11701, 20) (11701, 20) (11701, 20)
>>> print(test_mask.shape)
(11701,)
__getitem__(idx)[source]

获取图对象

参数

idx (int) – 项目索引,WikiCSDataset 只包含一个图对象

返回

图对象包含

  • ndata['feat']:节点特征

  • ndata['label']:节点标签

  • ndata['train_mask']:训练掩码,用于获取训练节点。

  • ndata['val_mask']:验证掩码,用于获取进行超参数调优的节点。

  • ndata['stopping_mask']:停止掩码,用于获取进行早停判断的节点。

  • ndata['test_mask']:测试掩码,用于获取测试节点。

返回类型

dgl.DGLGraph

__len__()[source]

数据集中图的数量。