AsNodePredDataset

class dgl.data.AsNodePredDataset(dataset, split_ratio=None, target_ntype=None, **kwargs)[源码]

基类: DGLDataset

将数据集重新用于标准的半监督转导节点预测任务。

该类将给定的数据集转换为新的数据集对象,使其满足以下条件:

  • 仅包含一个图,可通过 dataset[0] 访问。

  • 图存储有

    • 节点标签存储在 g.ndata['label'] 中。

    • 训练/验证/测试掩码分别存储在 g.ndata['train_mask'], g.ndata['val_mask']g.ndata['test_mask'] 中。

  • 此外,该数据集还包含以下属性

    • num_classes,要预测的类别数量。

    • train_idx, val_idx, test_idx,训练/验证/测试索引。

如果输入数据集包含异构图,用户需要指定 target_ntype 参数来指示对哪种节点类型进行预测。在这种情况下

  • 节点标签存储在 g.nodes[target_ntype].data['label'] 中。

  • 训练掩码存储在 g.nodes[target_ntype].data['train_mask'] 中。验证和测试掩码也同样。

该类将保留提供的 数据集中的第一个图,并根据给定的分割比例生成训练/验证/测试掩码。生成的掩码将被缓存到磁盘,以便快速重新加载。如果提供的分割比例与缓存的比例不同,它将重新处理数据集。

参数:
  • dataset (DGLDataset) – 要转换的数据集。

  • split_ratio ((float, float, float), optional) – 训练集、验证集和测试集的分割比例。它们必须总和为一。

  • target_ntype (str, optional) – 要添加分割掩码的节点类型。

num_classes

要预测的类别数量。

类型:

int

train_idx

一个包含训练节点 ID 的一维整数张量。

类型:

张量

val_idx

一个包含验证节点 ID 的一维整数张量。

类型:

张量

test_idx

一个包含测试节点 ID 的一维整数张量。

类型:

张量

示例

>>> ds = dgl.data.AmazonCoBuyComputerDataset()
>>> print(ds)
Dataset("amazon_co_buy_computer", num_graphs=1, save_path=...)
>>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1])
>>> print(new_ds)
Dataset("amazon_co_buy_computer-as-nodepred", num_graphs=1, save_path=...)
>>> print('train_mask' in new_ds[0].ndata)
True
__getitem__(idx)[源码]

获取索引处的数据对象。

__len__()[源码]

数据集中的样本数量。