AMDataset

dgl.data.AMDataset(print_every=10000, insert_reverse=True, raw_dir=None, force_reload=False, verbose=True, transform=None)[源代码]

基类: RDFGraphDataset

AM 数据集,用于节点分类任务

命名空间约定

  • 实例: http://purl.org/collections/nl/am/<type>-<id>

  • 关系: http://purl.org/collections/nl/am/<name>

我们在输出图中忽略了所有文字节点以及连接它们的关​​系。

AM 数据集统计信息

  • 节点数: 881680

  • 边数: 5668682 (包含反向边)

  • 目标类别: proxy

  • 类别数: 11

  • 标签划分

    • 训练集: 802

    • 测试集: 198

参数:
  • print_every (int) – 每处理 X 个三元组时打印预处理日志。默认值: 10000。

  • insert_reverse (bool) – 如果为 True,则在最终图中添加反向边和反向关系。默认值: True。

  • raw_dir (str) – 原始文件目录,用于下载或包含输入数据目录。默认值: ~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。默认值: False

  • verbose (bool) – 是否打印进度信息。默认值: True。

  • transform (callable, optional) – 一个转换函数,接受一个 DGLGraph 对象并返回其转换后的版本。每次访问时,DGLGraph 对象都会被转换。

num_classes

要预测的类别数

类型:

int

predict_category

具有预测标签的实体类别(节点类型)

类型:

str

示例

>>> dataset = dgl.data.rdf.AMDataset()
>>> graph = dataset[0]
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
__getitem__(idx)[源代码]

获取图对象

参数:

idx (int) – 项目索引,AMDataset 只有一个图对象

返回:

图包含

  • ndata['train_mask']: 训练节点集的掩码

  • ndata['test_mask']: 测试节点集的掩码

  • ndata['label']: 节点标签

返回类型:

dgl.DGLGraph

__len__()[源代码]

数据集中的图数量。

返回类型:

int