CLUSTERDataset
- class dgl.data.CLUSTERDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[源代码]
基类:
DGLBuiltinDataset
用于半监督聚类任务的 CLUSTER 数据集。
每个图包含 6 个 SBM 簇,大小在 [5, 35] 之间随机选择,概率分别为 p = 0.55, q = 0.25。图的大小在 40 到 190 个节点之间。每个节点可以取 {0, 1, 2, …, 6} 中的一个输入特征值,其中值 1~6 分别对应于类别 0~5,而值 0 表示节点类别未知。每个社区仅随机分配一个带标签的节点,并且大多数节点特征被设置为 0。
参考 https://arxiv.org/pdf/2003.00982.pdf
统计
训练样本: 10,000
验证样本: 1,000
测试样本: 1,000
每个节点的类别数: 6
- 参数:
示例
>>> from dgl.data import CLUSTERDataset >>> >>> trainset = CLUSTERDataset(mode='train') >>> >>> trainset.num_classes 6 >>> len(trainset) 10000 >>> trainset[0] Graph(num_nodes=117, num_edges=4104, ndata_schemes={'label': Scheme(shape=(), dtype=torch.int16), 'feat': Scheme(shape=(), dtype=torch.int64)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})