PubmedGraphDataset
- class dgl.data.PubmedGraphDataset(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None, reorder=False)[源文件]
基类:
CitationGraphDataset
Pubmed 引文网络数据集。
节点代表科学出版物,边代表引文关系。每个节点都有一个预定义的 500 维特征。该数据集专为节点分类任务设计。任务是预测特定出版物的类别。
统计信息
节点数:19717
边数:88651
类别数:3
标签划分
训练集:60
验证集:500
测试集:1000
- 参数:
raw_dir (str) – 用于下载/包含输入数据目录的原始文件目录。默认值:~/.dgl/
force_reload (bool) – 是否重新加载数据集。默认值:False
verbose (bool) – 是否打印进度信息。默认值:True。
reverse_edge (bool) – 是否在图中添加反向边。默认值:True。
transform (callable, optional) – 一个转换函数,接收一个
DGLGraph
对象并返回一个转换后的版本。在每次访问之前,DGLGraph
对象都会被转换。reorder (bool) – 是否使用
reorder_graph()
对图进行重新排序。默认值:False。
备注
节点特征是按行归一化的。
示例
>>> dataset = PubmedGraphDataset() >>> g = dataset[0] >>> num_class = dataset.num_of_class >>> >>> # get node feature >>> feat = g.ndata['feat'] >>> >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> test_mask = g.ndata['test_mask'] >>> >>> # get labels >>> label = g.ndata['label']