From c825c99c2ceb4b0dd23ad121231d7a8399e6b9de Mon Sep 17 00:00:00 2001 From: zhupr Date: Sat, 26 Sep 2020 23:36:43 +0800 Subject: [PATCH] Fix data collector --- scripts/data_collector/yahoo/README.md | 8 ++++ scripts/data_collector/yahoo/collector.py | 47 +++++++++++++++++++++-- scripts/get_data.py | 4 +- tests/dataset_tests/test_dataset.py | 20 +++++++--- 4 files changed, 69 insertions(+), 10 deletions(-) diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index 958d74854..956a736ea 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -1,5 +1,13 @@ # Collect Data From Yahoo Finance +> *Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)* + + +> **Examples of abnormal data** + +- [SH600000](https://finance.yahoo.com/quote/600000.SS/history?period1=1147046400&period2=1147478400&interval=1d&filter=history&frequency=1d) +- [SH600018](https://finance.yahoo.com/quote/600018.SS/history?period1=1158883200&period2=1161907200&interval=1d&filter=history&frequency=1d) + ## Requirements ```bash diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index b652311a6..96ea8d632 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import sys +import time from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed @@ -19,17 +20,20 @@ from dump_bin import DumpData from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" +MIN_NUMBERS_TRADING = 252 / 4 class YahooCollector: - def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=True, max_collector_count=3): + def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=False, max_collector_count=5, delay=0): self.save_dir = Path(save_dir).expanduser().resolve() self.save_dir.mkdir(parents=True, exist_ok=True) + self._delay = delay self._stock_list = None self.max_workers = max_workers self._asynchronous = asynchronous self._max_collector_count = max_collector_count + self._mini_symbol_map = {} @property def stock_list(self): @@ -37,6 +41,9 @@ class YahooCollector: self._stock_list = get_hs_stock_symbols() return self._stock_list + def _sleep(self): + time.sleep(self._delay) + def save_stock(self, symbol, df: pd.DataFrame): """save stock data to file @@ -56,6 +63,15 @@ class YahooCollector: df["symbol"] = symbol df.to_csv(stock_path, index=False) + def _temp_save_small_data(self, symbol, df): + if len(df) <= MIN_NUMBERS_TRADING: + logger.warning(f"the number of trading days of {symbol} is less than {MIN_NUMBERS_TRADING}!") + _temp = self._mini_symbol_map.setdefault(symbol, []) + _temp.append(df.copy()) + else: + if symbol in self._mini_symbol_map: + self._mini_symbol_map.pop(symbol) + def _collector(self, stock_list): error_symbol = [] @@ -63,12 +79,14 @@ class YahooCollector: futures = {} p_bar = tqdm(total=len(stock_list)) for symbols in [stock_list[i : i + self.max_workers] for i in range(0, len(stock_list), self.max_workers)]: + self._sleep() resp = Ticker(symbols, asynchronous=self._asynchronous, max_workers=self.max_workers).history( period="max" ) if isinstance(resp, dict): for symbol, df in resp.items(): if isinstance(df, pd.DataFrame): + self._temp_save_small_data(self, df) futures[ worker.submit( self.save_stock, symbol, df.reset_index().rename(columns={"index": "date"}) @@ -78,6 +96,7 @@ class YahooCollector: error_symbol.append(symbol) else: for symbol, df in resp.reset_index().groupby("symbol"): + self._temp_save_small_data(self, df) futures[worker.submit(self.save_stock, symbol, df)] = symbol p_bar.update(self.max_workers) p_bar.close() @@ -93,6 +112,7 @@ class YahooCollector: print(error_symbol) logger.info(f"error symbol nums: {len(error_symbol)}") logger.info(f"current get symbol nums: {len(stock_list)}") + error_symbol.extend(self._mini_symbol_map.keys()) return error_symbol def collector_data(self): @@ -107,7 +127,14 @@ class YahooCollector: logger.info(f"getting data: {i+1}") stock_list = self._collector(stock_list) logger.info(f"{i+1} finish.") + for _symbol, _df_list in self._mini_symbol_map.items(): + self.save_stock(_symbol, max(_df_list, key=len)) + logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}") + + self.download_csi300_data() + + def download_csi300_data(self): # TODO: from MSN logger.info(f"get bench data: csi300(SH000300)......") df = pd.DataFrame(map(lambda x: x.split(","), requests.get(CSI300_BENCH_URL).json()["data"]["klines"])) @@ -164,6 +191,7 @@ class Run: df = pd.read_csv(file_path) df.set_index("date", inplace=True) df.index = pd.to_datetime(df.index) + df = df[~df.index.duplicated(keep="first")] # using China stock market data calendar df = df.reindex(pd.Index(get_calendar_list())) @@ -232,7 +260,7 @@ class Run: include_fields="close,open,high,low,volume,change,factor" ) - def download_data(self): + def download_data(self, asynchronous=False, max_collector_count=5, delay=0): """download data from Internet Examples @@ -240,7 +268,20 @@ class Run: $ python collector.py download_data --source_dir ~/.qlib/stock_data/source """ - YahooCollector(self.source_dir, max_workers=self.max_workers).collector_data() + YahooCollector( + self.source_dir, + max_workers=self.max_workers, + asynchronous=asynchronous, + max_collector_count=max_collector_count, + delay=delay, + ).collector_data() + + def download_csi300_data(self): + YahooCollector(self.source_dir).download_csi300_data() + + def download_bench_data(self): + """download bench stock data(SH000300) + """ def collector_data(self): """download -> normalize -> dump data diff --git a/scripts/get_data.py b/scripts/get_data.py index d27953631..d20a251ed 100644 --- a/scripts/get_data.py +++ b/scripts/get_data.py @@ -53,7 +53,7 @@ class GetData: for _file in tqdm(zp.namelist()): zp.extract(_file, str(target_dir.resolve())) - def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data", version="v1"): + def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data", version="latest"): """download cn qlib data from remote Parameters @@ -61,7 +61,7 @@ class GetData: target_dir: str data save directory version: str - data version, value from [v0, v1], by default v1 + data version, value from [v0, v1, ..., latest], by default latest Examples --------- diff --git a/tests/dataset_tests/test_dataset.py b/tests/dataset_tests/test_dataset.py index 4a62ba79d..b22eb81ed 100644 --- a/tests/dataset_tests/test_dataset.py +++ b/tests/dataset_tests/test_dataset.py @@ -16,15 +16,25 @@ class TestDataset(unittest.TestCase): close_p = D.features(D.instruments('csi300'), ['$close']) size = close_p.groupby('datetime').size() cnt = close_p.groupby('datetime').count() + size_desc = size.describe(percentiles=np.arange(0.1, 0.9, 0.1)) + cnt_desc = cnt.describe(percentiles=np.arange(0.1, 0.9, 0.1)) - print(size.describe(percentiles=np.arange(0.1, 0.9, 0.1))) - print(cnt.describe(percentiles=np.arange(0.1, 0.9, 0.1))) - # TODO: assert + print(size_desc) + print(cnt_desc) + + self.assertLessEqual(size_desc.loc["max"][0], 305, "Excessive number of CSI300 constituent stocks") + self.assertLessEqual(size_desc.loc["80%"][0], 290, "Insufficient number of CSI300 constituent stocks") + + self.assertLessEqual(cnt_desc.loc["max"][0], 305, "Excessive number of CSI300 constituent stocks") + self.assertEqual(cnt_desc.loc["80%"][0], 300, "Insufficient number of CSI300 constituent stocks") def testClose(self): close_p = D.features(D.instruments('csi300'), ['Ref($close, 1)/$close - 1']) - print(close_p.describe(percentiles=np.arange(0.1, 0.9, 0.1))) - # TODO: assert + close_desc = close_p.describe(percentiles=np.arange(0.1, 0.9, 0.1)) + print(close_desc) + self.assertLessEqual(abs(close_desc.loc["80%"][0]), 0.1, "Close value is abnormal") + self.assertLessEqual(abs(close_desc.loc["max"][0]), 0.2, "Close value is abnormal") + self.assertGreaterEqual(abs(close_desc.loc["min"][0]), -0.2, "Close value is abnormal") if __name__ == '__main__':