diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index f374e5fb8..f5b62ded1 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -17,6 +17,7 @@ import pandas as pd from tqdm import tqdm from loguru import logger from yahooquery import Ticker +from dateutil.tz import tzlocal CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) @@ -42,6 +43,7 @@ class YahooCollector: max_collector_count=5, delay=0, check_data_length: bool = False, + limit_nums: int = None, ): """ @@ -63,18 +65,25 @@ class YahooCollector: end datetime, default None check_data_length: bool check data length, by default False + limit_nums: int + using for debug, by default None """ self.save_dir = Path(save_dir).expanduser().resolve() self.save_dir.mkdir(parents=True, exist_ok=True) self._delay = delay self.stock_list = sorted(set(self.get_stock_list())) + if limit_nums is not None: + try: + self.stock_list = self.stock_list[: int(limit_nums)] + except Exception as e: + logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") self.max_workers = max_workers self._max_collector_count = max_collector_count self._mini_symbol_map = {} self._interval = interval self._check_small_data = check_data_length - self._start_datetime = pd.Timestamp(start) if start else self.START_DATETIME - self._end_datetime = pd.Timestamp(end) if end else self.END_DATETIME + self._start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME + self._end_datetime = pd.Timestamp(str(end)) if end else self.END_DATETIME if self._interval == "1m": self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME) elif self._interval == "1d": @@ -82,7 +91,8 @@ class YahooCollector: else: raise ValueError(f"interval error: {self._interval}") - self._end_datetime = min(self._end_datetime, self.END_DATETIME) + self._start_datetime = self.convert_datetime(self._start_datetime) + self._end_datetime = self.convert_datetime(min(self._end_datetime, self.END_DATETIME)) @property @abc.abstractmethod @@ -90,11 +100,20 @@ class YahooCollector: # daily, one year: 252 / 4 # us 1min, a week: 6.5 * 60 * 5 # cn 1min, a week: 4 * 60 * 5 - raise NotImplementedError("") + raise NotImplementedError("rewirte min_numbers_trading") @abc.abstractmethod def get_stock_list(self): - raise NotImplementedError("") + raise NotImplementedError("rewirte get_stock_list") + + @property + @abc.abstractclassmethod + def _timezone(self): + raise NotImplementedError("rewrite get_timezone") + + def convert_datetime(self, dt: pd.Timestamp): + dt = pd.Timestamp(dt, tz=self._timezone).timestamp() + return pd.Timestamp(dt, tz=tzlocal(), unit="s") def _sleep(self): time.sleep(self._delay) @@ -112,80 +131,90 @@ class YahooCollector: if df.empty: raise ValueError("df is empty") - symbol = self.normailze_symbol(symbol) + symbol = self.normalize_symbol(symbol) stock_path = self.save_dir.joinpath(f"{symbol}.csv") df["symbol"] = symbol - df.to_csv(stock_path, index=False) + if stock_path.exists(): + with stock_path.open("a") as fp: + df.to_csv(fp, index=False, header=None) + else: + with stock_path.open("w") as fp: + df.to_csv(fp, index=False) def _save_small_data(self, symbol, df): if len(df) <= self.min_numbers_trading: logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") _temp = self._mini_symbol_map.setdefault(symbol, []) _temp.append(df.copy()) - return symbol + return None else: if symbol in self._mini_symbol_map: self._mini_symbol_map.pop(symbol) - return None + return symbol def _get_from_remote(self, symbol): - if self._interval == "1d": + def _get_simple(start_, end_): self._sleep() try: - resp = Ticker(symbol, asynchronous=False).history( - interval=self._interval, start=self._start_datetime, end=self._end_datetime - ) + _resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=start_, end=end_) + if isinstance(_resp, pd.DataFrame): + return _resp.reset_index() + else: + logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{_resp}") except Exception as e: - logger.warning(f"{symbol}-{self._interval}-{self._start_datetime}-{self._end_datetime}:{e}") - resp = None - yield resp + logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{e}") + + _result = None + if self._interval == "1d": + _result = _get_simple(self._start_datetime, self._end_datetime) elif self._interval == "1m": - _res = [] - for _start in pd.date_range(self._start_datetime, self._end_datetime + pd.Timedelta(days=-1)): - _end = _start + pd.Timedelta(days=1) - self._sleep() - try: - resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=_start, end=_end) - if isinstance(resp, pd.DataFrame): - _res.append(resp) - except Exception as e: - logger.warning(f"{symbol}-{self._interval}-{_start}-{_end}:{e}") - if _res: - yield pd.concat(_res, sort=False).sort_values(["symbol", "date"]) + _start_date = self._start_datetime.date() + pd.Timedelta(days=1) + _end_date = self._end_datetime.date() + if _start_date >= _end_date: + _result = _get_simple(self._start_datetime, self._end_datetime) else: - yield None + _res = [] + + def _get_multi(start_, end_): + _resp = _get_simple(start_, end_) + if _resp is not None: + _res.append(_resp) + + for _s, _e in ((self._start_datetime, _start_date), (_end_date, self._end_datetime)): + _get_multi(_s, _e) + for _start in pd.date_range(_start_date, _end_date, closed="left"): + _end = _start + pd.Timedelta(days=1) + self._sleep() + _get_multi(_start, _end) + if _res: + _result = pd.concat(_res, sort=False).sort_values(["symbol", "date"]) else: raise ValueError(f"cannot support {self._interval}") + return _result + + def _get_data(self, symbol): + _result = None + df = self._get_from_remote(symbol) + if isinstance(df, pd.DataFrame): + if not df.empty: + if self._check_small_data: + if self._save_small_data(symbol, df) is not None: + _result = symbol + self.save_stock(symbol, df) + else: + _result = symbol + self.save_stock(symbol, df) + return _result def _collector(self, stock_list): error_symbol = [] - with ThreadPoolExecutor(max_workers=self.max_workers) as worker: - futures = {} - for _symbol in tqdm(stock_list): - for _resp in self._get_from_remote(_symbol): - if isinstance(_resp, pd.DataFrame): - df = _resp.reset_index() - if self._check_small_data: - if self._save_small_data(_symbol, df) is not None: - error_symbol.append(_symbol) - futures[worker.submit(self.save_stock, _symbol, df)] = _symbol - elif isinstance(_resp, dict): - if "timestamp" in _resp[_symbol]: - logger.warning(_resp[_symbol]) - error_symbol.append(_symbol) - elif _resp is None: + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + with tqdm(total=len(stock_list)) as p_bar: + for _symbol, _result in zip(stock_list, executor.map(self._get_data, stock_list)): + if _result is None: error_symbol.append(_symbol) - else: - if not (("1m data not available for" in _resp) or ("Data doesn't exist for" in _resp)): - error_symbol.append(_symbol) - logger.info("save stock data......") - for future in tqdm(as_completed(futures)): - try: - future.result() - except Exception as e: - logger.error(e) - error_symbol.append(futures[future]) + p_bar.update() print(error_symbol) logger.info(f"error symbol nums: {len(error_symbol)}") logger.info(f"current get symbol nums: {len(stock_list)}") @@ -204,8 +233,9 @@ class YahooCollector: logger.info(f"{i+1} finish.") for _symbol, _df_list in self._mini_symbol_map.items(): self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) - - logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}") + if self._mini_symbol_map: + logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}") + logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}") self.download_index_data() @@ -215,7 +245,7 @@ class YahooCollector: raise NotImplementedError("rewrite download_index_data") @abc.abstractmethod - def normailze_symbol(self, symbol: str): + def normalize_symbol(self, symbol: str): """normalize symbol""" raise NotImplementedError("rewrite normalize_symbol") @@ -237,30 +267,41 @@ class YahooCollectorCN(YahooCollector): def download_index_data(self): # TODO: from MSN # FIXME: 1m - _format = "%Y%m%d" - _begin = self._start_datetime.strftime(_format) - _end = (self._end_datetime + pd.Timedelta(days=-1)).strftime(_format) - for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items(): - logger.info(f"get bench data: {_index_name}({_index_code})......") - df = pd.DataFrame( - map( - lambda x: x.split(","), - requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()["data"][ - "klines" - ], - ) - ) - df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] - df["date"] = pd.to_datetime(df["date"]) - df = df.astype(float, errors="ignore") - df["adjclose"] = df["close"] - df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False) + if self._interval == "1d": + _format = "%Y%m%d" + _begin = self._start_datetime.strftime(_format) + _end = (self._end_datetime + pd.Timedelta(days=-1)).strftime(_format) + for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items(): + logger.info(f"get bench data: {_index_name}({_index_code})......") + try: + df = pd.DataFrame( + map( + lambda x: x.split(","), + requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[ + "data" + ]["klines"], + ) + ) + except Exception as e: + logger.warning(f"get {_index_name} error: {e}") + continue + df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] + df["date"] = pd.to_datetime(df["date"]) + df = df.astype(float, errors="ignore") + df["adjclose"] = df["close"] + df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False) + else: + logger.warning(f"{self.__class__.__name__} {self._interval} does not support: downlaod_index_data") - def normailze_symbol(self, symbol): + def normalize_symbol(self, symbol): symbol_s = symbol.split(".") symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}" return symbol + @property + def _timezone(self): + return "Asia/Shanghai" + class YahooCollectorUS(YahooCollector): @property @@ -283,9 +324,13 @@ class YahooCollectorUS(YahooCollector): def download_index_data(self): pass - def normailze_symbol(self, symbol): + def normalize_symbol(self, symbol): return symbol.upper() + @property + def _timezone(self): + return "America/New_York" + class YahooNormalize: COLUMNS = ["open", "close", "high", "low", "volume"] @@ -419,7 +464,14 @@ class Run: self.region = region def download_data( - self, max_collector_count=5, delay=0, start=None, end=None, interval="1d", check_data_length=True + self, + max_collector_count=5, + delay=0, + start=None, + end=None, + interval="1d", + check_data_length=False, + limit_nums=None, ): """download data from Internet @@ -436,8 +488,9 @@ class Run: end: str end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` check_data_length: bool - check data length, by default True - + check data length, by default False + limit_nums: int + using for debug, by default None Examples --------- # get daily data @@ -456,6 +509,7 @@ class Run: end=end, interval=interval, check_data_length=check_data_length, + limit_nums=limit_nums, ).collector_data() def normalize_data(self): @@ -469,7 +523,14 @@ class Run: _class(self.source_dir, self.normalize_dir, self.max_workers).normalize() def collector_data( - self, max_collector_count=5, delay=0, start=None, end=None, interval="1d", check_data_length=False + self, + max_collector_count=5, + delay=0, + start=None, + end=None, + interval="1d", + check_data_length=False, + limit_nums=None, ): """download -> normalize @@ -487,7 +548,8 @@ class Run: end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` check_data_length: bool check data length, by default False - + limit_nums: int + using for debug, by default None Examples ------- python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d @@ -499,6 +561,7 @@ class Run: end=end, interval=interval, check_data_length=check_data_length, + limit_nums=limit_nums, ) self.normalize_data()