From ae300592a0914c7b2c5483f732fd9b5989cafadb Mon Sep 17 00:00:00 2001 From: zhupr Date: Mon, 16 Nov 2020 16:21:59 +0800 Subject: [PATCH] refactor yahoo_collector && support US and 1m --- scripts/data_collector/utils.py | 163 +++++++- scripts/data_collector/yahoo/README.md | 24 +- scripts/data_collector/yahoo/collector.py | 458 ++++++++++++++++------ 3 files changed, 493 insertions(+), 152 deletions(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index d2b3835c1..855569642 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -3,56 +3,69 @@ import re import time +import bisect import pickle import requests +import functools from pathlib import Path import pandas as pd from lxml import etree +from loguru import logger +from yahooquery import Ticker -SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" -CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" -SH600000_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.600000&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" +HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" -CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" +CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20991231" CALENDAR_BENCH_URL_MAP = { - "CSI300": CALENDAR_URL_BASE.format(bench_code="000300"), - "CSI100": CALENDAR_URL_BASE.format(bench_code="000903"), + "CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"), + "CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"), # NOTE: Use the time series of SH600000 as the sequence of all stocks - "ALL": CALENDAR_URL_BASE.format(bench_code="600000"), + "ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"), + # NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks + "US_ALL": "^GSPC", } + _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None +_US_SYMBOLS = None _CALENDAR_MAP = {} # NOTE: Until 2020-10-20 20:00:00 MINIMUM_SYMBOLS_NUM = 3900 -def get_hs_calendar_list(bench_code="CSI300") -> list: +def get_calendar_list(bench_code="CSI300") -> list: """get SH/SZ history calendar list Parameters ---------- bench_code: str - value from ["CSI300", "CSI500", "ALL"] + value from ["CSI300", "CSI500", "ALL", "US_ALL"] Returns ------- history calendar list """ + logger.info(f"get calendar list: {bench_code}......") + def _get_calendar(url): _value_list = requests.get(url).json()["data"]["klines"] return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list)) calendar = _CALENDAR_MAP.get(bench_code, None) if calendar is None: - calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code]) + if bench_code.startswith("US_"): + df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max") + calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist() + else: + calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code]) _CALENDAR_MAP[bench_code] = calendar + logger.info(f"end of get calendar list: {bench_code}.") return calendar @@ -68,13 +81,14 @@ def get_hs_stock_symbols() -> list: def _get_symbol(): _res = set() for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")): - resp = requests.get(SYMBOLS_URL.format(s_type=_k)) + resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k)) _res |= set( map( lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), ) ) + time.sleep(3) return _res if _HS_SYMBOLS is None: @@ -99,6 +113,84 @@ def get_hs_stock_symbols() -> list: return _HS_SYMBOLS +def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: + """get US stock symbols + + Returns + ------- + stock symbols + """ + global _US_SYMBOLS + + @deco_retry + def _get_eastmoney(): + url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12" + resp = requests.get(url) + if resp.status_code != 200: + raise ValueError("request error") + try: + _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] + except Exception as e: + logger.warning(f"request error: {e}") + raise + if len(_symbols) < 8000: + raise ValueError("request error") + return _symbols + + @deco_retry + def _get_nasdaq(): + _res_symbols = [] + for _name in ["otherlisted", "nasdaqtraded"]: + url = f"ftp://ftp.nasdaqtrader.com/SymbolDirectory/{_name}.txt" + df = pd.read_csv(url, sep="|") + df = df.rename(columns={"ACT Symbol": "Symbol"}) + _symbols = df["Symbol"].dropna() + _symbols = _symbols.str.replace("$", "-P", regex=False) + _symbols = _symbols.str.replace(".W", "-WT", regex=False) + _symbols = _symbols.str.replace(".U", "-UN", regex=False) + _symbols = _symbols.str.replace(".R", "-RI", regex=False) + _symbols = _symbols.str.replace(".", "-", regex=False) + _res_symbols += _symbols.unique().tolist() + return _res_symbols + + @deco_retry + def _get_nyse(): + url = "https://www.nyse.com/api/quotes/filter" + _parms = { + "instrumentType": "EQUITY", + "pageNumber": 1, + "sortColumn": "NORMALIZED_TICKER", + "sortOrder": "ASC", + "maxResultsPerPage": 10000, + "filterToken": "", + } + resp = requests.post(url, json=_parms) + if resp.status_code != 200: + raise ValueError("request error") + try: + _symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()] + except Exception as e: + logger.warning(f"request error: {e}") + _symbols = [] + return _symbols + + if _US_SYMBOLS is None: + _all_symbols = _get_eastmoney() + _get_nasdaq() + _get_nyse() + if qlib_data_path is not None: + for _index in ["nasdaq100", "sp500"]: + ins_df = pd.read_csv( + Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"), + sep="\t", + names=["symbol", "start_date", "end_date"], + ) + _all_symbols += ins_df["symbol"].unique().tolist() + _US_SYMBOLS = sorted( + set(map(lambda x: x.replace(".", "-"), filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols))) + ) + + return _US_SYMBOLS + + def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str: """symbol suffix to prefix @@ -137,5 +229,52 @@ def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str: return res.upper() if capital else res.lower() -if __name__ == '__main__': +def deco_retry(retry: int = 5, retry_sleep: int = 3): + def deco_func(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + _retry = 5 if callable(retry) else retry + _result = None + for _i in range(1, _retry + 1): + try: + _result = func(*args, **kwargs) + break + except Exception as e: + logger.warning(f"{func.__name__}: {_i} :{e}") + if _i == _retry: + raise + time.sleep(retry_sleep) + return _result + + return wrapper + + return deco_func(retry) if callable(retry) else deco_func + + +def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1): + """get trading date by shift + + Parameters + ---------- + trading_list: list + trading calendar list + shift : int + shift, default is 1 + + trading_date : pd.Timestamp + trading date + Returns + ------- + + """ + trading_date = pd.Timestamp(trading_date) + left_index = bisect.bisect_left(trading_list, trading_date) + try: + res = trading_list[left_index + shift] + except IndexError: + res = trading_date + return res + + +if __name__ == "__main__": assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index 4f1f4c650..1e65aeaed 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -18,31 +18,29 @@ pip install -r requirements.txt ## Collector Data -### Download data -> Normalize data -> Dump data +### Download data and Normalize data ```bash -python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data +python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d ``` -### Download Data From Yahoo Finance +### Download Data ```bash -python collector.py download_data --source_dir ~/.qlib/stock_data/source +python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d ``` -### Normalize Yahoo Finance Data +### Normalize Data ```bash -python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize +python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN ``` -### Manual Ajust Yahoo Finance Data - +### Help ```bash -python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize +pythono collector.py collector_data --help ``` -### Dump Yahoo Finance Data +## Parameters -```bash -python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data -``` +- interval: 1m or 1d +- region: CN or US diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 9456c6bc3..f374e5fb8 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -1,8 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import abc import sys +import copy import time +import datetime +import importlib from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed @@ -16,30 +20,81 @@ from yahooquery import Ticker CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from dump_bin import DumpData -from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols +from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols -INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" -MIN_NUMBERS_TRADING = 252 / 4 +INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}" +REGION_CN = "CN" +REGION_US = "US" class YahooCollector: - def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=False, max_collector_count=5, delay=0): + START_DATETIME = pd.Timestamp("2000-01-01") + HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 5)) + END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + def __init__( + self, + save_dir: [str, Path], + start=None, + end=None, + interval="1d", + max_workers=4, + max_collector_count=5, + delay=0, + check_data_length: bool = False, + ): + """ + + Parameters + ---------- + save_dir: str + stock save dir + max_workers: int + workers, default 4 + max_collector_count: int + default 5 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1m, 1d], default 1m + start: str + start datetime, default None + end: str + end datetime, default None + check_data_length: bool + check data length, by default False + """ self.save_dir = Path(save_dir).expanduser().resolve() self.save_dir.mkdir(parents=True, exist_ok=True) self._delay = delay - self._stock_list = None + self.stock_list = sorted(set(self.get_stock_list())) self.max_workers = max_workers - self._asynchronous = asynchronous self._max_collector_count = max_collector_count self._mini_symbol_map = {} + self._interval = interval + self._check_small_data = check_data_length + self._start_datetime = pd.Timestamp(start) if start else self.START_DATETIME + self._end_datetime = pd.Timestamp(end) if end else self.END_DATETIME + if self._interval == "1m": + self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME) + elif self._interval == "1d": + self._start_datetime = max(self._start_datetime, self.START_DATETIME) + else: + raise ValueError(f"interval error: {self._interval}") + + self._end_datetime = min(self._end_datetime, self.END_DATETIME) @property - def stock_list(self): - if self._stock_list is None: - self._stock_list = get_hs_stock_symbols() - return self._stock_list + @abc.abstractmethod + def min_numbers_trading(self): + # daily, one year: 252 / 4 + # us 1min, a week: 6.5 * 60 * 5 + # cn 1min, a week: 4 * 60 * 5 + raise NotImplementedError("") + + @abc.abstractmethod + def get_stock_list(self): + raise NotImplementedError("") def _sleep(self): time.sleep(self._delay) @@ -57,63 +112,85 @@ class YahooCollector: if df.empty: raise ValueError("df is empty") - symbol_s = symbol.split(".") - symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}" + symbol = self.normailze_symbol(symbol) stock_path = self.save_dir.joinpath(f"{symbol}.csv") df["symbol"] = symbol df.to_csv(stock_path, index=False) - def _temp_save_small_data(self, symbol, df): - if len(df) <= MIN_NUMBERS_TRADING: - logger.warning(f"the number of trading days of {symbol} is less than {MIN_NUMBERS_TRADING}!") + def _save_small_data(self, symbol, df): + if len(df) <= self.min_numbers_trading: + logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") _temp = self._mini_symbol_map.setdefault(symbol, []) _temp.append(df.copy()) + return symbol else: if symbol in self._mini_symbol_map: self._mini_symbol_map.pop(symbol) + return None + + def _get_from_remote(self, symbol): + if self._interval == "1d": + self._sleep() + try: + resp = Ticker(symbol, asynchronous=False).history( + interval=self._interval, start=self._start_datetime, end=self._end_datetime + ) + except Exception as e: + logger.warning(f"{symbol}-{self._interval}-{self._start_datetime}-{self._end_datetime}:{e}") + resp = None + yield resp + elif self._interval == "1m": + _res = [] + for _start in pd.date_range(self._start_datetime, self._end_datetime + pd.Timedelta(days=-1)): + _end = _start + pd.Timedelta(days=1) + self._sleep() + try: + resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=_start, end=_end) + if isinstance(resp, pd.DataFrame): + _res.append(resp) + except Exception as e: + logger.warning(f"{symbol}-{self._interval}-{_start}-{_end}:{e}") + if _res: + yield pd.concat(_res, sort=False).sort_values(["symbol", "date"]) + else: + yield None + else: + raise ValueError(f"cannot support {self._interval}") def _collector(self, stock_list): error_symbol = [] with ThreadPoolExecutor(max_workers=self.max_workers) as worker: futures = {} - p_bar = tqdm(total=len(stock_list)) - for symbols in [stock_list[i : i + self.max_workers] for i in range(0, len(stock_list), self.max_workers)]: - self._sleep() - resp = Ticker(symbols, asynchronous=self._asynchronous, max_workers=self.max_workers).history( - period="max" - ) - if isinstance(resp, dict): - for symbol, df in resp.items(): - if isinstance(df, pd.DataFrame): - self._temp_save_small_data(self, df) - futures[ - worker.submit( - self.save_stock, symbol, df.reset_index().rename(columns={"index": "date"}) - ) - ] = symbol - else: - error_symbol.append(symbol) - else: - for symbol, df in resp.reset_index().groupby("symbol"): - self._temp_save_small_data(self, df) - futures[worker.submit(self.save_stock, symbol, df)] = symbol - p_bar.update(self.max_workers) - p_bar.close() - - with tqdm(total=len(futures.values())) as p_bar: - for future in as_completed(futures): - try: - future.result() - except Exception as e: - logger.error(e) - error_symbol.append(futures[future]) - p_bar.update() + for _symbol in tqdm(stock_list): + for _resp in self._get_from_remote(_symbol): + if isinstance(_resp, pd.DataFrame): + df = _resp.reset_index() + if self._check_small_data: + if self._save_small_data(_symbol, df) is not None: + error_symbol.append(_symbol) + futures[worker.submit(self.save_stock, _symbol, df)] = _symbol + elif isinstance(_resp, dict): + if "timestamp" in _resp[_symbol]: + logger.warning(_resp[_symbol]) + error_symbol.append(_symbol) + elif _resp is None: + error_symbol.append(_symbol) + else: + if not (("1m data not available for" in _resp) or ("Data doesn't exist for" in _resp)): + error_symbol.append(_symbol) + logger.info("save stock data......") + for future in tqdm(as_completed(futures)): + try: + future.result() + except Exception as e: + logger.error(e) + error_symbol.append(futures[future]) print(error_symbol) logger.info(f"error symbol nums: {len(error_symbol)}") logger.info(f"current get symbol nums: {len(stock_list)}") error_symbol.extend(self._mini_symbol_map.keys()) - return error_symbol + return sorted(set(error_symbol)) def collector_data(self): """collector data""" @@ -126,20 +203,51 @@ class YahooCollector: stock_list = self._collector(stock_list) logger.info(f"{i+1} finish.") for _symbol, _df_list in self._mini_symbol_map.items(): - self.save_stock(_symbol, max(_df_list, key=len)) + self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) - logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}") + logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}") self.download_index_data() + @abc.abstractmethod + def download_index_data(self): + """download index data""" + raise NotImplementedError("rewrite download_index_data") + + @abc.abstractmethod + def normailze_symbol(self, symbol: str): + """normalize symbol""" + raise NotImplementedError("rewrite normalize_symbol") + + +class YahooCollectorCN(YahooCollector): + @property + def min_numbers_trading(self): + if self._interval == "1m": + return 60 * 4 * 5 + elif self._interval == "1d": + return 252 / 4 + + def get_stock_list(self): + logger.info("get HS stock symbos......") + symbols = get_hs_stock_symbols() + logger.info(f"get {len(symbols)} symbols.") + return symbols + def download_index_data(self): # TODO: from MSN + # FIXME: 1m + _format = "%Y%m%d" + _begin = self._start_datetime.strftime(_format) + _end = (self._end_datetime + pd.Timedelta(days=-1)).strftime(_format) for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items(): logger.info(f"get bench data: {_index_name}({_index_code})......") df = pd.DataFrame( map( lambda x: x.split(","), - requests.get(INDEX_BENCH_URL.format(index_code=_index_code)).json()["data"]["klines"], + requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()["data"][ + "klines" + ], ) ) df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] @@ -148,59 +256,71 @@ class YahooCollector: df["adjclose"] = df["close"] df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False) + def normailze_symbol(self, symbol): + symbol_s = symbol.split(".") + symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}" + return symbol -class Run: - def __init__(self, source_dir=None, normalize_dir=None, qlib_dir=None, max_workers=4): + +class YahooCollectorUS(YahooCollector): + @property + def min_numbers_trading(self): + if self._interval == "1m": + return 60 * 6.5 * 5 + elif self._interval == "1d": + return 252 / 4 + + def get_stock_list(self): + logger.info("get US stock symbols......") + symbols = get_us_stock_symbols(qlib_data_path="/data1/data/yahoo_staock_data/backup/us_1d_qlib") + [ + "^GSPC", + "^NDX", + "^DJI", + ] + logger.info(f"get {len(symbols)} symbols.") + return symbols + + def download_index_data(self): + pass + + def normailze_symbol(self, symbol): + return symbol.upper() + + +class YahooNormalize: + COLUMNS = ["open", "close", "high", "low", "volume"] + + def __init__(self, source_dir: [str, Path], target_dir: [str, Path], max_workers: int = 16): """ Parameters ---------- - source_dir: str - The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" - normalize_dir: str - Directory for normalize data, default "Path(__file__).parent/normalize" - qlib_dir: str - qlib data dir; usage of provider_uri, default "Path(__file__).parent/qlib_data" + source_dir: str or Path + The directory where the raw data collected from the Internet is saved + target_dir: str or Path + Directory for normalize data max_workers: int - Concurrent number, default is 4 + Concurrent number, default is 16 """ - if source_dir is None: - source_dir = CUR_DIR.joinpath("source") - self.source_dir = Path(source_dir).expanduser().resolve() - self.source_dir.mkdir(parents=True, exist_ok=True) - - if normalize_dir is None: - normalize_dir = CUR_DIR.joinpath("normalize") - self.normalize_dir = Path(normalize_dir).expanduser().resolve() - self.normalize_dir.mkdir(parents=True, exist_ok=True) - - if qlib_dir is None: - qlib_dir = CUR_DIR.joinpath("qlib_data") - self.qlib_dir = Path(qlib_dir).expanduser().resolve() - self.qlib_dir.mkdir(parents=True, exist_ok=True) - - self.max_workers = max_workers + if not (source_dir and target_dir): + raise ValueError("source_dir and target_dir cannot be None") + self._source_dir = Path(source_dir).expanduser() + self._target_dir = Path(target_dir).expanduser() + self._max_workers = max_workers + self._calendar_list = self._get_calendar_list() def normalize_data(self): - """normalize data + logger.info("normalize data......") - Examples - --------- - $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize - - """ - - def _normalize(file_path: Path): - columns = ["open", "close", "high", "low", "volume"] - df = pd.read_csv(file_path) + def _normalize(source_path: Path): + columns = copy.deepcopy(self.COLUMNS) + df = pd.read_csv(source_path) df.set_index("date", inplace=True) df.index = pd.to_datetime(df.index) df = df[~df.index.duplicated(keep="first")] - - # using China stock market data calendar - df = df.reindex(pd.Index(get_calendar_list("ALL"))) + if self._calendar_list is not None: + df = df.reindex(pd.DataFrame(index=self._calendar_list).loc[df.index.min() : df.index.max()].index) df.sort_index(inplace=True) - df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {"symbol"}] = np.nan df["factor"] = df["adjclose"] / df["close"] for _col in columns: @@ -213,22 +333,17 @@ class Run: columns += ["change", "factor"] df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan df.index.names = ["date"] - df.loc[:, columns].to_csv(self.normalize_dir.joinpath(file_path.name)) + df.loc[:, columns].to_csv(self._target_dir.joinpath(source_path.name)) - with ThreadPoolExecutor(max_workers=self.max_workers) as worker: - file_list = list(self.source_dir.glob("*.csv")) + with ThreadPoolExecutor(max_workers=self._max_workers) as worker: + file_list = list(self._source_dir.glob("*.csv")) with tqdm(total=len(file_list)) as p_bar: for _ in worker.map(_normalize, file_list): p_bar.update() def manual_adj_data(self): - """manual adjust data - - Examples - -------- - $ python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize - - """ + """adjust data""" + logger.info("manual adjust data......") def _adj(file_path: Path): df = pd.read_csv(file_path) @@ -244,59 +359,148 @@ class Run: df[_col] = df[_col] / _close else: pass - df.reset_index().to_csv(self.normalize_dir.joinpath(file_path.name), index=False) + df.reset_index().to_csv(self._target_dir.joinpath(file_path.name), index=False) - with ThreadPoolExecutor(max_workers=self.max_workers) as worker: - file_list = list(self.normalize_dir.glob("*.csv")) + with ThreadPoolExecutor(max_workers=self._max_workers) as worker: + file_list = list(self._target_dir.glob("*.csv")) with tqdm(total=len(file_list)) as p_bar: for _ in worker.map(_adj, file_list): p_bar.update() - def dump_data(self): - """dump yahoo data + def normalize(self): + self.normalize_data() + self.manual_adj_data() - Examples - --------- - $ python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data + @abc.abstractmethod + def _get_calendar_list(self): + """Get benchmark calendar""" + raise NotImplementedError("") + +class YahooNormalizeUS(YahooNormalize): + def _get_calendar_list(self): + # TODO: from MSN + return get_calendar_list("US_ALL") + + +class YahooNormalizeCN(YahooNormalize): + def _get_calendar_list(self): + # TODO: from MSN + return get_calendar_list("ALL") + + +class Run: + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): """ - DumpData(csv_path=self.normalize_dir, qlib_dir=self.qlib_dir, works=self.max_workers).dump( - include_fields="close,open,high,low,volume,change,factor" - ) - def download_data(self, asynchronous=False, max_collector_count=5, delay=0): + Parameters + ---------- + source_dir: str + The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" + max_workers: int + Concurrent number, default is 4 + region: str + region, value from ["CN", "US"], default "CN" + """ + if source_dir is None: + source_dir = CUR_DIR.joinpath("source") + self.source_dir = Path(source_dir).expanduser().resolve() + self.source_dir.mkdir(parents=True, exist_ok=True) + + if normalize_dir is None: + normalize_dir = CUR_DIR.joinpath("normalize") + self.normalize_dir = Path(normalize_dir).expanduser().resolve() + self.normalize_dir.mkdir(parents=True, exist_ok=True) + + self._cur_module = importlib.import_module("collector") + self.max_workers = max_workers + self.region = region + + def download_data( + self, max_collector_count=5, delay=0, start=None, end=None, interval="1d", check_data_length=True + ): """download data from Internet + Parameters + ---------- + max_collector_count: int + default 5 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1m, 1d], default 1m + start: str + start datetime, default "2000-01-01" + end: str + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` + check_data_length: bool + check data length, by default True + Examples --------- - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source - + # get daily data + $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + # get 1m data + $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ - YahooCollector( + + _class = getattr(self._cur_module, f"YahooCollector{self.region.upper()}") + _class( self.source_dir, max_workers=self.max_workers, - asynchronous=asynchronous, max_collector_count=max_collector_count, delay=delay, + start=start, + end=end, + interval=interval, + check_data_length=check_data_length, ).collector_data() - def download_index_data(self): - YahooCollector(self.source_dir).download_index_data() + def normalize_data(self): + """normalize data - def download_bench_data(self): - """download bench stock data(SH000300)""" + Examples + --------- + $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN + """ + _class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}") + _class(self.source_dir, self.normalize_dir, self.max_workers).normalize() - def collector_data(self): - """download -> normalize -> dump data + def collector_data( + self, max_collector_count=5, delay=0, start=None, end=None, interval="1d", check_data_length=False + ): + """download -> normalize + + Parameters + ---------- + max_collector_count: int + default 5 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1m, 1d], default 1m + start: str + start datetime, default "2000-01-01" + end: str + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` + check_data_length: bool + check data length, by default False Examples ------- - $ python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data + python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d """ - self.download_data() + self.download_data( + max_collector_count=max_collector_count, + delay=delay, + start=start, + end=end, + interval=interval, + check_data_length=check_data_length, + ) self.normalize_data() - self.manual_adj_data() - self.dump_data() if __name__ == "__main__":