BAShapeDataset
- class dgl.data.BAShapeDataset(num_base_nodes=300, num_base_edges_per_node=5, num_motifs=80, perturb_ratio=0.01, seed=None, raw_dir=None, force_reload=False, verbose=True, transform=None)[源码]
基类:
DGLBuiltinDataset
BA-SHAPES 数据集,来自 GNNExplainer: Generating Explanations for Graph Neural Networks
这是一个用于节点分类的合成数据集。它通过按以下步骤生成:
构建一个基础的 Barabási–Albert (BA) 图。
构建一组五节点房屋结构的子图 (motifs)。
将这些子图附加到基础图的随机选择的节点上。
通过添加随机边来扰动图。
节点被分配到 4 个类别。标签为 0 的节点属于基础 BA 图。标签为 1、2、3 的节点分别位于房屋结构的中间、底部或顶部。
为所有节点生成恒定特征,其值为 1。
- 参数:
num_base_nodes (int, optional) – 基础 BA 图中的节点数量。默认值: 300
num_base_edges_per_node (int, optional) – 在构建基础 BA 图时,从新节点附加到现有节点的边数。默认值: 5
num_motifs (int, optional) – 使用的房屋结构子图数量。默认值: 80
perturb_ratio (float, optional) – 扰动中添加的随机边数除以原始图中的边数。默认值: 0.01
seed (integer, random_state, or None, optional) – 随机数生成状态的指示器。默认值: None
raw_dir (str, optional) – 存储处理后数据的原始文件目录。默认值: ~/.dgl/
force_reload (bool, optional) – 是否总是从头开始生成数据,而不是加载缓存版本。默认值: False
verbose (bool, optional) – 是否打印进度信息。默认值: True
transform (callable, optional) – 一个转换函数,它接收一个
DGLGraph
对象并返回一个转换后的版本。每次访问DGLGraph
对象之前都会对其进行转换。默认值: None
示例
>>> from dgl.data import BAShapeDataset >>> dataset = BAShapeDataset() >>> dataset.num_classes 4 >>> g = dataset[0] >>> label = g.ndata['label'] >>> feat = g.ndata['feat']