4.4 保存和加载数据
DGL 建议实现保存和加载函数,以便将处理后的数据缓存到本地磁盘。在大多数情况下,这可以节省大量的处理时间。DGL 提供了四个函数来简化操作
dgl.save_graphs()
和dgl.load_graphs()
:将 DGLGraph 对象和标签保存/加载到本地磁盘。dgl.data.utils.save_info()
和dgl.data.utils.load_info()
:将数据集的有用信息(Pythondict
对象)保存/加载到本地磁盘。
以下示例展示了如何保存和加载图列表及数据集信息。
import os
from dgl import save_graphs, load_graphs
from dgl.data.utils import makedirs, save_info, load_info
def save(self):
# save graphs and labels
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
save_graphs(graph_path, self.graphs, {'labels': self.labels})
# save other information in python dict
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
save_info(info_path, {'num_classes': self.num_classes})
def load(self):
# load processed data from directory `self.save_path`
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
self.graphs, label_dict = load_graphs(graph_path)
self.labels = label_dict['labels']
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
self.num_classes = load_info(info_path)['num_classes']
def has_cache(self):
# check whether there are processed data in `self.save_path`
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
return os.path.exists(graph_path) and os.path.exists(info_path)
请注意,有些情况下不适合保存处理后的数据。例如,在内置数据集 GDELTDataset
中,处理后的数据非常大,因此在 __getitem__(idx)
中处理每个数据示例会更有效率。