AmazonRatingsDataset

dgl.data.AmazonRatingsDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

基类: HeterophilousGraphDataset

来自 'A Critical Look at the Evaluation of GNNs under Heterophily: Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>'\_\_ 论文的 Amazon-ratings 数据集。

该数据集基于 Amazon 产品共同购买数据。节点是产品(书籍、音乐 CD、DVD、VHS 录像带),边连接经常一起购买的产品。任务是预测评论者对产品的平均评分。所有可能的评分值被分为五类。节点特征是产品描述中单词的词嵌入均值。

统计信息

  • 节点数: 24492

  • 边数: 186100

  • 类别数: 5

  • 节点特征维度: 300

  • 10 个训练/验证/测试划分

参数:
  • raw_dir (str, 可选) – 存储处理后数据的原始文件目录。默认值:~/.dgl/

  • force_reload (bool, 可选) – 是否重新下载数据源。默认值:False

  • verbose (bool, 可选) – 是否打印进度信息。默认值:True

  • transform (callable, 可选) – 一个转换函数,它接受一个 DGLGraph 对象并返回转换后的版本。在每次访问之前都会对 DGLGraph 对象进行转换。默认值:None

num_classes

节点类别数

类型:

int

示例

>>> from dgl.data import AmazonRatingsDataset
>>> dataset = AmazonRatingsDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get the first data split
>>> train_mask = g.ndata["train_mask"][:, 0]
>>> val_mask = g.ndata["val_mask"][:, 0]
>>> test_mask = g.ndata["test_mask"][:, 0]
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)

获取索引处的数据对象。

__len__()

数据集中的样本数量。