From 6f150f3fd6f9d48a219e6e9045b285e7fb84c436 Mon Sep 17 00:00:00 2001 From: zhupr Date: Fri, 4 Jun 2021 22:28:42 +0800 Subject: [PATCH 01/44] Add YahooCollector support for extend data --- scripts/data_collector/base.py | 14 +- scripts/data_collector/cn_index/collector.py | 5 +- scripts/data_collector/yahoo/collector.py | 210 ++++++++++++++++++- 3 files changed, 211 insertions(+), 18 deletions(-) diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index 12983f6a5..cb51f9b22 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -226,11 +226,7 @@ class BaseCollector(abc.ABC): class BaseNormalize(abc.ABC): - def __init__( - self, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): + def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs): """ Parameters @@ -265,6 +261,7 @@ class Normalize: max_workers: int = 16, date_field_name: str = "date", symbol_field_name: str = "symbol", + **kwargs, ): """ @@ -291,7 +288,9 @@ class Normalize: self._max_workers = max_workers - self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name) + self._normalize_obj = normalize_class( + date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs + ) def _executor(self, file_path: Path): file_path = Path(file_path) @@ -404,7 +403,7 @@ class BaseRun(abc.ABC): limit_nums=limit_nums, ).collector_data() - def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs): """normalize data Parameters @@ -426,5 +425,6 @@ class BaseRun(abc.ABC): max_workers=self.max_workers, date_field_name=date_field_name, symbol_field_name=symbol_field_name, + **kwargs, ) yc.normalize() diff --git a/scripts/data_collector/cn_index/collector.py b/scripts/data_collector/cn_index/collector.py index 5af9785ec..1f8434c58 100644 --- a/scripts/data_collector/cn_index/collector.py +++ b/scripts/data_collector/cn_index/collector.py @@ -24,7 +24,10 @@ from data_collector.utils import get_calendar_list, get_trading_date_by_shift NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls" -INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A" + +# INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A" +# 2020-11-27 Announcement title change +INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89" class CSIIndex(IndexBase): diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 2cd080199..fcaa9ff92 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -23,7 +23,7 @@ from qlib.config import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from data_collector.base import BaseCollector, BaseNormalize, BaseRun +from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize from data_collector.utils import ( get_calendar_list, get_hs_stock_symbols, @@ -297,6 +297,7 @@ class YahooNormalize(BaseNormalize): calendar_list: list = None, date_field_name: str = "date", symbol_field_name: str = "symbol", + last_close: float = None, ): if df.empty: return df @@ -318,7 +319,10 @@ class YahooNormalize(BaseNormalize): df.sort_index(inplace=True) df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan _tmp_series = df["close"].fillna(method="ffill") - df["change"] = _tmp_series / _tmp_series.shift(1) - 1 + _tmp_shift_series = _tmp_series.shift(1) + if last_close is not None and isinstance(last_close, (int, float)): + _tmp_shift_series.iloc[0] = last_close + df["change"] = _tmp_series / _tmp_shift_series - 1 columns += ["change"] df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan @@ -367,6 +371,17 @@ class YahooNormalize1d(YahooNormalize, ABC): df = self._manual_adj_data(df) return df + def _get_first_close(self, df: pd.DataFrame) -> float: + """get first close value + + Notes + ----- + For incremental updates(append) to Yahoo 1D data, user need to use a close that is not 0 on the first trading day of the existing data + """ + df = df.loc[df["close"].first_valid_index() :] + _close = df["close"].iloc[0] + return _close + def _manual_adj_data(self, df: pd.DataFrame) -> pd.DataFrame: """manual adjust data: All fields (except change) are standardized according to the close of the first day""" if df.empty: @@ -374,8 +389,7 @@ class YahooNormalize1d(YahooNormalize, ABC): df = df.copy() df.sort_values(self._date_field_name, inplace=True) df = df.set_index(self._date_field_name) - df = df.loc[df["close"].first_valid_index() :] - _close = df["close"].iloc[0] + _close = self._get_first_close(df) for _col in df.columns: if _col == self._symbol_field_name: continue @@ -388,18 +402,97 @@ class YahooNormalize1d(YahooNormalize, ABC): return df.reset_index() +class YahooNormalize1dExtend(YahooNormalize1d): + def __init__( + self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + ): + """ + + Parameters + ---------- + old_qlib_data_dir: str, Path + the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name) + self._end_date, self._old_close = self._get_old_data(old_qlib_data_dir) + self._end_date = pd.Timestamp(self._end_date).strftime(self.DAILY_FORMAT) + + def _get_old_data(self, qlib_data_dir: [str, Path]): + import qlib + from qlib.data import D + + qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve()) + qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None) + df = D.features(D.instruments("all"), ["$close/$factor"]) + df.columns = ["close"] + return D.calendar()[-1], df + + def _get_first_close(self, df: pd.DataFrame) -> float: + _symbol = df.iloc[0][self._symbol_field_name] + try: + _df = self._old_close.loc(axis=0)[_symbol.upper()] + _close = _df.loc[_df.first_valid_index()]["close"] + except KeyError: + _close = super(YahooNormalize1dExtend, self)._get_first_close(df) + return _close + + def _get_last_close(self, df: pd.DataFrame) -> float: + _symbol = df.iloc[0][self._symbol_field_name] + try: + _df = self._old_close.loc(axis=0)[_symbol.upper()] + _close = _df.loc[_df.last_valid_index()]["close"] + except KeyError: + _close = None + return _close + + def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp: + _symbol = df.iloc[0][self._symbol_field_name] + try: + _df = self._old_close.loc(axis=0)[_symbol.upper()] + _date = _df.index.max() + except KeyError: + _date = None + return _date + + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + _last_close = self._get_last_close(df) + # reindex + _last_date = self._get_last_date(df) + if _last_date is not None: + df = df.set_index(self._date_field_name) + df.index = pd.to_datetime(df.index) + df = df[~df.index.duplicated(keep="first")] + _max_date = df.index.max() + df = df.reindex(self._calendar_list).loc[:_max_date].reset_index() + df = df[df[self._date_field_name] > _last_date] + _si = df["close"].first_valid_index() + if _si > df.index[0]: + logger.warning( + f"{df.iloc[0][self._symbol_field_name]} missing data: {df.loc[:_si][self._date_field_name]}" + ) + # normalize + df = self.normalize_yahoo( + df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close + ) + # adjusted price + df = self.adjusted_price(df) + df = self._manual_adj_data(df) + return df + + class YahooNormalize1min(YahooNormalize, ABC): AM_RANGE = None # type: tuple # eg: ("09:30:00", "11:29:00") PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00") # Whether the trading day of 1min data is consistent with 1d CONSISTENT_1d = False + CALC_PAUSED_NUM = False - def __init__( - self, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): + def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs): """ Parameters @@ -478,6 +571,54 @@ class YahooNormalize1min(YahooNormalize, ABC): df[_col] = df[_col] / df["factor"] else: df[_col] = df[_col] * df["factor"] + + if self.CALC_PAUSED_NUM: + df = self.calc_paused_num(df) + return df + + def calc_paused_num(self, df: pd.DataFrame): + _symbol = df.iloc[0][self._symbol_field_name] + df = df.copy() + df["date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) + # remove data that starts and ends with `np.nan` all day + all_data = [] + # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan + all_nan_nums = 0 + # Record the number of consecutive occurrences of trading days that are not nan throughout the day + not_nan_nums = 0 + for _date, _df in df.groupby(level="date"): + _df["paused"] = 0 + if not _df.loc[_df["volume"] < 0].empty: + logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}") + _df.loc[_df["volume"] < 0, "volume"] = np.nan + + check_fields = set(_df.columns) - { + "date", + "paused", + "factor", + self._date_field_name, + self._symbol_field_name, + } + if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all(): + all_nan_nums += 1 + not_nan_nums = 0 + _df["paused"] = 1 + if all_data: + _df["paused_num"] = not_nan_nums + all_data.append(_df) + else: + all_nan_nums = 0 + not_nan_nums += 1 + _df["paused_num"] = not_nan_nums + all_data.append(_df) + all_data = all_data[: len(all_data) - all_nan_nums] + if all_data: + df = pd.concat(all_data, sort=False) + else: + logger.warning(f"data is empty: {_symbol}") + df = pd.DataFrame() + return df + del df["date"] return df @abc.abstractmethod @@ -523,11 +664,16 @@ class YahooNormalizeCN1d(YahooNormalizeCN, YahooNormalize1d): pass +class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend): + pass + + class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min): AM_RANGE = ("09:30:00", "11:29:00") PM_RANGE = ("13:00:00", "14:59:00") CONSISTENT_1d = True + CALC_PAUSED_NUM = True def _get_calendar_list(self): return self.generate_1min_from_daily(self.calendar_list_1d) @@ -624,10 +770,54 @@ class Run(BaseRun): Examples --------- - $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d + $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d """ super(Run, self).normalize_data(date_field_name, symbol_field_name) + def normalize_data_1d_extend( + self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol" + ): + """normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data) + + Notes + ----- + Steps to extend yahoo qlib data: + + 1. download qlib data: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data; save to + + 2. collector source data: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#collector-data; save to + + 3. normalize new source data(from step 2): python scripts/data_collector/yahoo/collector.py normalize_data_1d_extend --old_qlib_dir --source_dir --normalize_dir --region CN --interval 1d + + 4. dump data: python scripts/dump_bin.py dump_update --csv_path --qlib_dir --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date + + 5. update instrument(eg. csi300): python python scripts/data_collector/cn_index/collector.py --index_name CSI300 --qlib_dir --method parse_instruments + + Parameters + ---------- + old_qlib_data_dir: str + the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data + date_field_name: str + date field name, default date + symbol_field_name: str + symbol field name, default symbol + + Examples + --------- + $ python collector.py normalize_data_1d_extend --old_qlib_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d + """ + _class = getattr(self._cur_module, f"{self.normalize_class_name}Extend") + yc = Normalize( + source_dir=self.source_dir, + target_dir=self.normalize_dir, + normalize_class=_class, + max_workers=self.max_workers, + date_field_name=date_field_name, + symbol_field_name=symbol_field_name, + old_qlib_data_dir=old_qlib_data_dir, + ) + yc.normalize() + if __name__ == "__main__": fire.Fire(Run) From 554b9c78268c6439754e4a9f6fb4b6b35e48cdd4 Mon Sep 17 00:00:00 2001 From: zhupr Date: Sat, 5 Jun 2021 16:01:01 +0800 Subject: [PATCH 02/44] fix YahooCollector getting 1min data occasionally missing --- scripts/data_collector/base.py | 14 ++--- scripts/data_collector/yahoo/collector.py | 63 +++++++++++------------ 2 files changed, 36 insertions(+), 41 deletions(-) diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index cb51f9b22..d261f11cd 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -22,9 +22,9 @@ class BaseCollector(abc.ABC): NORMAL_FLAG = "NORMAL" DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01") - DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6)) - DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) - DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1)).date() + DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)).date() + DEFAULT_END_DATETIME_1MIN = DEFAULT_END_DATETIME_1D INTERVAL_1min = "1min" INTERVAL_1d = "1d" @@ -35,7 +35,7 @@ class BaseCollector(abc.ABC): start=None, end=None, interval="1d", - max_workers=4, + max_workers=1, max_collector_count=2, delay=0, check_data_length: bool = False, @@ -48,7 +48,7 @@ class BaseCollector(abc.ABC): save_dir: str instrument save dir max_workers: int - workers, default 4 + workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1 max_collector_count: int default 2 delay: float @@ -310,7 +310,7 @@ class Normalize: class BaseRun(abc.ABC): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"): """ Parameters @@ -320,7 +320,7 @@ class BaseRun(abc.ABC): normalize_dir: str Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int - Concurrent number, default is 4 + Concurrent number, default is 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1 interval: str freq, value from [1min, 1d], default 1d """ diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index fcaa9ff92..a48c5f16a 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -25,6 +25,7 @@ CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize from data_collector.utils import ( + deco_retry, get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols, @@ -92,10 +93,6 @@ class YahooCollector(BaseCollector): else: raise ValueError(f"interval error: {self.interval}") - # using for 1min - self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone) - self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone) - self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) @@ -140,40 +137,36 @@ class YahooCollector(BaseCollector): def get_data( self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp ) -> pd.DataFrame: + @deco_retry(retry_sleep=1) def _get_simple(start_, end_): self.sleep() _remote_interval = "1m" if interval == self.INTERVAL_1min else interval - return self.get_data_from_remote( + resp = self.get_data_from_remote( symbol, interval=_remote_interval, start=start_, end=end_, ) + if resp is None or resp.empty: + raise ValueError(f"get data error: {symbol}--{start_}--{end_}") + return resp _result = None if interval == self.INTERVAL_1d: _result = _get_simple(start_datetime, end_datetime) elif interval == self.INTERVAL_1min: - if self._next_datetime >= self._latest_datetime: - _result = _get_simple(start_datetime, end_datetime) - else: - _res = [] - - def _get_multi(start_, end_): - _resp = _get_simple(start_, end_) - if _resp is not None and not _resp.empty: - _res.append(_resp) - - for _s, _e in ( - (self.start_datetime, self._next_datetime), - (self._latest_datetime, self.end_datetime), - ): - _get_multi(_s, _e) - for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"): - _end = _start + pd.Timedelta(days=1) - _get_multi(_start, _end) - if _res: - _result = pd.concat(_res, sort=False).sort_values(["symbol", "date"]) + _res = [] + _start = self.start_datetime + while _start < self.end_datetime: + _tmp_end = min(_start + pd.Timedelta(days=7), self.end_datetime) + try: + _resp = _get_simple(_start, _tmp_end) + _res.append(_resp) + except ValueError as e: + pass + _start = _tmp_end + if _res: + _result = pd.concat(_res, sort=False).sort_values(["symbol", "date"]) else: raise ValueError(f"cannot support {self.interval}") return pd.DataFrame() if _result is None else _result @@ -520,6 +513,10 @@ class YahooNormalize1min(YahooNormalize, ABC): calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE ) + def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame: + data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end) + return data_1d + def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: # TODO: using daily data factor if df.empty: @@ -529,9 +526,7 @@ class YahooNormalize1min(YahooNormalize, ABC): # get 1d data from yahoo _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) - data_1d = YahooCollector.get_data_from_remote( - self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end - ) + data_1d = self.get_1d_data(symbol, _start, _end) if data_1d is None or data_1d.empty: df["factor"] = 1 # TODO: np.nan or 1 or 0 @@ -579,21 +574,21 @@ class YahooNormalize1min(YahooNormalize, ABC): def calc_paused_num(self, df: pd.DataFrame): _symbol = df.iloc[0][self._symbol_field_name] df = df.copy() - df["date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) + df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) # remove data that starts and ends with `np.nan` all day all_data = [] # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan all_nan_nums = 0 # Record the number of consecutive occurrences of trading days that are not nan throughout the day not_nan_nums = 0 - for _date, _df in df.groupby(level="date"): + for _date, _df in df.groupby("_tmp_date"): _df["paused"] = 0 if not _df.loc[_df["volume"] < 0].empty: logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}") _df.loc[_df["volume"] < 0, "volume"] = np.nan check_fields = set(_df.columns) - { - "date", + "_tmp_date", "paused", "factor", self._date_field_name, @@ -618,7 +613,7 @@ class YahooNormalize1min(YahooNormalize, ABC): logger.warning(f"data is empty: {_symbol}") df = pd.DataFrame() return df - del df["date"] + del df["_tmp_date"] return df @abc.abstractmethod @@ -690,7 +685,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min): class Run(BaseRun): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): """ Parameters @@ -700,7 +695,7 @@ class Run(BaseRun): normalize_dir: str Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int - Concurrent number, default is 4 + Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1 interval: str freq, value from [1min, 1d], default 1d region: str From a845a2271b3352e410b99843628167c7dce095db Mon Sep 17 00:00:00 2001 From: zhupr Date: Tue, 8 Jun 2021 14:45:20 +0800 Subject: [PATCH 03/44] add normalize 1min to use local data && change the default parameters for collecting 1min --- scripts/data_collector/base.py | 52 ++--- .../contrib/fill_cn_1min_data/README.md | 23 ++ .../fill_cn_1min_data/fill_cn_1min_data.py | 98 +++++++++ .../fill_cn_1min_data/requirements.txt | 5 + .../README.md | 0 .../future_trading_date_collector.py | 2 +- .../requirements.txt | 0 scripts/data_collector/fund/collector.py | 22 +- scripts/data_collector/yahoo/collector.py | 196 +++++++++++++++--- 9 files changed, 328 insertions(+), 70 deletions(-) create mode 100644 scripts/data_collector/contrib/fill_cn_1min_data/README.md create mode 100644 scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py create mode 100644 scripts/data_collector/contrib/fill_cn_1min_data/requirements.txt rename scripts/data_collector/contrib/{ => future_trading_date_collector}/README.md (100%) rename scripts/data_collector/contrib/{ => future_trading_date_collector}/future_trading_date_collector.py (98%) rename scripts/data_collector/contrib/{ => future_trading_date_collector}/requirements.txt (100%) diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index d261f11cd..08e1838a4 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -7,7 +7,7 @@ import time import datetime import importlib from pathlib import Path -from typing import Type +from typing import Type, Iterable from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor import pandas as pd @@ -38,7 +38,7 @@ class BaseCollector(abc.ABC): max_workers=1, max_collector_count=2, delay=0, - check_data_length: bool = False, + check_data_length: int = None, limit_nums: int = None, ): """ @@ -59,8 +59,8 @@ class BaseCollector(abc.ABC): start datetime, default None end: str end datetime, default None - check_data_length: bool - check data length, by default False + check_data_length: int + check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. limit_nums: int using for debug, by default None """ @@ -72,7 +72,7 @@ class BaseCollector(abc.ABC): self.max_collector_count = max_collector_count self.mini_symbol_map = {} self.interval = interval - self.check_small_data = check_data_length + self.check_data_length = max(int(check_data_length) if check_data_length is not None else 0, 0) self.start_datetime = self.normalize_start_datetime(start) self.end_datetime = self.normalize_end_datetime(end) @@ -99,14 +99,6 @@ class BaseCollector(abc.ABC): else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}") ) - @property - @abc.abstractmethod - def min_numbers_trading(self): - # daily, one year: 252 / 4 - # us 1min, a week: 6.5 * 60 * 5 - # cn 1min, a week: 4 * 60 * 5 - raise NotImplementedError("rewrite min_numbers_trading") - @abc.abstractmethod def get_instrument_list(self): raise NotImplementedError("rewrite get_instrument_list") @@ -132,7 +124,7 @@ class BaseCollector(abc.ABC): Returns --------- - pd.DataFrame, "symbol" in pd.columns + pd.DataFrame, "symbol" and "date"in pd.columns """ raise NotImplementedError("rewrite get_timezone") @@ -151,7 +143,7 @@ class BaseCollector(abc.ABC): self.sleep() df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime) _result = self.NORMAL_FLAG - if self.check_small_data: + if self.check_data_length > 0: _result = self.cache_small_data(symbol, df) if _result == self.NORMAL_FLAG: self.save_instrument(symbol, df) @@ -181,8 +173,8 @@ class BaseCollector(abc.ABC): df.to_csv(instrument_path, index=False) def cache_small_data(self, symbol, df): - if len(df) <= self.min_numbers_trading: - logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") + if len(df) < self.check_data_length: + logger.warning(f"the number of trading days of {symbol} is less than {self.check_data_length}!") _temp = self.mini_symbol_map.setdefault(symbol, []) _temp.append(df.copy()) return self.CACHE_FLAG @@ -194,9 +186,17 @@ class BaseCollector(abc.ABC): def _collector(self, instrument_list): error_symbol = [] - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - with tqdm(total=len(instrument_list)) as p_bar: - for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)): + with tqdm(total=len(instrument_list)) as p_bar: + if self.max_workers is not None and self.max_workers > 1: + logger.info(f"concurrent collector, max_workers: {self.max_workers}") + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)): + if _result != self.NORMAL_FLAG: + error_symbol.append(_symbol) + p_bar.update() + else: + for _symbol in instrument_list: + _result = self._simple_collector(_symbol) if _result != self.NORMAL_FLAG: error_symbol.append(_symbol) p_bar.update() @@ -217,11 +217,11 @@ class BaseCollector(abc.ABC): instrument_list = self._collector(instrument_list) logger.info(f"{i+1} finish.") for _symbol, _df_list in self.mini_symbol_map.items(): - self.save_instrument( - _symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]) - ) + _df = pd.concat(_df_list, sort=False) + if not _df.empty: + self.save_instrument(_symbol, _df.drop_duplicates(["date"]).sort_values(["date"])) if self.mini_symbol_map: - logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}") + logger.warning(f"less than {self.check_data_length} instrument list: {list(self.mini_symbol_map.keys())}") logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}") @@ -247,7 +247,7 @@ class BaseNormalize(abc.ABC): raise NotImplementedError("") @abc.abstractmethod - def _get_calendar_list(self): + def _get_calendar_list(self) -> Iterable[pd.Timestamp]: """Get benchmark calendar""" raise NotImplementedError("") @@ -296,7 +296,7 @@ class Normalize: file_path = Path(file_path) df = pd.read_csv(file_path) df = self._normalize_obj.normalize(df) - if not df.empty: + if df is not None and not df.empty: df.to_csv(self._target_dir.joinpath(file_path.name), index=False) def normalize(self): diff --git a/scripts/data_collector/contrib/fill_cn_1min_data/README.md b/scripts/data_collector/contrib/fill_cn_1min_data/README.md new file mode 100644 index 000000000..c9ff0629c --- /dev/null +++ b/scripts/data_collector/contrib/fill_cn_1min_data/README.md @@ -0,0 +1,23 @@ +# Use 1d data to fill in the missing symbols relative to 1min + + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## fill 1min data + +```bash +python fill_1min_using_1d.py --data_1min_dir ~/.qlib/csv_data/cn_data_1min --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data +``` + +## Parameters + +- ata_1min_dir: csv data +- qlib_data_1d_dir: qlib data directory +- max_workers: `ThreadPoolExecutor(max_workers=max_workers)`, by default *16* +- date_field_name: date field name, by default *date* +- symbol_field_name: symbol field name, by default *symbol* + diff --git a/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py b/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py new file mode 100644 index 000000000..4abca3361 --- /dev/null +++ b/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor + +import fire +import qlib +import pandas as pd +from tqdm import tqdm +from qlib.data import D +from loguru import logger + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent.parent)) +from data_collector.utils import generate_minutes_calendar_from_daily + + +def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: str = "date"): + csv_files = list(data_1min_dir.glob("*.csv")) + min_date = None + max_date = None + with tqdm(total=len(csv_files)) as p_bar: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for _file, _result in zip(csv_files, executor.map(pd.read_csv, csv_files)): + if not _result.empty: + _dates = pd.to_datetime(_result[date_field_name]) + + _tmp_min = _dates.min() + min_date = min_date(min_date, _tmp_min) if min_date is not None else _tmp_min + + _tmp_max = _dates.max() + max_date = min_date(max_date, _tmp_max) if max_date is not None else _tmp_max + p_bar.update() + return min_date, max_date + + +def get_symbols(data_1min_dir: Path): + return list(map(lambda x: x.name[:-4].upper(), data_1min_dir.glob("*.csv"))) + + +def fill_1min_using_1d( + data_1min_dir: [str, Path], + qlib_data_1d_dir: [str, Path], + max_workers: int = 16, + date_field_name: str = "date", + symbol_field_name: str = "symbol", +): + """Use 1d data to fill in the missing symbols relative to 1min + + Parameters + ---------- + data_1min_dir: str + 1min data dir + qlib_data_1d_dir: str + 1d qlib data(bin data) dir, from: https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format + max_workers: int + ThreadPoolExecutor(max_workers), by default 16 + date_field_name: str + date field name, by default date + symbol_field_name: str + symbol field name, by default symbol + + """ + data_1min_dir = Path(data_1min_dir).expanduser().resolve() + qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve() + + min_date, max_date = get_date_range(data_1min_dir, max_workers, date_field_name) + symbols_1min = get_symbols(data_1min_dir) + + qlib.init(provider_uri=str(qlib_data_1d_dir)) + data_1d = D.features(D.instruments("all"), ["$close"], min_date, max_date, freq="day") + + miss_symbols = set(data_1d.index.get_level_values(level="instrument").unique()) - set(symbols_1min) + if not miss_symbols: + logger.warning("More symbols in 1min than 1d, no padding required") + return + + logger.info(f"miss_symbols {len(miss_symbols)}: {miss_symbols}") + tmp_df = pd.read_csv(list(data_1min_dir.glob("*.csv"))[0]) + columns = tmp_df.columns + _si = tmp_df[symbol_field_name].first_valid_index() + is_lower = tmp_df.loc[tmp_df][symbol_field_name].islower() + for symbol in tqdm(miss_symbols): + if is_lower: + symbol = symbol.lower() + index_1d = data_1d.loc(axis=0)[symbol.upper()].index + index_1min = generate_minutes_calendar_from_daily(index_1d) + index_1min.name = date_field_name + _df = pd.DataFrame(columns=columns, index=index_1min) + _df.reset_index(inplace=True) + _df[symbol_field_name] = symbol + _df.to_csv(data_1min_dir.joinpath(f"{symbol}.csv"), index=False) + + +if __name__ == "__main__": + fire.Fire(fill_1min_using_1d) diff --git a/scripts/data_collector/contrib/fill_cn_1min_data/requirements.txt b/scripts/data_collector/contrib/fill_cn_1min_data/requirements.txt new file mode 100644 index 000000000..057683685 --- /dev/null +++ b/scripts/data_collector/contrib/fill_cn_1min_data/requirements.txt @@ -0,0 +1,5 @@ +fire +pandas +loguru +tqdm +pyqlib \ No newline at end of file diff --git a/scripts/data_collector/contrib/README.md b/scripts/data_collector/contrib/future_trading_date_collector/README.md similarity index 100% rename from scripts/data_collector/contrib/README.md rename to scripts/data_collector/contrib/future_trading_date_collector/README.md diff --git a/scripts/data_collector/contrib/future_trading_date_collector.py b/scripts/data_collector/contrib/future_trading_date_collector/future_trading_date_collector.py similarity index 98% rename from scripts/data_collector/contrib/future_trading_date_collector.py rename to scripts/data_collector/contrib/future_trading_date_collector/future_trading_date_collector.py index 4da62d465..8df0a4972 100644 --- a/scripts/data_collector/contrib/future_trading_date_collector.py +++ b/scripts/data_collector/contrib/future_trading_date_collector/future_trading_date_collector.py @@ -14,7 +14,7 @@ from loguru import logger import baostock as bs CUR_DIR = Path(__file__).resolve().parent -sys.path.append(str(CUR_DIR.parent.parent)) +sys.path.append(str(CUR_DIR.parent.parent.parent)) from data_collector.utils import generate_minutes_calendar_from_daily diff --git a/scripts/data_collector/contrib/requirements.txt b/scripts/data_collector/contrib/future_trading_date_collector/requirements.txt similarity index 100% rename from scripts/data_collector/contrib/requirements.txt rename to scripts/data_collector/contrib/future_trading_date_collector/requirements.txt diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 10800a7a3..fc06a27e4 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -3,18 +3,13 @@ import abc import sys -import copy -import time import datetime -import importlib import json from abc import ABC from pathlib import Path -from typing import Iterable, Type import fire import requests -import numpy as np import pandas as pd from loguru import logger from dateutil.tz import tzlocal @@ -38,7 +33,7 @@ class FundCollector(BaseCollector): max_workers=4, max_collector_count=2, delay=0, - check_data_length: bool = False, + check_data_length: int = None, limit_nums: int = None, ): """ @@ -59,8 +54,8 @@ class FundCollector(BaseCollector): start datetime, default None end: str end datetime, default None - check_data_length: bool - check data length, by default False + check_data_length: int + check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. limit_nums: int using for debug, by default None """ @@ -168,10 +163,7 @@ class FundollectorCN(FundCollector, ABC): class FundCollectorCN1d(FundollectorCN): - @property - def min_numbers_trading(self): - return 252 / 4 - + pass class FundNormalize(BaseNormalize): DAILY_FORMAT = "%Y-%m-%d" @@ -261,7 +253,7 @@ class Run(BaseRun): start=None, end=None, interval="1d", - check_data_length=False, + check_data_length=None, limit_nums=None, ): """download data from Internet @@ -278,8 +270,8 @@ class Run(BaseRun): start datetime, default "2000-01-01" end: str end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` - check_data_length: bool # if this param useful? - check data length, by default False + check_data_length: int # if this param useful? + check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. limit_nums: int using for debug, by default None diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index a48c5f16a..16b0a32ba 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -137,7 +137,7 @@ class YahooCollector(BaseCollector): def get_data( self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp ) -> pd.DataFrame: - @deco_retry(retry_sleep=1) + @deco_retry(retry_sleep=self.delay) def _get_simple(start_, end_): self.sleep() _remote_interval = "1m" if interval == self.INTERVAL_1min else interval @@ -200,10 +200,6 @@ class YahooCollectorCN(YahooCollector, ABC): class YahooCollectorCN1d(YahooCollectorCN): - @property - def min_numbers_trading(self): - return 252 / 4 - def download_index_data(self): # TODO: from MSN _format = "%Y%m%d" @@ -237,10 +233,6 @@ class YahooCollectorCN1d(YahooCollectorCN): class YahooCollectorCN1min(YahooCollectorCN): - @property - def min_numbers_trading(self): - return 60 * 4 * 5 - def download_index_data(self): # TODO: 1m logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data") @@ -269,15 +261,11 @@ class YahooCollectorUS(YahooCollector, ABC): class YahooCollectorUS1d(YahooCollectorUS): - @property - def min_numbers_trading(self): - return 252 / 4 + pass class YahooCollectorUS1min(YahooCollectorUS): - @property - def min_numbers_trading(self): - return 60 * 6.5 * 5 + pass class YahooNormalize(BaseNormalize): @@ -514,7 +502,17 @@ class YahooNormalize1min(YahooNormalize, ABC): ) def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame: + """get 1d data + + Returns + ------ + data_1d: pd.DataFrame + set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {} + + """ data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end) + if not (data_1d is None or data_1d.empty): + data_1d = self.data_1d_obj.normalize(data_1d) return data_1d def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: @@ -526,13 +524,12 @@ class YahooNormalize1min(YahooNormalize, ABC): # get 1d data from yahoo _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) - data_1d = self.get_1d_data(symbol, _start, _end) + data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end) if data_1d is None or data_1d.empty: df["factor"] = 1 # TODO: np.nan or 1 or 0 df["paused"] = np.nan else: - data_1d = self.data_1d_obj.normalize(data_1d) # type: pd.DataFrame # NOTE: volume is np.nan or volume <= 0, paused = 1 # FIXME: find a more accurate data source data_1d["paused"] = 0 @@ -621,12 +618,12 @@ class YahooNormalize1min(YahooNormalize, ABC): raise NotImplementedError("rewrite symbol_to_yahoo") @abc.abstractmethod - def _get_1d_calendar_list(self): + def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: raise NotImplementedError("rewrite _get_1d_calendar_list") class YahooNormalizeUS: - def _get_calendar_list(self): + def _get_calendar_list(self) -> Iterable[pd.Timestamp]: # TODO: from MSN return get_calendar_list("US_ALL") @@ -638,7 +635,7 @@ class YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d): class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min): CONSISTENT_1d = False - def _get_calendar_list(self): + def _get_calendar_list(self) -> Iterable[pd.Timestamp]: # TODO: support 1min raise ValueError("Does not support 1min") @@ -650,7 +647,7 @@ class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min): class YahooNormalizeCN: - def _get_calendar_list(self): + def _get_calendar_list(self) -> Iterable[pd.Timestamp]: # TODO: from MSN return get_calendar_list("ALL") @@ -670,7 +667,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min): CONSISTENT_1d = True CALC_PAUSED_NUM = True - def _get_calendar_list(self): + def _get_calendar_list(self) -> Iterable[pd.Timestamp]: return self.generate_1min_from_daily(self.calendar_list_1d) def symbol_to_yahoo(self, symbol): @@ -680,10 +677,67 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min): symbol = symbol[2:] + "." + _exchange return symbol - def _get_1d_calendar_list(self): + def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: return get_calendar_list("ALL") +class YahooNormalizeCN1minOffline(YahooNormalizeCN1min): + """Normalised to 1min using local 1d data + 1d data usually from: Normalised to 1min using local 1d data + """ + + def __init__( + self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + ): + """ + + Parameters + ---------- + qlib_data_1d_dir: str, Path + the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name) + self.qlib_data_1d_dir = qlib_data_1d_dir + self._all_1d_data = self._get_all_1d_data() + + def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: + import qlib + from qlib.data import D + + qlib.init(provider_uri=self.qlib_data_1d_dir) + return list(D.calendar(freq="day")) + + def _get_all_1d_data(self): + import qlib + from qlib.data import D + + qlib.init(provider_uri=self.qlib_data_1d_dir) + df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor"], freq="day") + df.reset_index(inplace=True) + df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True) + df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) + return df + + def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame: + """get 1d data + + Returns + ------ + data_1d: pd.DataFrame + set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {} + + """ + return self._all_1d_data[ + (self._all_1d_data[self._symbol_field_name] == symbol.upper()) + & (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start)) + & (self._all_1d_data[self._date_field_name] < pd.Timestamp(end)) + ] + + class Run(BaseRun): def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): """ @@ -722,7 +776,7 @@ class Run(BaseRun): delay=0, start=None, end=None, - check_data_length=False, + check_data_length=None, limit_nums=None, ): """download data from Internet @@ -734,14 +788,21 @@ class Run(BaseRun): delay: float time.sleep(delay), default 0 start: str - start datetime, default "2000-01-01" + start datetime, default "2000-01-01"; closed interval(including start) end: str - end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` - check_data_length: bool - check data length, by default False + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``; open interval(excluding end) + check_data_length: int + check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. limit_nums: int using for debug, by default None + Notes + ----- + check_data_length, example: + daily, one year: 252 // 4 + us 1min, a week: 6.5 * 60 * 5 + cn 1min, a week: 4 * 60 * 5 + Examples --------- # get daily data @@ -813,6 +874,85 @@ class Run(BaseRun): ) yc.normalize() + def normalize_data_1min_cn_offline( + self, qlib_data_1d_dir, date_field_name: str = "date", symbol_field_name: str = "symbol" + ): + """Normalised to 1min using local 1d data + + Parameters + ---------- + qlib_data_1d_dir: str + the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data + date_field_name: str + date field name, default date + symbol_field_name: str + symbol field name, default symbol + + Examples + --------- + $ python collector.py normalize_data_1min_cn_offline --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min + """ + _class = getattr(self._cur_module, f"{self.normalize_class_name}Offline") + yc = Normalize( + source_dir=self.source_dir, + target_dir=self.normalize_dir, + normalize_class=_class, + max_workers=self.max_workers, + date_field_name=date_field_name, + symbol_field_name=symbol_field_name, + qlib_data_1d_dir=qlib_data_1d_dir, + ) + yc.normalize() + + def download_today_data( + self, + max_collector_count=2, + delay=0, + check_data_length=None, + limit_nums=None, + ): + """download today data from Internet + + Parameters + ---------- + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + check_data_length: int + check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. + limit_nums: int + using for debug, by default None + + Notes + ----- + Download today's data: + start_time = datetime.datetime.now().date(); closed interval(including start) + end_time = pd.Timestamp(start_time + pd.Timedelta(days=1)).date(); open interval(excluding end) + + check_data_length, example: + daily, one year: 252 // 4 + us 1min, a week: 6.5 * 60 * 5 + cn 1min, a week: 4 * 60 * 5 + + Examples + --------- + # get daily data + $ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1d + # get 1m data + $ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1m + """ + start = datetime.datetime.now().date() + end = pd.Timestamp(start + pd.Timedelta(days=1)).date() + self.download_data( + max_collector_count, + delay, + start.strftime("%Y-%m-%d"), + end.strftime("%Y-%m-%d"), + check_data_length, + limit_nums, + ) + if __name__ == "__main__": fire.Fire(Run) From 03eb0882de52b9c8476fd8456185e99b41e11dc5 Mon Sep 17 00:00:00 2001 From: zhupr Date: Tue, 8 Jun 2021 22:23:05 +0800 Subject: [PATCH 04/44] fix YahooNormalizeCN1minOffline bugs --- .../fill_cn_1min_data/fill_cn_1min_data.py | 10 ++++---- scripts/data_collector/yahoo/collector.py | 23 +++++-------------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py b/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py index 4abca3361..0a721298d 100644 --- a/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py +++ b/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py @@ -28,10 +28,9 @@ def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: _dates = pd.to_datetime(_result[date_field_name]) _tmp_min = _dates.min() - min_date = min_date(min_date, _tmp_min) if min_date is not None else _tmp_min - + min_date = min(min_date, _tmp_min) if min_date is not None else _tmp_min _tmp_max = _dates.max() - max_date = min_date(max_date, _tmp_max) if max_date is not None else _tmp_max + max_date = max(max_date, _tmp_max) if max_date is not None else _tmp_max p_bar.update() return min_date, max_date @@ -81,7 +80,7 @@ def fill_1min_using_1d( tmp_df = pd.read_csv(list(data_1min_dir.glob("*.csv"))[0]) columns = tmp_df.columns _si = tmp_df[symbol_field_name].first_valid_index() - is_lower = tmp_df.loc[tmp_df][symbol_field_name].islower() + is_lower = tmp_df.loc[_si][symbol_field_name].islower() for symbol in tqdm(miss_symbols): if is_lower: symbol = symbol.lower() @@ -89,8 +88,11 @@ def fill_1min_using_1d( index_1min = generate_minutes_calendar_from_daily(index_1d) index_1min.name = date_field_name _df = pd.DataFrame(columns=columns, index=index_1min) + if date_field_name in _df.columns: + del _df[date_field_name] _df.reset_index(inplace=True) _df[symbol_field_name] = symbol + _df["paused_num"] = 0 _df.to_csv(data_1min_dir.joinpath(f"{symbol}.csv"), index=False) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 16b0a32ba..58e1d3009 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -473,21 +473,6 @@ class YahooNormalize1min(YahooNormalize, ABC): CONSISTENT_1d = False CALC_PAUSED_NUM = False - def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs): - """ - - Parameters - ---------- - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name) - _class_name = self.__class__.__name__.replace("min", "d") - _class = getattr(importlib.import_module("collector"), _class_name) # type: Type[YahooNormalize] - self.data_1d_obj = _class(self._date_field_name, self._symbol_field_name) - @property def calendar_list_1d(self): calendar_list_1d = getattr(self, "_calendar_list_1d", None) @@ -512,7 +497,10 @@ class YahooNormalize1min(YahooNormalize, ABC): """ data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end) if not (data_1d is None or data_1d.empty): - data_1d = self.data_1d_obj.normalize(data_1d) + _class_name = self.__class__.__name__.replace("min", "d") + _class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name) + data_1d_obj = _class(self._date_field_name, self._symbol_field_name) + data_1d = data_1d_obj.normalize(data_1d) return data_1d def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: @@ -525,6 +513,7 @@ class YahooNormalize1min(YahooNormalize, ABC): _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end) + data_1d = data_1d.copy() if data_1d is None or data_1d.empty: df["factor"] = 1 # TODO: np.nan or 1 or 0 @@ -700,8 +689,8 @@ class YahooNormalizeCN1minOffline(YahooNormalizeCN1min): symbol_field_name: str symbol field name, default is symbol """ - super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name) self.qlib_data_1d_dir = qlib_data_1d_dir + super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name) self._all_1d_data = self._get_all_1d_data() def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: From 9a44fbf9c1e797fe99c9ef283586544347eede0d Mon Sep 17 00:00:00 2001 From: zhupr Date: Tue, 8 Jun 2021 22:52:31 +0800 Subject: [PATCH 05/44] fix PEP8: qlib/scripts/data_collector/fund/collector.py --- scripts/data_collector/fund/collector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index fc06a27e4..7b5566f72 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -165,6 +165,7 @@ class FundollectorCN(FundCollector, ABC): class FundCollectorCN1d(FundollectorCN): pass + class FundNormalize(BaseNormalize): DAILY_FORMAT = "%Y-%m-%d" From d4b36bdab448859f04b163b89843594660b3323f Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 11 Jun 2021 01:58:04 +0000 Subject: [PATCH 06/44] Online fix - Skip duplicated qlib.auto_init() - Fix TSDatasetH flt_col bug! - Resolve qlib log attribute confliction - Trainer API enhancement - More docs and user-friendly warning --- qlib/__init__.py | 13 ++++++-- qlib/data/dataset/__init__.py | 30 +++++++++++++---- qlib/log.py | 8 +++-- qlib/model/trainer.py | 56 +++++++++++++++++++++----------- qlib/utils/__init__.py | 22 +++++++++++++ qlib/workflow/exp.py | 11 +++++-- qlib/workflow/expm.py | 9 +++++ qlib/workflow/online/manager.py | 15 ++++++--- qlib/workflow/online/strategy.py | 6 ++++ qlib/workflow/task/collect.py | 10 +++++- qlib/workflow/task/gen.py | 4 ++- qlib/workflow/task/manage.py | 10 ++++-- 12 files changed, 150 insertions(+), 44 deletions(-) diff --git a/qlib/__init__.py b/qlib/__init__.py index 4fd48f8c1..5f45f4557 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -20,11 +20,17 @@ def init(default_conf="client", **kwargs): from .config import C from .data.cache import H - H.clear() - # FIXME: this logger ignored the level in config logger = get_module_logger("Initialization", level=logging.INFO) + skip_if_reg = kwargs.pop("skip_if_reg", False) + if skip_if_reg and C.registered: + # if we reinitialize Qlib during running an experiment `R.start`. + # it will result in loss of the recorder + logger.warning("Skip initialization because `skip_if_reg is True`") + return + + H.clear() C.set(default_conf, **kwargs) # check path if server/local @@ -197,14 +203,15 @@ def auto_init(**kwargs): - Find the project configuration and init qlib - The parsing process will be affected by the `conf_type` of the configuration file - Init qlib with default config + - Skip initialization if already initialized """ + kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True) try: pp = get_project_path(cur_path=kwargs.pop("cur_path", None)) except FileNotFoundError: init(**kwargs) else: - conf_pp = pp / "config.yaml" with conf_pp.open() as f: conf = yaml.safe_load(f) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 8d7786368..fe641be35 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -1,6 +1,6 @@ from ...utils.serial import Serializable from typing import Union, List, Tuple, Dict, Text, Optional -from ...utils import init_instance_by_config, np_ffill +from ...utils import init_instance_by_config, np_ffill, time_to_slc_point from ...log import get_module_logger from .handler import DataHandler, DataHandlerLP from copy import deepcopy @@ -243,6 +243,8 @@ class TSDataSampler: It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series dataset based on tabular data. + - On time step dimension, the smaller index indicates the historical data and the larger index indicates the future + data. If user have further requirements for processing data, user could process them based on `TSDataSampler` or create more powerful subclasses. @@ -309,11 +311,19 @@ class TSDataSampler: self.data_index = deepcopy(self.data.index) if flt_data is not None: - self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1) + if isinstance(flt_data, pd.DataFrame): + assert len(flt_data.columns) == 1 + flt_data = flt_data.iloc[:, 0] + # NOTE: bool(np.nan) is True !!!!!!!! + # make sure reindex comes first. Otherwise extra NaN may appear. + flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool) + self.flt_data = flt_data.values self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map) self.data_index = self.data_index[np.where(self.flt_data == True)[0]] - self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) + self.start_idx, self.end_idx = self.data_index.slice_locs( + start=time_to_slc_point(start), end=time_to_slc_point(end) + ) self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance del self.data # save memory @@ -341,7 +351,7 @@ class TSDataSampler: setattr(self, k, v) @staticmethod - def build_index(data: pd.DataFrame) -> dict: + def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: """ The relation of the data @@ -352,9 +362,15 @@ class TSDataSampler: Returns ------- - dict: - {: } - # get the previous index of a line given index + Tuple[pd.DataFrame, dict]: + 1) the first element: reshape the original index into a 2D dataframe + instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ... + datetime + 2021-01-11 0 1 2 3 4 5 ... + 2021-01-12 4146 4147 4148 4149 4150 4151 ... + 2021-01-13 8293 8294 8295 8296 8297 8298 ... + 2021-01-14 12441 12442 12443 12444 12445 12446 ... + 2) the second element: {: } """ # object incase of pandas converting int to flaot idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object) diff --git a/qlib/log.py b/qlib/log.py index 379544392..ad55e2200 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -28,16 +28,18 @@ class QlibLogger(metaclass=MetaLogger): def __init__(self, module_name): self.module_name = module_name - self.level = 0 + # this feature name conflicts with the attribute with Logger + # rename it to avoid some corner cases that result in comparing `str` and `int` + self.__level = 0 @property def logger(self): logger = logging.getLogger(self.module_name) - logger.setLevel(self.level) + logger.setLevel(self.__level) return logger def setLevel(self, level): - self.level = level + self.__level = level def __getattr__(self, name): # During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error. diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 28d854477..a534a7a3b 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -8,7 +8,7 @@ There are two steps in each Trainer including ``train``(make model recorder) and This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training. In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting. -``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically. +``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically. """ import socket @@ -153,6 +153,9 @@ class Trainer: """ return self.delay + def __call__(self, *args, **kwargs) -> list: + return self.end_train(self.train(*args, **kwargs)) + class TrainerR(Trainer): """ @@ -286,7 +289,9 @@ class TrainerRM(Trainer): # This tag is the _id in TaskManager to distinguish tasks. TM_ID = "_id in TaskManager" - def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train): + def __init__( + self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False + ): """ Init TrainerR. @@ -294,11 +299,16 @@ class TrainerRM(Trainer): experiment_name (str): the default name of experiment. task_pool (str): task pool name in TaskManager. None for use same name as experiment_name. train_func (Callable, optional): default training method. Defaults to `task_train`. + skip_run_task (bool): + If skip_run_task == True: + Only run_task in the worker. Otherwise skip run_task. """ + super().__init__() self.experiment_name = experiment_name self.task_pool = task_pool self.train_func = train_func + self.skip_run_task = skip_run_task def train( self, @@ -340,15 +350,16 @@ class TrainerRM(Trainer): tm = TaskManager(task_pool=task_pool) _id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB query = {"_id": {"$in": _id_list}} - run_task( - train_func, - task_pool, - query=query, # only train these tasks - experiment_name=experiment_name, - before_status=before_status, - after_status=after_status, - **kwargs, - ) + if not self.skip_run_task: + run_task( + train_func, + task_pool, + query=query, # only train these tasks + experiment_name=experiment_name, + before_status=before_status, + after_status=after_status, + **kwargs, + ) if not self.is_delay(): tm.wait(query=query) @@ -411,6 +422,7 @@ class DelayTrainerRM(TrainerRM): task_pool: str = None, train_func=begin_task_train, end_train_func=end_task_train, + skip_run_task: bool = False, ): """ Init DelayTrainerRM. @@ -420,10 +432,15 @@ class DelayTrainerRM(TrainerRM): task_pool (str): task pool name in TaskManager. None for use same name as experiment_name. train_func (Callable, optional): default train method. Defaults to `begin_task_train`. end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`. + skip_run_task (bool): + If skip_run_task == True: + Only run_task in the worker. Otherwise skip run_task. + E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs. """ super().__init__(experiment_name, task_pool, train_func) self.end_train_func = end_train_func self.delay = True + self.skip_run_task = skip_run_task def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: """ @@ -477,14 +494,15 @@ class DelayTrainerRM(TrainerRM): _id_list.append(rec.list_tags()[self.TM_ID]) query = {"_id": {"$in": _id_list}} - run_task( - end_train_func, - task_pool, - query=query, # only train these tasks - experiment_name=experiment_name, - before_status=TaskManager.STATUS_PART_DONE, - **kwargs, - ) + if not self.skip_run_task: + run_task( + end_train_func, + task_pool, + query=query, # only train these tasks + experiment_name=experiment_name, + before_status=TaskManager.STATUS_PART_DONE, + **kwargs, + ) TaskManager(task_pool=task_pool).wait(query=query) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 1e8ee2e48..778d0e17a 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -642,6 +642,28 @@ def split_pred(pred, number=None, split_date=None): return pred_left, pred_right +def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]: + """ + Time slicing in Qlib or Pandas is a frequently-used action. + However, user often input all kinds of data format to represent time. + This function will help user to convert these inputs into a uniform format which is friendly to time slicing. + + Parameters + ---------- + t : Union[None, str, pd.Timestamp] + original time + + Returns + ------- + Union[None, pd.Timestamp]: + """ + if t is None: + # None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]). + return t + else: + return pd.Timestamp(t) + + def can_use_cache(): res = True r = get_redis_connection() diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 467c7c3f4..08f429eb3 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -213,11 +213,15 @@ class Experiment: """ raise NotImplementedError(f"Please implement the `_get_recorder` method") - def list_recorders(self): + def list_recorders(self, **flt_kwargs): """ List all the existing recorders of this experiment. Please first get the experiment instance before calling this method. If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`. + flt_kwargs : dict + filter recorders by conditions + e.g. list_recorders(status=Recorder.STATUS_FI) + Returns ------- A dictionary (id -> recorder) of recorder information that being stored. @@ -320,11 +324,14 @@ class MLflowExperiment(Experiment): UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!! - def list_recorders(self, max_results=UNLIMITED): + def list_recorders(self, max_results=UNLIMITED, status=None): runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results) recorders = dict() for i in range(len(runs)): recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i]) + if status is not None: + if recorder.status != status: + continue recorders[runs[i].info.run_id] = recorder return recorders diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 04cc3bcb7..751459d81 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -351,6 +351,15 @@ class MLflowExpManager(ExpManager): experiment_id is not None or experiment_name is not None ), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder." if experiment_id is not None: + try: + experiment_id = int(experiment_id) + except ValueError as e: + msg = "The `experiment_id` for mlflow backend must be `int`" + logger.error(msg) + # We have to raise type error here + # - The error looks like type error + # - Value Error will be catched + raise TypeError(msg) try: exp = self.client.get_experiment(experiment_id) if exp.lifecycle_stage.upper() == "DELETED": diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index d3cc0cbf8..b4b509483 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -6,7 +6,7 @@ OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models. In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated. -So this module provides a series of methods to control this process. +So this module provides a series of methods to control this process. This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history. Which means you can verify your strategy or find a better one. @@ -31,7 +31,7 @@ Simulation + Trainer When your models have some temporal dependence on the Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer for the ability to multitasking. It means all tasks in all routines - can be REAL trained at the end of simulating. The signals will be prepared well at + can be REAL trained at the end of simulating. The signals will be prepared well at different time segments (based on whether or not any new model is online). ========================= =================================================================================== """ @@ -113,6 +113,8 @@ class OnlineManager(Serializable): models = self.trainer.train(tasks, experiment_name=strategy.name_id) models_list.append(models) self.logger.info(f"Finished training {len(models)} models.") + # FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the + # start. online_models = strategy.prepare_online_models(models, **model_kwargs) self.history.setdefault(self.cur_time, {})[strategy] = online_models @@ -148,8 +150,6 @@ class OnlineManager(Serializable): models_list = [] for strategy in self.strategies: self.logger.info(f"Strategy `{strategy.name_id}` begins routine...") - if self.status == self.STATUS_NORMAL: - strategy.tool.update_online_pred() tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs) models = self.trainer.train(tasks, experiment_name=strategy.name_id) @@ -158,6 +158,11 @@ class OnlineManager(Serializable): online_models = strategy.prepare_online_models(models, **model_kwargs) self.history.setdefault(self.cur_time, {})[strategy] = online_models + # The online model may changes in the above processes + # So updating the predictions of online models should be the last step + if self.status == self.STATUS_NORMAL: + strategy.tool.update_online_pred() + if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay(): for strategy, models in zip(self.strategies, models_list): models = self.trainer.end_train(models, experiment_name=strategy.name_id) @@ -236,7 +241,7 @@ class OnlineManager(Serializable): SIM_LOG_NAME = "SIMULATE_INFO" def simulate( - self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={} + self, end_time=None, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={} ) -> Union[pd.Series, pd.DataFrame]: """ Starting from the current time, this method will simulate every routine in OnlineManager until the end time. diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index a54eb32bf..1e8e85c0f 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -52,6 +52,12 @@ class OnlineStrategy: NOTE: Reset all online models to trained models. If there are no trained models, then do nothing. + **NOTE**: + Current implementation is very naive. Here is a more complex situation which is more closer to the + practical scenarios. + 1. Train new models at the day before `test_start` (at time stamp `T`) + 2. Switch models at the `test_start` (at time timestamp `T + 1` typically) + Args: models (list): a list of models. cur_time (pd.Dataframe): current time from OnlineManger. None for the latest. diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 9410c2b9c..36ccf434d 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -6,6 +6,7 @@ Collector module can collect objects from everywhere and process them such as me """ from typing import Callable, Dict, List +from qlib.log import get_module_logger from qlib.utils.serial import Serializable from qlib.workflow import R @@ -192,6 +193,7 @@ class RecorderCollector(Collector): if rec_filter_func is None or rec_filter_func(rec): recs_flt[rid] = rec + logger = get_module_logger("RecorderCollector") for _, rec in recs_flt.items(): rec_key = self.rec_key_func(rec) for key in artifacts_key: @@ -205,7 +207,13 @@ class RecorderCollector(Collector): # only collect existing artifact continue raise e - collect_dict.setdefault(key, {})[rec_key] = artifact + # give user some warning if the values are overridden + cdd = collect_dict.setdefault(key, {}) + if rec_key in cdd: + logger.warning( + f"key '{rec_key}' is duplicated. Previous value will be overrides. Please check you `rec_key_func`" + ) + cdd[rec_key] = artifact return collect_dict diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index cdebf5049..ca7b8ae7f 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -6,6 +6,8 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp import abc import copy from typing import List, Union, Callable + +from qlib.utils import transform_end_date from .utils import TimeAdjuster @@ -199,7 +201,7 @@ class RollingGen(TaskGen): # First rolling # 1) prepare the end point segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) - test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1] + test_end = transform_end_date(segments[self.test_key][1]) # 2) and init test segments test_start_idx = self.ta.align_idx(segments[self.test_key][0]) segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 7a85036da..01f79b1b4 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -275,7 +275,7 @@ class TaskManager: except Exception: if task is not None: self.logger.info("Returning task before raising error") - self.return_task(task) + self.return_task(task, status=status) # return task as the original status self.logger.info("Task returned") raise @@ -411,7 +411,11 @@ class TaskManager: self.task_pool.update_one({"_id": task["_id"]}, update_dict) def _get_undone_n(self, task_stat): - return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0) + return ( + task_stat.get(self.STATUS_WAITING, 0) + + task_stat.get(self.STATUS_RUNNING, 0) + + task_stat.get(self.STATUS_PART_DONE, 0) + ) def _get_total(self, task_stat): return sum(task_stat.values()) @@ -429,7 +433,7 @@ class TaskManager: last_undone_n = self._get_undone_n(task_stat) if last_undone_n == 0: return - self.logger.warn(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.") + self.logger.warning(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.") with tqdm(total=total, initial=total - last_undone_n) as pbar: while True: time.sleep(10) From 5850490b245a78279adbb9dd10aeeec0d3a17f21 Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 11 Jun 2021 08:23:57 +0000 Subject: [PATCH 07/44] simplify the code and add docs --- qlib/workflow/exp.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 08f429eb3..627b5ff82 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from typing import Union import mlflow, logging from mlflow.entities import ViewType from mlflow.exceptions import MlflowException @@ -324,14 +325,21 @@ class MLflowExperiment(Experiment): UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!! - def list_recorders(self, max_results=UNLIMITED, status=None): + def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None): + """ + Parameters + ---------- + max_results : int + the number limitation of the results + status : str + the criteria based on status to filter results. + `None` indicates no filtering. + """ runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results) recorders = dict() for i in range(len(runs)): recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i]) - if status is not None: - if recorder.status != status: - continue - recorders[runs[i].info.run_id] = recorder + if status is None or recorder.status == status: + recorders[runs[i].info.run_id] = recorder return recorders From 730f6258d6ec6b88ed3a8e42d2f8d70b3ddc12b7 Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 11 Jun 2021 10:40:56 +0000 Subject: [PATCH 08/44] add warning and * --- qlib/workflow/__init__.py | 10 ++++++---- qlib/workflow/expm.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 2b2535edc..63f63fb56 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -215,9 +215,9 @@ class QlibRecorder: ------- A dictionary (id -> recorder) of recorder information that being stored. """ - return self.get_exp(experiment_id, experiment_name).list_recorders() + return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders() - def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment: + def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment: """ Method for retrieving an experiment with given id or name. Once the `create` argument is set to True, if no valid experiment is found, this method will create one for you. Otherwise, it will @@ -262,7 +262,7 @@ class QlibRecorder: # Case 2 with R.start('test'): - exp = R.get_exp('test1') + exp = R.get_exp(experiment_name='test1') # Case 3 exp = R.get_exp() -> a default experiment. @@ -287,7 +287,9 @@ class QlibRecorder: ------- An experiment instance with given id or name. """ - return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False) + return self.exp_manager.get_exp( + experiment_id=experiment_id, experiment_name=experiment_name, create=create, start=False + ) def delete_exp(self, experiment_id=None, experiment_name=None): """ diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 751459d81..7e39d3a32 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -109,7 +109,7 @@ class ExpManager: """ raise NotImplementedError(f"Please implement the `search_records` method.") - def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False): + def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False): """ Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment. @@ -190,7 +190,7 @@ class ExpManager: except ValueError: if experiment_name is None: experiment_name = self._default_exp_name - logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.") + logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.") return self.create_exp(experiment_name), True def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment: From 973c4137e442a2ad2a73fd94230a1c7abea75733 Mon Sep 17 00:00:00 2001 From: Young Date: Sat, 12 Jun 2021 13:54:26 +0000 Subject: [PATCH 09/44] fix mlflow & task bug --- qlib/model/base.py | 2 +- qlib/workflow/__init__.py | 6 ++++-- qlib/workflow/expm.py | 11 ++--------- qlib/workflow/task/manage.py | 2 +- 4 files changed, 8 insertions(+), 13 deletions(-) diff --git a/qlib/model/base.py b/qlib/model/base.py index 12caf5f73..493981133 100644 --- a/qlib/model/base.py +++ b/qlib/model/base.py @@ -97,7 +97,7 @@ class ModelFT(Model): # Finetune model based on previous trained model with R.start(experiment_name="finetune model"): - recorder = R.get_recorder(rid, experiment_name="init models") + recorder = R.get_recorder(recorder_id=rid, experiment_name="init models") model = recorder.load_object("init_model") model.finetune(dataset, num_boost_round=10) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 63f63fb56..a14f60c01 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -333,7 +333,9 @@ class QlibRecorder: """ self.exp_manager.set_uri(uri) - def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder: + def get_recorder( + self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None + ) -> Recorder: """ Method for retrieving a recorder. @@ -386,7 +388,7 @@ class QlibRecorder: ------- A recorder instance. """ - return self.get_exp(experiment_name=experiment_name, create=False).get_recorder( + return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder( recorder_id, recorder_name, create=False, start=False ) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 7e39d3a32..84cc6a13a 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -352,15 +352,8 @@ class MLflowExpManager(ExpManager): ), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder." if experiment_id is not None: try: - experiment_id = int(experiment_id) - except ValueError as e: - msg = "The `experiment_id` for mlflow backend must be `int`" - logger.error(msg) - # We have to raise type error here - # - The error looks like type error - # - Value Error will be catched - raise TypeError(msg) - try: + # NOTE: the mlflow's experiment_id must be str type... + # https://www.mlflow.org/docs/latest/python_api/mlflow.tracking.html#mlflow.tracking.MlflowClient.get_experiment exp = self.client.get_experiment(experiment_id) if exp.lifecycle_stage.upper() == "DELETED": raise MlflowException("No valid experiment has been found.") diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 01f79b1b4..41e243b43 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -272,7 +272,7 @@ class TaskManager: task = self.fetch_task(query=query, status=status) try: yield task - except Exception: + except (Exception, KeyboardInterrupt): # KeyboardInterrupt is not a subclass of Exception if task is not None: self.logger.info("Returning task before raising error") self.return_task(task, status=status) # return task as the original status From 9e0e2ff7362989b2267702d20dfc013bc753f78b Mon Sep 17 00:00:00 2001 From: Jactus Date: Tue, 15 Jun 2021 14:46:31 +0800 Subject: [PATCH 10/44] Update QlibRecorder wrapper --- qlib/workflow/__init__.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 2b2535edc..7652ff709 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -525,14 +525,33 @@ class QlibRecorder: self.get_exp().get_recorder().set_tags(**kwargs) +# error type for reinitialization when starting an experiment +class RecorderInitializationError(Exception): + def __init__(self, message): + super(RecorderInitializationError, self).__init__(message) + + +class RecorderWrapper(Wrapper): + """ + Wrapper class for QlibRecorder, which detects whether users reinitialize qlib when already starting an experiment. + """ + + def register(self, provider): + if self._provider is not None: + raise RecorderInitializationError( + "Please don't reinitialize Qlib if QlibRecorder is already acivated. Otherwise, the experiment stored location will be modified." + ) + self._provider = provider + + import sys if sys.version_info >= (3, 9): from typing import Annotated - QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper] + QlibRecorderWrapper = Annotated[QlibRecorder, RecorderWrapper] else: QlibRecorderWrapper = QlibRecorder # global record -R: QlibRecorderWrapper = Wrapper() +R: QlibRecorderWrapper = RecorderWrapper() From 64582e9d4622a7f1f6b1a16a67099e289c03e1b9 Mon Sep 17 00:00:00 2001 From: Jactus Date: Tue, 15 Jun 2021 15:02:11 +0800 Subject: [PATCH 11/44] Add QlibException --- qlib/utils/exceptions.py | 12 ++++++++++++ qlib/workflow/__init__.py | 7 +------ 2 files changed, 13 insertions(+), 6 deletions(-) create mode 100644 qlib/utils/exceptions.py diff --git a/qlib/utils/exceptions.py b/qlib/utils/exceptions.py new file mode 100644 index 000000000..69712172b --- /dev/null +++ b/qlib/utils/exceptions.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Base exception class +class QlibException(Exception): + def __init__(self, message): + super(QlibException, self).__init__(message) + + +# Error type for reinitialization when starting an experiment +class RecorderInitializationError(QlibException): + pass diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 7652ff709..e5cdbb71c 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -7,6 +7,7 @@ from .expm import MLflowExpManager from .exp import Experiment from .recorder import Recorder from ..utils import Wrapper +from ..utils.exceptions import RecorderInitializationError class QlibRecorder: @@ -525,12 +526,6 @@ class QlibRecorder: self.get_exp().get_recorder().set_tags(**kwargs) -# error type for reinitialization when starting an experiment -class RecorderInitializationError(Exception): - def __init__(self, message): - super(RecorderInitializationError, self).__init__(message) - - class RecorderWrapper(Wrapper): """ Wrapper class for QlibRecorder, which detects whether users reinitialize qlib when already starting an experiment. From 5331ab93f883a10f951a07c3f6d87ad33d88e2e5 Mon Sep 17 00:00:00 2001 From: lewwang Date: Wed, 16 Jun 2021 12:18:16 +0800 Subject: [PATCH 12/44] Update TCTS README. --- examples/benchmarks/TCTS/README.md | 52 ++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 examples/benchmarks/TCTS/README.md diff --git a/examples/benchmarks/TCTS/README.md b/examples/benchmarks/TCTS/README.md new file mode 100644 index 000000000..ee67ffbeb --- /dev/null +++ b/examples/benchmarks/TCTS/README.md @@ -0,0 +1,52 @@ +# Temporally Correlated Task Scheduling for Sequence Learning +We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments. + +### Background +Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other. + +

