Questions 数据集

class dgl.data.QuestionsDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[源码]

基类: HeterophilousGraphDataset

来自论文《A Critical Look at the Evaluation of GNNs under Heterophily: Are We Really Making Progress? https://arxiv.org/abs/2302.11640》的 Questions 数据集。

此数据集基于问答网站 Yandex Q 的数据。节点是用户,如果一个用户回答了另一个用户的问题,则连接这两个节点。任务是预测哪些用户在网站上保持活跃(未被删除或阻止)。节点特征是用户描述中单词嵌入的平均值。没有描述的用户由一个单独的二进制特征表示。

统计信息

  • 节点数:48921

  • 边数:307080

  • 类别数:2

  • 节点特征维度:301

  • 10个训练/验证/测试划分

参数:
  • raw_dir (str, optional) – 存储处理后的原始文件目录。默认值:~/.dgl/

  • force_reload (bool, optional) – 是否重新下载数据源。默认值:False

  • verbose (bool, optional) – 是否打印进度信息。默认值:True

  • transform (callable, optional) – 一个转换函数,输入一个 DGLGraph 对象并返回一个转换后的版本。每次访问 DGLGraph 对象之前都会对其进行转换。默认值:None

num_classes

节点类别数

类型:

int

示例

>>> from dgl.data import QuestionsDataset
>>> dataset = QuestionsDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get the first data split
>>> train_mask = g.ndata["train_mask"][:, 0]
>>> val_mask = g.ndata["val_mask"][:, 0]
>>> test_mask = g.ndata["test_mask"][:, 0]
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)

获取指定索引的数据对象。

__len__()

数据集中的样本数量。