CiteseerGraphDataset
- class dgl.data.CiteseerGraphDataset(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None, reorder=False)[source]
基类:
CitationGraphDataset
Citeseer 引文网络数据集。
节点代表科学出版物,边代表引文关系。每个节点都有一个预定义的特征,维度为 3703。该数据集设计用于节点分类任务。任务是预测特定出版物的类别。
统计信息
节点数: 3327
边数: 9228
类别数: 6
标签划分
训练集: 120
验证集: 500
测试集: 1000
- 参数:
raw_dir (str) – 用于下载/包含输入数据目录的原始文件目录。默认值: ~/.dgl/
force_reload (bool) – 是否重新加载数据集。默认值: False
verbose (bool) – 是否打印进度信息。默认值: True。
reverse_edge (bool) – 是否在图中添加反向边。默认值: True。
transform (callable, optional) – 一个转换函数,输入一个
DGLGraph
对象并返回一个转换后的版本。每次访问DGLGraph
对象之前都会对其进行转换。reorder (bool) – 是否使用
reorder_graph()
对图进行重新排序。默认值: False。
说明
节点特征是按行归一化的。
在 citeseer 数据集中,图中有一些孤立节点。这些孤立节点作为零向量被添加到正确的位置。
示例
>>> dataset = CiteseerGraphDataset() >>> g = dataset[0] >>> num_class = dataset.num_classes >>> >>> # get node feature >>> feat = g.ndata['feat'] >>> >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> test_mask = g.ndata['test_mask'] >>> >>> # get labels >>> label = g.ndata['label']