mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import fire
|
|
import zipfile
|
|
import requests
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
from loguru import logger
|
|
|
|
|
|
class GetData:
|
|
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
|
|
|
|
def __init__(self, delete_zip_file=False):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
delete_zip_file : bool, optional
|
|
Whether to delete the zip file, value from True or False, by default False
|
|
"""
|
|
self.delete_zip_file = delete_zip_file
|
|
|
|
def _download_data(self, file_name: str, target_dir: [Path, str]):
|
|
target_dir = Path(target_dir).expanduser()
|
|
target_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
url = f"{self.REMOTE_URL}/{file_name}"
|
|
target_path = target_dir.joinpath(file_name)
|
|
|
|
resp = requests.get(url, stream=True)
|
|
if resp.status_code != 200:
|
|
raise requests.exceptions.HTTPError()
|
|
|
|
chuck_size = 1024
|
|
logger.info(f"{file_name} downloading......")
|
|
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
|
with target_path.open("wb") as fp:
|
|
for chuck in resp.iter_content(chunk_size=chuck_size):
|
|
fp.write(chuck)
|
|
p_bar.update(chuck_size)
|
|
|
|
self._unzip(target_path, target_dir)
|
|
if self.delete_zip_file:
|
|
target_path.unlike()
|
|
|
|
@staticmethod
|
|
def _unzip(file_path: Path, target_dir: Path):
|
|
logger.info(f"{file_path} unzipping......")
|
|
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
|
|
for _file in tqdm(zp.namelist()):
|
|
zp.extract(_file, str(target_dir.resolve()))
|
|
|
|
def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data"):
|
|
"""download cn qlib data from remote
|
|
|
|
Parameters
|
|
----------
|
|
target_dir: str
|
|
data save directory
|
|
|
|
Examples
|
|
---------
|
|
python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
|
|
-------
|
|
|
|
"""
|
|
file_name = "qlib_data_cn.zip"
|
|
self._download_data(file_name, target_dir)
|
|
|
|
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
|
"""download cn csv data from remote
|
|
|
|
Parameters
|
|
----------
|
|
target_dir: str
|
|
data save directory
|
|
|
|
Examples
|
|
---------
|
|
python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
|
-------
|
|
|
|
"""
|
|
file_name = "csv_data_cn.zip"
|
|
self._download_data(file_name, target_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(GetData)
|