CoraGraphDataset
- class dgl.data.CoraGraphDataset(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None, reorder=False)[source]
基类:
CitationGraphDataset
Cora 引文网络数据集。
节点代表论文,边代表引文关系。每个节点都有一个预定义的特征,维度为 1433。此数据集设计用于节点分类任务。任务是预测特定论文的类别。
统计信息
节点数: 2708
边数: 10556
类别数: 7
标签划分
训练集: 140
验证集: 500
测试集: 1000
- 参数:
raw_dir (str) – 用于下载/包含输入数据目录的原始文件目录。默认值: ~/.dgl/
force_reload (bool) – 是否重新加载数据集。默认值: False
verbose (bool) – 是否打印进度信息。默认值: True。
reverse_edge (bool) – 是否在图中添加反向边。默认值: True。
transform (callable, optional) – 一个转换函数,接受
DGLGraph
对象并返回转换后的版本。每次访问DGLGraph
对象之前都会进行转换。reorder (bool) – 是否使用
reorder_graph()
重新排序图。默认值: False。
说明
节点特征按行归一化。
示例
>>> dataset = CoraGraphDataset() >>> g = dataset[0] >>> num_class = dataset.num_classes >>> >>> # get node feature >>> feat = g.ndata['feat'] >>> >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> test_mask = g.ndata['test_mask'] >>> >>> # get labels >>> label = g.ndata['label']