SSTDataset
- class dgl.data.SSTDataset(mode='train', glove_embed_file=None, vocab_file=None, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]
基类:
DGLBuiltinDataset
Stanford 情感树库数据集。
每个样本是句子的成分树。叶节点表示词语。词语的整数值存储在
x
特征字段中。非叶节点在x
字段中有一个特殊值PAD_WORD
。每个节点也有情感标注:5个类别(非常消极、消极、中性、积极和非常积极)。情感标签的整数值存储在y
特征字段中。官方网站:http://nlp.stanford.edu/sentiment/index.html统计信息
训练样本数:8,544
开发样本数:1,101
测试样本数:2,210
每个节点的类别数:5
- 参数:
mode (str, optional) – 应为 ['train', 'dev', 'test', 'tiny'] 中的一个。默认值:train
glove_embed_file (str, optional) – 预训练 glove 嵌入文件的路径。默认值:None
vocab_file (str, optional) – 可选的词汇表文件。如果未提供,则使用默认词汇表文件。默认值:None
raw_dir (str) – 下载/包含输入数据目录的原始文件目录。默认值:~/.dgl/
force_reload (bool) – 是否重新加载数据集。默认值:False
verbose (bool) – 是否打印进度信息。默认值:True。
transform (callable, optional) – 一个转换函数,接受一个
DGLGraph
对象并返回一个转换后的版本。DGLGraph
对象将在每次访问之前进行转换。
- vocab
数据集的词汇表
- 类型:
OrderedDict
- pretrained_emb
对应词汇表的预训练 glove 嵌入。
- 类型:
Tensor
注意事项
所有样本将首先加载并在内存中进行预处理。
示例
>>> # get dataset >>> train_data = SSTDataset() >>> dev_data = SSTDataset(mode='dev') >>> test_data = SSTDataset(mode='test') >>> tiny_data = SSTDataset(mode='tiny') >>> >>> len(train_data) 8544 >>> train_data.num_classes 5 >>> glove_embed = train_data.pretrained_emb >>> train_data.vocab_size 19536 >>> train_data[0] Graph(num_nodes=71, num_edges=70, ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)} edata_schemes={}) >>> for tree in train_data: ... input_ids = tree.ndata['x'] ... labels = tree.ndata['y'] ... mask = tree.ndata['mask'] ... # your code here