CLUSTERDataset

class dgl.data.CLUSTERDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[源代码]

基类: DGLBuiltinDataset

用于半监督聚类任务的 CLUSTER 数据集。

每个图包含 6 个 SBM 簇,大小在 [5, 35] 之间随机选择,概率分别为 p = 0.55, q = 0.25。图的大小在 40 到 190 个节点之间。每个节点可以取 {0, 1, 2, …, 6} 中的一个输入特征值,其中值 1~6 分别对应于类别 0~5,而值 0 表示节点类别未知。每个社区仅随机分配一个带标签的节点,并且大多数节点特征被设置为 0。

参考 https://arxiv.org/pdf/2003.00982.pdf

统计

  • 训练样本: 10,000

  • 验证样本: 1,000

  • 测试样本: 1,000

  • 每个节点的类别数: 6

参数:
  • mode (str) – 必须是 ('train', 'valid', 'test') 之一。默认值: 'train'

  • raw_dir (str) – 用于下载/包含输入数据目录的原始文件目录。默认值: ~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。默认值: False

  • verbose (bool) – 是否打印进度信息。默认值: False

  • transform (callable, optional) – 一个转换函数,接收一个 DGLGraph 对象并返回一个转换后的版本。每次访问时都会对 DGLGraph 对象进行转换。

num_classes

每个节点的类别数。

类型:

int

示例

>>> from dgl.data import CLUSTERDataset
>>>
>>> trainset = CLUSTERDataset(mode='train')
>>>
>>> trainset.num_classes
6
>>> len(trainset)
10000
>>> trainset[0]
Graph(num_nodes=117, num_edges=4104,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int16),
                     'feat': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
__getitem__(idx)[源代码]

获取第 idx 个样本。

参数:

idx (int) – 样本索引。

返回:

图结构、节点特征、节点标签和边特征。

  • ndata['feat']: 节点特征

  • ndata['label']: 节点标签

  • edata['feat']: 边特征

返回类型:

dgl.DGLGraph

__len__()[源代码]

数据集中的样本数。