PPIDataset

class dgl.data.PPIDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

基类:DGLBuiltinDataset

用于归纳节点分类的蛋白质-蛋白质相互作用数据集

一个用于蛋白质-蛋白质相互作用网络的玩具数据集。该数据集包含 24 个图。每个图的平均节点数为 2372。每个节点有 50 个特征和 121 个标签。其中 20 个图用于训练,2 个用于验证,2 个用于测试。

参考:http://snap.stanford.edu/graphsage/

统计信息

  • 训练样本数:20

  • 验证样本数:2

  • 测试样本数:2

参数:
  • mode (str) – 必须是 (‘train’, ‘valid’, ‘test’) 之一。默认为 ‘train’

  • raw_dir (str) – 下载/包含输入数据目录的原始文件目录。默认为 ~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。默认为 False

  • verbose (bool) – 是否打印进度信息。默认为 True。

  • transform (callable, optional) – 一个接受 DGLGraph 对象并返回转换后版本的数据转换。该 DGLGraph 对象将在每次访问前进行转换。

num_labels

每个节点的标签数

类型:

int

labels

节点标签

类型:

Tensor

features

节点特征

类型:

Tensor

示例

>>> dataset = PPIDataset(mode='valid')
>>> num_classes = dataset.num_classes
>>> for g in dataset:
....    feat = g.ndata['feat']
....    label = g.ndata['label']
....    # your code here
>>>
__getitem__(item)[source]

获取第 item 个样本。

参数:

item (int) – 样本索引。

返回:

图结构、节点特征和节点标签。

  • ndata['feat']: 节点特征

  • ndata['label']: 节点标签

返回类型:

dgl.DGLGraph

__len__()[source]

返回此数据集中的样本数量。