TolokersDataset

class dgl.data.TolokersDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[源码]

基类: HeterophilousGraphDataset

Tolokers 数据集,来源于论文《异质性下 GNN 评估的批判性审视:我们真的在进步吗?https://arxiv.org/abs/2302.11640》__。

此数据集基于 Toloka 众包平台的数据。节点表示 Tolokers (工人)。如果两个 Tolokers 在同一个任务上工作过,则它们之间存在一条边。目标是预测哪些 Tolokers 在某个项目中被禁止。节点特征基于工人的资料信息和任务表现统计数据。

统计信息

  • 节点数: 11758

  • 边数: 1038000

  • 类别数: 2

  • 节点特征维度: 10

  • 10 个训练/验证/测试划分

参数:
  • raw_dir (str, 可选) – 原始文件目录,用于存储处理后的数据。默认为 ~/.dgl/

  • force_reload (bool, 可选) – 是否重新下载数据源。默认为 False

  • verbose (bool, 可选) – 是否打印进度信息。默认为 True

  • transform (callable, 可选) – 一个转换函数,接受一个 DGLGraph 对象并返回转换后的版本。每次访问前都会转换 DGLGraph 对象。默认为 None

num_classes

节点类别数

类型:

int

示例

>>> from dgl.data import TolokersDataset
>>> dataset = TolokersDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get the first data split
>>> train_mask = g.ndata["train_mask"][:, 0]
>>> val_mask = g.ndata["val_mask"][:, 0]
>>> test_mask = g.ndata["test_mask"][:, 0]
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)

获取指定索引处的数据对象。

__len__()

数据集中样本的数量。