PubmedGraphDataset

class dgl.data.PubmedGraphDataset(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None, reorder=False)[源文件]

基类:CitationGraphDataset

Pubmed 引文网络数据集。

节点代表科学出版物,边代表引文关系。每个节点都有一个预定义的 500 维特征。该数据集专为节点分类任务设计。任务是预测特定出版物的类别。

统计信息

  • 节点数:19717

  • 边数:88651

  • 类别数:3

  • 标签划分

    • 训练集:60

    • 验证集:500

    • 测试集:1000

参数:
  • raw_dir (str) – 用于下载/包含输入数据目录的原始文件目录。默认值:~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。默认值:False

  • verbose (bool) – 是否打印进度信息。默认值:True。

  • reverse_edge (bool) – 是否在图中添加反向边。默认值:True。

  • transform (callable, optional) – 一个转换函数,接收一个 DGLGraph 对象并返回一个转换后的版本。在每次访问之前,DGLGraph 对象都会被转换。

  • reorder (bool) – 是否使用 reorder_graph() 对图进行重新排序。默认值:False。

num_classes

标签类别数

类型:

int

备注

节点特征是按行归一化的。

示例

>>> dataset = PubmedGraphDataset()
>>> g = dataset[0]
>>> num_class = dataset.num_of_class
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)[源文件]

获取图对象

参数:

idx (int) – 项目索引,PubmedGraphDataset 只包含一个图对象

返回值:

图结构、节点特征和标签。

  • ndata['train_mask']:训练节点集的掩码

  • ndata['val_mask']:验证节点集的掩码

  • ndata['test_mask']:测试节点集的掩码

  • ndata['feat']:节点特征

  • ndata['label']:真实标签

返回类型:

dgl.DGLGraph

__len__()[源文件]

数据集中的图数量。