TreeCycleDataset

class dgl.data.TreeCycleDataset(tree_height=8, num_motifs=60, cycle_size=6, perturb_ratio=0.01, seed=None, raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

基类:DGLBuiltinDataset

TREE-CYCLES 数据集,来源于 GNNExplainer: Generating Explanations for Graph Neural Networks

这是一个用于节点分类的合成数据集。它按照以下步骤顺序生成。

  • 构建一个平衡二叉树作为基础图。

  • 构建一组环状结构。

  • 将这些结构附加到基础图的随机选择的节点上。

  • 通过添加随机边来扰乱图。

  • 为所有节点生成常数特征,其值为 1。

  • 树中的节点属于类别 0,环中的节点属于类别 1。

参数:
  • tree_height (int, 可选) – 平衡二叉树的高度。默认为:8

  • num_motifs (int, 可选) – 使用的环状结构数量。默认为:60

  • cycle_size (int, 可选) – 环状结构中的节点数量。默认为:6

  • perturb_ratio (float, 可选) – 扰动中添加的随机边数除以图中原始边数的比例。默认为:0.01

  • seed (integer, random_stateNone, 可选) – 随机数生成状态的指示符。默认为:None

  • raw_dir (str, 可选) – 存储处理后的数据的原始文件目录。默认为:~/.dgl/

  • force_reload (bool, 可选) – 是否总是从头开始生成数据,而不是加载缓存版本。默认为:False

  • verbose (bool, 可选) – 是否打印进度信息。默认为:True

  • transform (callable, 可选) – 一个转换函数,它接受一个 DGLGraph 对象并返回一个转换后的版本。每次访问 DGLGraph 对象时,它都会被转换。默认为:None

num_classes

节点类别数量

类型:

int

示例

>>> from dgl.data import TreeCycleDataset
>>> dataset = TreeCycleDataset()
>>> dataset.num_classes
2
>>> g = dataset[0]
>>> label = g.ndata['label']
>>> feat = g.ndata['feat']
__getitem__(idx)[source]

获取指定索引的数据对象。

__len__()[source]

数据集中示例的数量。