ICEWS18Dataset

class dgl.data.ICEWS18Dataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

基类: DGLBuiltinDataset

用于时序图的 ICEWS18 数据集

综合危机预警系统 (ICEWS18)

事件数据包含社会政治参与者(即,个体、团体、部门和国家之间的合作或敌对行动)之间的编码交互。该数据集包含从 2018 年 1 月 1 日到 2018 年 10 月 31 日的事件(24 小时时间粒度)。

参考文献

统计信息:

  • 训练样本: 240

  • 验证样本: 30

  • 测试样本: 34

  • 每图节点数: 23033

参数:
  • mode (str) – 加载训练/验证/测试数据。必须是 [‘train’, ‘valid’, ‘test’] 之一

  • raw_dir (str) – 下载/包含输入数据目录的原始文件目录。默认值: ~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。默认值: False

  • verbose (bool) – 是否打印进度信息。默认值: True。

  • transform (callable, optional) – 一个转换函数,接受一个 DGLGraph 对象并返回一个转换后的版本。DGLGraph 对象在每次访问前都会被转换。

is_temporal

数据集是否包含时序图

类型:

bool

示例

>>> # get train, valid, test set
>>> train_data = ICEWS18Dataset()
>>> valid_data = ICEWS18Dataset(mode='valid')
>>> test_data = ICEWS18Dataset(mode='test')
>>>
>>> train_size = len(train_data)
>>> for g in train_data:
....    e_feat = g.edata['rel_type']
....    # your code here
....
>>>
__getitem__(idx)[source]

按索引获取图

参数:

idx (int) – 项索引

返回:

包含的图

  • edata['rel_type']: 边类型

返回类型:

dgl.DGLGraph

__len__()[source]

数据集中图的数量。

返回类型:

int