From fb54d08236db58a48fc4638b792b71cad40dd9bf Mon Sep 17 00:00:00 2001 From: Linlang Date: Thu, 23 May 2024 16:01:57 +0800 Subject: [PATCH] optimize get_data code --- qlib/tests/data.py | 44 +++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/qlib/tests/data.py b/qlib/tests/data.py index ed99e7cdf..d58d58bb0 100644 --- a/qlib/tests/data.py +++ b/qlib/tests/data.py @@ -31,6 +31,34 @@ class GetData: def merge_remote_url(self, file_name: str): return f"{self.REMOTE_URL}/{file_name}" if "/" in file_name else f"{self.REMOTE_URL}/v0/{file_name}" + def download(self, url: str, target_path: [Path, str]): + """ + Download a file from the specified url. + + Parameters + ---------- + url: str + The url of the data. + target_path: str + The location where the data is saved, including the file name. + """ + file_name = str(target_path).split("/")[-1] + resp = requests.get(url, stream=True, timeout=60) + resp.raise_for_status() + if resp.status_code != 200: + raise requests.exceptions.HTTPError() + + chunk_size = 1024 + logger.warning( + f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)" + ) + logger.info(f"{os.path.basename(file_name)} downloading......") + with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar: + with target_path.open("wb") as fp: + for chunk in resp.iter_content(chunk_size=chunk_size): + fp.write(chunk) + p_bar.update(chunk_size) + def download_data(self, file_name: str, target_dir: [Path, str], delete_old: bool = True): """ Download the specified file to the target folder. @@ -64,21 +92,7 @@ class GetData: target_path = target_dir.joinpath(_target_file_name) url = self.merge_remote_url(file_name) - resp = requests.get(url, stream=True, timeout=60) - resp.raise_for_status() - if resp.status_code != 200: - raise requests.exceptions.HTTPError() - - chunk_size = 1024 - logger.warning( - f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)" - ) - logger.info(f"{os.path.basename(file_name)} downloading......") - with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar: - with target_path.open("wb") as fp: - for chunk in resp.iter_content(chunk_size=chunk_size): - fp.write(chunk) - p_bar.update(chunk_size) + self.download(url=url, target_path=target_path) self._unzip(target_path, target_dir, delete_old) if self.delete_zip_file: