4.2 下载原始数据 (可选)
如果数据集已在本地磁盘中,请确保它位于目录 raw_dir
中。如果想在任何地方运行代码而无需手动下载数据并移动到正确目录,可以通过实现函数 download()
来自动完成。
如果数据集是一个 zip 文件,让 MyDataset
继承自 dgl.data.DGLBuiltinDataset
类,该类为我们处理 zip 文件提取。否则,需要像 QM7bDataset
那样实现 download()
。
import os
from dgl.data.utils import download
def download(self):
# path to store the file
file_path = os.path.join(self.raw_dir, self.name + '.mat')
# download file
download(self.url, path=file_path)
上面的代码将一个 .mat 文件下载到目录 self.raw_dir
中。如果文件是 .gz, .tar, .tar.gz 或 .tgz 文件,使用 extract_archive()
函数进行提取。以下代码展示了如何在 BitcoinOTCDataset
中下载一个 .gz 文件。
from dgl.data.utils import download, check_sha1
def download(self):
# path to store the file
# make sure to use the same suffix as the original file name's
gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
# download file
download(self.url, path=gz_file_path)
# check SHA-1
if not check_sha1(gz_file_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.'
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))
# extract file to directory `self.name` under `self.raw_dir`
self._extract_gz(gz_file_path, self.raw_path)
上面的代码将文件提取到 self.raw_dir
下的目录 self.name
中。如果类继承自 dgl.data.DGLBuiltinDataset
来处理 zip 文件,它也会将文件提取到目录 self.name
中。
(可选)可以检查下载文件的 SHA-1 字符串,如上面的示例所示,以防作者某天更改了远程服务器上的文件。