MinesweeperDataset

class dgl.data.MinesweeperDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[源代码]

基类:HeterophilousGraphDataset

来自论文《对异质性下 GNN 评估的批判性审视:我们真的在进步吗?https://arxiv.org/abs/2302.11640》的 Minesweeper 数据集。

此数据集灵感来源于扫雷游戏。图是一个规则的 100x100 网格,其中每个节点(单元格)连接到八个相邻节点(网格边缘的节点除外,它们有较少的邻居)。20% 的节点被随机选为地雷。任务是预测哪些节点是地雷。节点特征是相邻地雷数量的独热编码。然而,对于随机选取的 50% 节点,其特征是未知的,这由一个单独的二元特征指示。

统计数据

  • 节点数:10000

  • 边数:78804

  • 类别数:2

  • 节点特征数:7

  • 10 个训练/验证/测试分割

参数
  • raw_dir (str, 可选) – 用于存储处理后的数据的原始文件目录。默认值:~/.dgl/

  • force_reload (bool, 可选) – 是否重新下载数据源。默认值:False

  • verbose (bool, 可选) – 是否打印进度信息。默认值:True

  • transform (callable, 可选) – 一个转换函数,接受一个 DGLGraph 对象并返回一个转换后的版本。DGLGraph 对象将在每次访问前被转换。默认值:None

num_classes

节点类别数

类型

int

示例

>>> from dgl.data import MinesweeperDataset
>>> dataset = MinesweeperDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get the first data split
>>> train_mask = g.ndata["train_mask"][:, 0]
>>> val_mask = g.ndata["val_mask"][:, 0]
>>> test_mask = g.ndata["test_mask"][:, 0]
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)

获取索引处的数据对象。

__len__()

数据集中的示例数量。