ZINCDataset

class dgl.data.ZINCDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

基类:DGLBuiltinDataset

用于图回归任务的 ZINC 数据集。

使用 ZINC 分子图数据集(250K)的一个子集(12K)来回归称为约束溶解度的分子属性。对于每个分子图,节点特征是重原子的类型,边特征是键的类型。每个图包含 9-37 个节点和 16-84 条边。

参考 https://arxiv.org/pdf/2003.00982.pdf

统计数据

训练样本:10,000 验证样本:1,000 测试样本:1,000 平均节点数:23.16 平均边数:39.83 原子类型数:28 键类型数:4

参数:
  • mode (str, 可选) – 必须从 [“train”, “valid”, “test”] 中选择。默认为 “train”。

  • raw_dir (str) – 用于下载或包含输入数据文件的原始文件目录。默认为 “~/.dgl/”。

  • force_reload (bool) – 是否重新加载数据集。默认为 False。

  • verbose (bool) – 是否打印进度信息。默认为 False。

  • transform (callable, 可选) – 一个转换函数,接收一个 DGLGraph 对象并返回一个转换后的版本。在每次访问 DGLGraph 对象之前,都会对其进行转换。

num_atom_types

原子类型数。

类型:

int

num_bond_types

键类型数。

类型:

int

示例

>>> from dgl.data import ZINCDataset
>>> training_set = ZINCDataset(mode="train")
>>> training_set.num_atom_types
28
>>> len(training_set)
10000
>>> graph, label = training_set[0]
>>> graph
Graph(num_nodes=29, num_edges=64,
    ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}
    edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)})
__getitem__(idx)[source]

按索引获取一个样本。

参数:

idx (int) – 样本索引。

返回值:

  • dgl.DGLGraph – 每个图包含

    • ndata['feat']:作为节点特征的重原子类型

    • edata['feat']:作为边特征的键类型

  • Tensor – 作为图标签的约束溶解度

__len__()[source]

数据集中的样本数量。