mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 18:11:18 +08:00
init commit
This commit is contained in:
91
scripts/get_data.py
Normal file
91
scripts/get_data.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import fire
|
||||
import zipfile
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class GetData:
|
||||
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
|
||||
|
||||
def __init__(self, delete_zip_file=False):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
delete_zip_file : bool, optional
|
||||
Whether to delete the zip file, value from True or False, by default False
|
||||
"""
|
||||
self.delete_zip_file = delete_zip_file
|
||||
|
||||
def _download_data(self, file_name: str, target_dir: [Path, str]):
|
||||
target_dir = Path(target_dir).expanduser()
|
||||
target_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
url = f"{self.REMOTE_URL}/{file_name}"
|
||||
target_path = target_dir.joinpath(file_name)
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
chuck_size = 1024
|
||||
logger.info(f"{file_name} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chuck in resp.iter_content(chunk_size=chuck_size):
|
||||
fp.write(chuck)
|
||||
p_bar.update(chuck_size)
|
||||
|
||||
self._unzip(target_path, target_dir)
|
||||
if self.delete_zip_file:
|
||||
target_path.unlike()
|
||||
|
||||
@staticmethod
|
||||
def _unzip(file_path: Path, target_dir: Path):
|
||||
logger.info(f"{file_path} unzipping......")
|
||||
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data"):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = "qlib_data_cn.zip"
|
||||
self._download_data(file_name, target_dir)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
"""download cn csv data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = "csv_data_cn.zip"
|
||||
self._download_data(file_name, target_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(GetData)
|
||||
Reference in New Issue
Block a user