PATTERNDataset
- class dgl.data.PATTERNDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[源码]
基类:
DGLBuiltinDataset
用于图模式识别任务的 PATTERN 数据集。
每个图 G 包含 5 个社区,大小在 [5, 35] 之间随机选择。每个社区的 SBM 参数为 p = 0.5, q = 0.35。G 的节点特征使用大小为 3 的词汇表(即 {0, 1, 2})进行均匀随机分布生成。然后随机生成 100 个由 20 个节点组成的模式 P,其内部连接概率为 \(p_P\) = 0.5,外部连接概率为 \(q_P\) = 0.5(即 P 中 50% 的节点连接到 G)。P 的节点特征也生成为值 {0, 1, 2} 的随机信号。图的大小范围为 44-188 个节点。输出节点标签的值为 1 如果节点属于 P,值为 0 如果节点在 G 中。
参考 https://arxiv.org/pdf/2003.00982.pdf
统计信息
训练样本数: 10,000
验证样本数: 2,000
测试样本数: 2,000
每个节点的类别数: 2
- 参数:
示例
>>> from dgl.data import PATTERNDataset >>> data = PATTERNDataset(mode='train') >>> data.num_classes 2 >>> len(trainset) 10000 >>> data[0] Graph(num_nodes=108, num_edges=4884, ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(), dtype=torch.int16)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})