SquirrelDataset

class dgl.data.SquirrelDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[源码]

基类: GeomGCNDataset

来自 Multi-scale Attributed Node Embedding 的关于松鼠的维基百科页面-页面网络,后经 Geom-GCN: Geometric Graph Convolutional Networks 修改。

节点表示英文维基百科中的文章,边反映它们之间的相互链接。节点特征表示文章中特定名词的存在性。节点根据其平均月流量分为 5 类。

统计

  • 节点数: 5201

  • 边数: 217073

  • 类别数: 5

  • 10 个训练/验证/测试划分

    • 训练集: 2496

    • 验证集: 1664

    • 测试集: 1041

参数:
  • raw_dir (str, 可选) – 存储处理后的数据的原始文件目录。默认值: ~/.dgl/

  • force_reload (bool, 可选) – 是否重新下载数据源。默认值: False

  • verbose (bool, 可选) – 是否打印进度信息。默认值: True

  • transform (callable, 可选) – 一个转换函数,接受一个 DGLGraph 对象并返回一个转换后的版本。DGLGraph 对象将在每次访问前进行转换。默认值: None

num_classes

节点类别数

类型:

int

注意事项

图的边不包含双向连接。

示例

>>> from dgl.data import SquirrelDataset
>>> dataset = SquirrelDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get data split
>>> train_mask = g.ndata["train_mask"]
>>> val_mask = g.ndata["val_mask"]
>>> test_mask = g.ndata["test_mask"]
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)

获取索引处的数据对象。

__len__()

数据集中示例的数量。