+ +

+ + +### Method +Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2. + +

+ +

+ +At step , with training data , the scheduler chooses a suitable task (green solid lines) to update the model (blue solid lines). After steps, we evaluate the model on the validation set and update the scheduler (green dashed lines). + +### DataSet +* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020. +* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time. + +### Experiments +#### Task Description +* The main tasks ( in Figure1) refers to forecasting return of stock as following, +
+ +
+ +* Temporally correlated task sets , in this paper, , and are used. +#### Baselines +* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT) +* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task. +* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task , then , and gradually move to the last one. +#### Result +| Methods | | | | +| :----: | :----: | :----: | :----: | +| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 | +| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 | +| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 | +| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 | +| MTL() | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 | +| CL() | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 | +| Ours() | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 | +| MTL() | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 | +| CL() | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 | +| Ours() | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 | +| MTL() | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 | +| CL() | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 | +| Ours() | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942| \ No newline at end of file From 0fe8b281ba7d4e1269d8d273196a853a1aa686a2 Mon Sep 17 00:00:00 2001 From: Jactus Date: Wed, 16 Jun 2021 12:28:20 +0800 Subject: [PATCH 13/44] Update R wrapper logic --- qlib/workflow/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index e5cdbb71c..98b2c9925 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -533,9 +533,11 @@ class RecorderWrapper(Wrapper): def register(self, provider): if self._provider is not None: - raise RecorderInitializationError( - "Please don't reinitialize Qlib if QlibRecorder is already acivated. Otherwise, the experiment stored location will be modified." - ) + expm = getattr(self._provider, "exp_manager") + if expm.active_experiment is not None: + raise RecorderInitializationError( + "Please don't reinitialize Qlib if QlibRecorder is already acivated. Otherwise, the experiment stored location will be modified." + ) self._provider = provider From b4efbd53b2f8889b984a5f283e8d62cd3ecf1976 Mon Sep 17 00:00:00 2001 From: zhupr Date: Wed, 16 Jun 2021 22:00:43 +0800 Subject: [PATCH 14/44] Fix 'report' compatibility with matplotlib versions --- .../analysis_model_performance.py | 35 ++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/qlib/contrib/report/analysis_model/analysis_model_performance.py b/qlib/contrib/report/analysis_model/analysis_model_performance.py index 1cb14d261..1d444b104 100644 --- a/qlib/contrib/report/analysis_model/analysis_model_performance.py +++ b/qlib/contrib/report/analysis_model/analysis_model_performance.py @@ -3,7 +3,6 @@ import pandas as pd -import plotly.tools as tls import plotly.graph_objs as go import statsmodels.api as sm @@ -80,9 +79,37 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure: :param dist: :return: """ - fig, ax = plt.subplots(figsize=(8, 5)) - _mpl_fig = sm.qqplot(data.dropna(), dist, fit=True, line="45", ax=ax) - return tls.mpl_to_plotly(_mpl_fig) + _plt_fig = sm.qqplot(data.dropna(), dist=dist, fit=True, line="45") + plt.close(_plt_fig) + qqplot_data = _plt_fig.gca().lines + fig = go.Figure() + + fig.add_trace({ + 'type': 'scatter', + 'x': qqplot_data[0].get_xdata(), + # 'x': [0, 1], + 'y': qqplot_data[0].get_ydata(), + # 'y': [1, 2], + 'mode': 'markers', + 'marker': { + 'color': '#19d3f3' + } + }) + + fig.add_trace({ + 'type': 'scatter', + 'x': qqplot_data[1].get_xdata(), + # 'x': [0, 1], + 'y': qqplot_data[1].get_ydata(), + # 'y': [1, 2], + 'mode': 'lines', + 'line': { + 'color': '#636efa' + } + + }) + del qqplot_data + return fig def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple: From 9c8d423a86ec4c2ef306e6ac80fb67ba340cb4dc Mon Sep 17 00:00:00 2001 From: Young Date: Wed, 16 Jun 2021 14:10:51 +0000 Subject: [PATCH 15/44] fix ModelUpdater --- qlib/workflow/online/update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index 561f7e18a..96cbf4d65 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -136,7 +136,7 @@ class PredUpdater(RecordUpdater): # https://github.com/pytorch/pytorch/issues/16797 start_time = get_date_by_shift(self.last_end, 1, freq=self.freq) - if start_time >= self.to_date: + if start_time > self.to_date: self.logger.info( f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}." ) From a4f6e0419943428def7d5bd12958d065d07ecc9f Mon Sep 17 00:00:00 2001 From: zhupr Date: Thu, 17 Jun 2021 22:33:31 +0800 Subject: [PATCH 16/44] modify dump_update starts with the last end date of each symbol --- scripts/dump_bin.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index b3a18cc90..83daa28bc 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -401,6 +401,8 @@ class DumpDataUpdate(DumpDataBase): ) self._mode = self.UPDATE_MODE self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt")) + # NOTE: all.txt only exists once for each stock + # NOTE: if a stock corresponds to multiple different time ranges, user need to modify self._update_instruments self._update_instruments = ( self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)) .set_index([self.symbol_field_name]) @@ -409,10 +411,9 @@ class DumpDataUpdate(DumpDataBase): # load all csv files self._all_data = self._load_all_source_data() # type: pd.DataFrame - self._update_calendars = sorted( + self._new_calendar_list = self._old_calendar_list + sorted( filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique()) ) - self._new_calendar_list = self._old_calendar_list + self._update_calendars def _load_all_source_data(self): # NOTE: Need more memory @@ -452,8 +453,16 @@ class DumpDataUpdate(DumpDataBase): if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)): continue if _code in self._update_instruments: + # exists stock, will append data + _update_calendars = ( + _df[_df[self.date_field_name] > self._update_instruments[_code][self.INSTRUMENTS_START_FIELD]][ + self.date_field_name + ] + .sort_values() + .to_list() + ) self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end) - futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code + futures[executor.submit(self._dump_bin, _df, _update_calendars)] = _code else: # new stock _dt_range = self._update_instruments.setdefault(_code, dict()) From b6c31540e8c8bd7559b58bb6e4e268e9f91d32d5 Mon Sep 17 00:00:00 2001 From: zhupr Date: Thu, 17 Jun 2021 23:01:08 +0800 Subject: [PATCH 17/44] add function to automatically update daily frequency data --- README.md | 22 +++++ docs/component/data.rst | 28 ++++++ scripts/data_collector/cn_index/collector.py | 2 +- scripts/data_collector/us_index/collector.py | 2 +- scripts/data_collector/yahoo/README.md | 61 +++++++++++-- scripts/data_collector/yahoo/collector.py | 95 +++++++++++++++++--- 6 files changed, 189 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 8276c4951..635b143f4 100644 --- a/README.md +++ b/README.md @@ -159,6 +159,28 @@ Users could create the same dataset with it. *Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect. We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*. +### Automatic update of daily frequency data(from yahoo finance) + > It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically. + + > For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#Automatic-update-of-daily-frequency-data) + + * Automatic update of data to the "qlib" directory each trading day(Linux) + * use *crontab*: `crontab -e` + * set up timed tasks: + + ``` + * * * * 1-5 python