AsNodePredDataset
- class dgl.data.AsNodePredDataset(dataset, split_ratio=None, target_ntype=None, **kwargs)[源码]
基类:
DGLDataset
将数据集重新用于标准的半监督转导节点预测任务。
该类将给定的数据集转换为新的数据集对象,使其满足以下条件:
仅包含一个图,可通过
dataset[0]
访问。图存储有
节点标签存储在
g.ndata['label']
中。训练/验证/测试掩码分别存储在
g.ndata['train_mask']
,g.ndata['val_mask']
和g.ndata['test_mask']
中。
此外,该数据集还包含以下属性
num_classes
,要预测的类别数量。train_idx
,val_idx
,test_idx
,训练/验证/测试索引。
如果输入数据集包含异构图,用户需要指定
target_ntype
参数来指示对哪种节点类型进行预测。在这种情况下节点标签存储在
g.nodes[target_ntype].data['label']
中。训练掩码存储在
g.nodes[target_ntype].data['train_mask']
中。验证和测试掩码也同样。
该类将保留提供的 数据集中的第一个图,并根据给定的分割比例生成训练/验证/测试掩码。生成的掩码将被缓存到磁盘,以便快速重新加载。如果提供的分割比例与缓存的比例不同,它将重新处理数据集。
- 参数:
dataset (DGLDataset) – 要转换的数据集。
split_ratio ((float, float, float), optional) – 训练集、验证集和测试集的分割比例。它们必须总和为一。
target_ntype (str, optional) – 要添加分割掩码的节点类型。
- train_idx
一个包含训练节点 ID 的一维整数张量。
- 类型:
张量
- val_idx
一个包含验证节点 ID 的一维整数张量。
- 类型:
张量
- test_idx
一个包含测试节点 ID 的一维整数张量。
- 类型:
张量
示例
>>> ds = dgl.data.AmazonCoBuyComputerDataset() >>> print(ds) Dataset("amazon_co_buy_computer", num_graphs=1, save_path=...) >>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1]) >>> print(new_ds) Dataset("amazon_co_buy_computer-as-nodepred", num_graphs=1, save_path=...) >>> print('train_mask' in new_ds[0].ndata) True