PATTERNDataset

class dgl.data.PATTERNDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[源码]

基类: DGLBuiltinDataset

用于图模式识别任务的 PATTERN 数据集。

每个图 G 包含 5 个社区,大小在 [5, 35] 之间随机选择。每个社区的 SBM 参数为 p = 0.5, q = 0.35。G 的节点特征使用大小为 3 的词汇表(即 {0, 1, 2})进行均匀随机分布生成。然后随机生成 100 个由 20 个节点组成的模式 P,其内部连接概率为 \(p_P\) = 0.5,外部连接概率为 \(q_P\) = 0.5(即 P 中 50% 的节点连接到 G)。P 的节点特征也生成为值 {0, 1, 2} 的随机信号。图的大小范围为 44-188 个节点。输出节点标签的值为 1 如果节点属于 P,值为 0 如果节点在 G 中。

参考 https://arxiv.org/pdf/2003.00982.pdf

统计信息

  • 训练样本数: 10,000

  • 验证样本数: 2,000

  • 测试样本数: 2,000

  • 每个节点的类别数: 2

参数:
  • mode (str) – 必须是 ('train', 'valid', 'test') 之一。默认值: 'train'

  • raw_dir (str) – 下载/包含输入数据目录的原始文件目录。默认值: ~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。默认值: False

  • verbose (bool) – 是否打印进度信息。默认值: False

  • transform (可调用对象, 可选) – 一个转换函数,接受一个 DGLGraph 对象并返回一个转换后的版本。每次访问前,DGLGraph 对象都会被转换。

num_classes

每个节点的类别数。

类型:

int

示例

>>> from dgl.data import PATTERNDataset
>>> data = PATTERNDataset(mode='train')
>>> data.num_classes
2
>>> len(trainset)
10000
>>> data[0]
Graph(num_nodes=108, num_edges=4884, ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(), dtype=torch.int16)}
edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
__getitem__(idx)[源码]

获取第 idx 个样本。

参数:

idx (int) – 样本索引。

返回:

图结构、节点特征、节点标签和边特征。

  • ndata['feat']: 节点特征

  • ndata['label']: 节点标签

  • edata['feat']: 边特征

返回类型:

dgl.DGLGraph

__len__()[源码]

数据集中的样本数量。