RomanEmpireDataset

class dgl.data.RomanEmpireDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[源码]

基类:HeterophilousGraphDataset

来自论文《A Critical Look at the Evaluation of GNNs under Heterophily: Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>》的 Roman-empire 数据集。

该数据集基于英文维基百科中关于罗马帝国的文章,选择这篇文章是因为它是维基百科中最长的文章之一。图中的每个节点对应文本中的一个(非唯一)词。因此,图中的节点数等于文章的长度。当满足以下两个条件中至少一个时,两个词通过一条边连接:这些词在文本中相互紧随,或者这些词在句子的依存树中相连(一个词依赖于另一个词)。因此,该图是一个链式图,带有对应于词之间句法依存关系的其他快捷边。节点的类别是其句法角色(选择了 17 个最常见的角色作为唯一类别,所有其他角色被归为第 18 类)。节点特征是词嵌入。

统计信息

  • 节点数:22662

  • 边数:65854

  • 类别数:18

  • 节点特征:300

  • 10 个训练/验证/测试划分

参数:
  • raw_dir (str, 可选) – 存储处理后数据的原始文件目录。默认值:~/.dgl/

  • force_reload (bool, 可选) – 是否重新下载数据源。默认值:False

  • verbose (bool, 可选) – 是否打印进度信息。默认值:True

  • transform (callable, 可选) – 一个转换函数,接受一个 DGLGraph 对象并返回一个转换后的版本。每次访问 DGLGraph 对象之前都会对其进行转换。默认值:None

num_classes

节点类别数

类型:

int

示例

>>> from dgl.data import RomanEmpireDataset
>>> dataset = RomanEmpireDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get the first data split
>>> train_mask = g.ndata["train_mask"][:, 0]
>>> val_mask = g.ndata["val_mask"][:, 0]
>>> test_mask = g.ndata["test_mask"][:, 0]
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)

获取指定索引处的数据对象。

__len__()

数据集中样本的数量。