mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
95 lines
3.0 KiB
Python
95 lines
3.0 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import fire
|
|
import zipfile
|
|
import requests
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
from loguru import logger
|
|
|
|
|
|
class GetData:
|
|
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
|
|
|
|
def __init__(self, delete_zip_file=False):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
delete_zip_file : bool, optional
|
|
Whether to delete the zip file, value from True or False, by default False
|
|
"""
|
|
self.delete_zip_file = delete_zip_file
|
|
|
|
def _download_data(self, file_name: str, target_dir: [Path, str]):
|
|
target_dir = Path(target_dir).expanduser()
|
|
target_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
url = f"{self.REMOTE_URL}/{file_name}"
|
|
target_path = target_dir.joinpath(file_name)
|
|
|
|
resp = requests.get(url, stream=True)
|
|
if resp.status_code != 200:
|
|
raise requests.exceptions.HTTPError()
|
|
|
|
chuck_size = 1024
|
|
logger.warning(f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)")
|
|
logger.info(f"{file_name} downloading......")
|
|
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
|
with target_path.open("wb") as fp:
|
|
for chuck in resp.iter_content(chunk_size=chuck_size):
|
|
fp.write(chuck)
|
|
p_bar.update(chuck_size)
|
|
|
|
self._unzip(target_path, target_dir)
|
|
if self.delete_zip_file:
|
|
target_path.unlike()
|
|
|
|
@staticmethod
|
|
def _unzip(file_path: Path, target_dir: Path):
|
|
logger.info(f"{file_path} unzipping......")
|
|
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
|
|
for _file in tqdm(zp.namelist()):
|
|
zp.extract(_file, str(target_dir.resolve()))
|
|
|
|
def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data", version="latest"):
|
|
"""download cn qlib data from remote
|
|
|
|
Parameters
|
|
----------
|
|
target_dir: str
|
|
data save directory
|
|
version: str
|
|
data version, value from [v0, v1, ..., latest], by default latest
|
|
|
|
Examples
|
|
---------
|
|
python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data --version v1
|
|
-------
|
|
|
|
"""
|
|
file_name = f"qlib_data_cn_{version}.zip"
|
|
self._download_data(file_name, target_dir)
|
|
|
|
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
|
"""download cn csv data from remote
|
|
|
|
Parameters
|
|
----------
|
|
target_dir: str
|
|
data save directory
|
|
|
|
Examples
|
|
---------
|
|
python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
|
-------
|
|
|
|
"""
|
|
file_name = "csv_data_cn.zip"
|
|
self._download_data(file_name, target_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(GetData)
|