SSTDataset

class dgl.data.SSTDataset(mode='train', glove_embed_file=None, vocab_file=None, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

基类: DGLBuiltinDataset

Stanford 情感树库数据集。

每个样本是句子的成分树。叶节点表示词语。词语的整数值存储在x特征字段中。非叶节点在x字段中有一个特殊值PAD_WORD。每个节点也有情感标注:5个类别(非常消极、消极、中性、积极和非常积极)。情感标签的整数值存储在y特征字段中。官方网站:http://nlp.stanford.edu/sentiment/index.html

统计信息

  • 训练样本数:8,544

  • 开发样本数:1,101

  • 测试样本数:2,210

  • 每个节点的类别数:5

参数:
  • mode (str, optional) – 应为 ['train', 'dev', 'test', 'tiny'] 中的一个。默认值:train

  • glove_embed_file (str, optional) – 预训练 glove 嵌入文件的路径。默认值:None

  • vocab_file (str, optional) – 可选的词汇表文件。如果未提供,则使用默认词汇表文件。默认值:None

  • raw_dir (str) – 下载/包含输入数据目录的原始文件目录。默认值:~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。默认值:False

  • verbose (bool) – 是否打印进度信息。默认值:True。

  • transform (callable, optional) – 一个转换函数,接受一个 DGLGraph 对象并返回一个转换后的版本。DGLGraph 对象将在每次访问之前进行转换。

vocab

数据集的词汇表

类型:

OrderedDict

num_classes

每个节点的类别数

类型:

int

pretrained_emb

对应词汇表的预训练 glove 嵌入。

类型:

Tensor

vocab_size

词汇表大小

类型:

int

注意事项

所有样本将首先加载并在内存中进行预处理。

示例

>>> # get dataset
>>> train_data = SSTDataset()
>>> dev_data = SSTDataset(mode='dev')
>>> test_data = SSTDataset(mode='test')
>>> tiny_data = SSTDataset(mode='tiny')
>>>
>>> len(train_data)
8544
>>> train_data.num_classes
5
>>> glove_embed = train_data.pretrained_emb
>>> train_data.vocab_size
19536
>>> train_data[0]
Graph(num_nodes=71, num_edges=70,
  ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
  edata_schemes={})
>>> for tree in train_data:
...     input_ids = tree.ndata['x']
...     labels = tree.ndata['y']
...     mask = tree.ndata['mask']
...     # your code here
__getitem__(idx)[source]

按索引获取图

参数:

idx (int)

返回:

图结构、每个节点的词语 ID、节点标签和掩码。

  • ndata['x']:节点的词语 ID

  • ndata['y']: 节点的标签

  • ndata['mask']:如果节点是叶节点则为 1,否则为 0

返回类型:

dgl.DGLGraph

__len__()[source]

数据集中图的数量。