TUDataset
- class dgl.data.TUDataset(name, raw_dir=None, force_reload=False, verbose=False, transform=None)[源码]
基类:
DGLBuiltinDataset
TUDataset 包含许多用于图分类的图核数据集。
- 参数:
name (str) – 数据集名称,例如
ENZYMES
,DD
,COLLAB
,MUTAG
,可以是 https://chrsmrrs.github.io/datasets/docs/datasets/ 上的数据集名称。transform (callable, optional) – 一个转换函数,它接收一个
DGLGraph
对象并返回其转换后的版本。每次访问前都会对DGLGraph
对象进行转换。
注意事项
重要提示: 某些数据集的图中存在重复边,例如
IMDB-BINARY
中的边都是重复的。DGL 忠实地保留了原始数据中的重复边。PyTorch Geometric 等其他框架默认会移除重复边。您可以使用dgl.to_simple()
来移除重复边。图可能包含节点标签、节点属性、边标签和边属性,具体取决于不同的数据集。
标签被映射到 \(\lbrace 0,\cdots,n-1 \rbrace\),其中 \(n\) 是标签的数量(某些数据集的原始标签是 \(\lbrace -1, 1 \rbrace\),将被映射到 \(\lbrace 0, 1 \rbrace\))。在之前的版本中,会加上最小标签值,因此 \(\lbrace -1, 1 \rbrace\) 会被映射到 \(\lbrace 0, 2 \rbrace\)。
数据集按标签对图进行排序。在手动划分训练/验证集之前,最好进行洗牌。
示例
>>> data = TUDataset('DD')
数据集实例是可迭代的
>>> len(data) 1178 >>> g, label = data[1024] >>> g Graph(num_nodes=88, num_edges=410, ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'node_labels': Scheme(shape=(1,), dtype=torch.int64)} edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}) >>> label tensor([1])
将图和标签批量处理以进行小批量训练
>>> graphs, labels = zip(*[data[i] for i in range(16)]) >>> batched_graphs = dgl.batch(graphs) >>> batched_labels = torch.tensor(labels) >>> batched_graphs Graph(num_nodes=9539, num_edges=47382, ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)} edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
- __getitem__(idx)[源码]
获取第 idx 个样本。
- 参数:
idx (int) – 样本索引。
- 返回:
图,其节点特征存储在
feat
字段中,节点标签(如果可用)存储在node_labels
中。以及其标签。- 返回类型:
(
dgl.DGLGraph
, Tensor)