add_nodepred_split

class dgl.data.utils.add_nodepred_split(dataset, ratio, ntype=None)[源]

基类

将给定数据集划分为训练集、验证集和测试集,用于直推式节点预测任务。

它会向数据集中的每个图添加三个节点掩码数组 'train_mask', 'val_mask''test_mask'。因此,数据集中的每个样本都必须是 DGLGraph 实例。

固定 NumPy 的随机种子以使结果确定

numpy.random.seed(42)
参数:
  • dataset (DGLDataset) – 要修改的数据集。

  • ratio ((float, float, float)) – 训练集、验证集和测试集的分割比例。总和必须为一。

  • ntype (str, 可选) – 要添加掩码的节点类型。

示例

>>> dataset = dgl.data.AmazonCoBuyComputerDataset()
>>> print('train_mask' in dataset[0].ndata)
False
>>> dgl.data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
>>> print('train_mask' in dataset[0].ndata)
True