add_nodepred_split
- class dgl.data.utils.add_nodepred_split(dataset, ratio, ntype=None)[源]
基类
将给定数据集划分为训练集、验证集和测试集,用于直推式节点预测任务。
它会向数据集中的每个图添加三个节点掩码数组
'train_mask'
,'val_mask'
和'test_mask'
。因此,数据集中的每个样本都必须是DGLGraph
实例。固定 NumPy 的随机种子以使结果确定
numpy.random.seed(42)
- 参数:
dataset (DGLDataset) – 要修改的数据集。
ntype (str, 可选) – 要添加掩码的节点类型。
示例
>>> dataset = dgl.data.AmazonCoBuyComputerDataset() >>> print('train_mask' in dataset[0].ndata) False >>> dgl.data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1]) >>> print('train_mask' in dataset[0].ndata) True