GDELTDataset

class dgl.data.GDELTDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

基类:DGLBuiltinDataset

用于基于事件的时间图的 GDELT 数据集

全球事件、语言和调性数据库 (GDELT) 数据集。该数据集包含世界各地发生的事件(即将俄罗斯某一天任何地方举行的每一次抗议活动合并为一个条目)。该数据集包含 2018 年 1 月 1 日至 2018 年 1 月 31 日期间收集的事件(15 分钟时间粒度)。

参考

统计数据

  • 训练样本:2,304

  • 验证样本:288

  • 测试样本:384

参数:
  • mode (str) – 必须是 ('train', 'valid', 'test') 之一。默认值: ‘train’

  • raw_dir (str) – 下载/包含输入数据目录的原始文件目录。默认值: ~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。默认值: False

  • verbose (bool) – 是否打印进度信息。默认值: True。

  • transform (callable, optional) – 一个转换函数,接受一个 DGLGraph 对象并返回一个转换后的版本。DGLGraph 对象在每次访问前都会被转换。

start_time

时间图的开始时间

类型:

int

end_time

时间图的结束时间

类型:

int

is_temporal

数据集是否包含时间图

类型:

bool

示例

>>> # get train, valid, test dataset
>>> train_data = GDELTDataset()
>>> valid_data = GDELTDataset(mode='valid')
>>> test_data = GDELTDataset(mode='test')
>>>
>>> # length of train set
>>> train_size = len(train_data)
>>>
>>> for g in train_data:
....    e_feat = g.edata['rel_type']
....    # your code here
....
>>>
__getitem__(t)[source]

获取包含时间 t + self.start_time 之前事件的图

参数:

t (int) – 时间,其值必须在 [0, self.end_time - self.start_time] 范围内

返回:

图包含

  • edata['rel_type']: 边类型

返回类型:

dgl.DGLGraph

__len__()[source]

数据集中的图数量。

返回类型:

int