TreeGridDataset
- class dgl.data.TreeGridDataset(tree_height=8, num_motifs=80, grid_size=3, perturb_ratio=0.1, seed=None, raw_dir=None, force_reload=False, verbose=True, transform=None)[source]
基类:
DGLBuiltinDataset
来自 GNNExplainer: Generating Explanations for Graph Neural Networks 的 TREE-GRIDS 数据集
这是一个用于节点分类的合成数据集。它通过按顺序执行以下步骤生成。
构建一个平衡二叉树作为基础图。
构建一组 n 乘 n 的网格图案(motif)。
将图案附加到基础图上随机选择的节点。
通过添加随机边来扰动图。
为所有节点生成常数特征,该特征值为 1。
树中的节点属于类别 0,网格中的节点属于类别 1。
- 参数:
tree_height (int, optional) – 平衡二叉树的高度。默认值:8
num_motifs (int, optional) – 使用的网格图案(motif)数量。默认值:80
grid_size (int, optional) – 网格图案中的节点数量为 grid_size ^ 2。默认值:3
perturb_ratio (float, optional) – 扰动中添加的随机边数量除以图中原始边的数量。默认值:0.1
seed (integer, random_state, or None, optional) – 随机数生成状态的指示符。默认值:None
raw_dir (str, optional) – 存储处理后数据的原始文件目录。默认值:~/.dgl/
force_reload (bool, optional) – 是否总是从头开始生成数据,而不是加载缓存的版本。默认值:False
verbose (bool, optional) – 是否打印进度信息。默认值:True
transform (callable, optional) – 一个接受
DGLGraph
对象并返回转换后版本的转换器。每次访问时,DGLGraph
对象都将被转换。默认值:None
示例
>>> from dgl.data import TreeGridDataset >>> dataset = TreeGridDataset() >>> dataset.num_classes 2 >>> g = dataset[0] >>> label = g.ndata['label'] >>> feat = g.ndata['feat']