TUDataset

class dgl.data.TUDataset(name, raw_dir=None, force_reload=False, verbose=False, transform=None)[源码]

基类: DGLBuiltinDataset

TUDataset 包含许多用于图分类的图核数据集。

参数:
max_num_node

最大节点数

类型:

int

num_classes

类别数

类型:

int

num_labels

(已废弃,请改用 num_classes) 类别数

类型:

int

注意事项

重要提示: 某些数据集的图中存在重复边,例如 IMDB-BINARY 中的边都是重复的。DGL 忠实地保留了原始数据中的重复边。PyTorch Geometric 等其他框架默认会移除重复边。您可以使用 dgl.to_simple() 来移除重复边。

图可能包含节点标签、节点属性、边标签和边属性,具体取决于不同的数据集。

标签被映射到 \(\lbrace 0,\cdots,n-1 \rbrace\),其中 \(n\) 是标签的数量(某些数据集的原始标签是 \(\lbrace -1, 1 \rbrace\),将被映射到 \(\lbrace 0, 1 \rbrace\))。在之前的版本中,会加上最小标签值,因此 \(\lbrace -1, 1 \rbrace\) 会被映射到 \(\lbrace 0, 2 \rbrace\)

数据集按标签对图进行排序。在手动划分训练/验证集之前,最好进行洗牌。

示例

>>> data = TUDataset('DD')

数据集实例是可迭代的

>>> len(data)
1178
>>> g, label = data[1024]
>>> g
Graph(num_nodes=88, num_edges=410,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'node_labels': Scheme(shape=(1,), dtype=torch.int64)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
>>> label
tensor([1])

将图和标签批量处理以进行小批量训练

>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
>>> batched_labels = torch.tensor(labels)
>>> batched_graphs
Graph(num_nodes=9539, num_edges=47382,
      ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
__getitem__(idx)[源码]

获取第 idx 个样本。

参数:

idx (int) – 样本索引。

返回:

图,其节点特征存储在 feat 字段中,节点标签(如果可用)存储在 node_labels 中。以及其标签。

返回类型:

(dgl.DGLGraph, Tensor)

__len__()[源码]

返回数据集中图的数量。