diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index 12983f6a5..cb51f9b22 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -226,11 +226,7 @@ class BaseCollector(abc.ABC): class BaseNormalize(abc.ABC): - def __init__( - self, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): + def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs): """ Parameters @@ -265,6 +261,7 @@ class Normalize: max_workers: int = 16, date_field_name: str = "date", symbol_field_name: str = "symbol", + **kwargs, ): """ @@ -291,7 +288,9 @@ class Normalize: self._max_workers = max_workers - self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name) + self._normalize_obj = normalize_class( + date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs + ) def _executor(self, file_path: Path): file_path = Path(file_path) @@ -404,7 +403,7 @@ class BaseRun(abc.ABC): limit_nums=limit_nums, ).collector_data() - def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs): """normalize data Parameters @@ -426,5 +425,6 @@ class BaseRun(abc.ABC): max_workers=self.max_workers, date_field_name=date_field_name, symbol_field_name=symbol_field_name, + **kwargs, ) yc.normalize() diff --git a/scripts/data_collector/cn_index/collector.py b/scripts/data_collector/cn_index/collector.py index 5af9785ec..1f8434c58 100644 --- a/scripts/data_collector/cn_index/collector.py +++ b/scripts/data_collector/cn_index/collector.py @@ -24,7 +24,10 @@ from data_collector.utils import get_calendar_list, get_trading_date_by_shift NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls" -INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A" + +# INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A" +# 2020-11-27 Announcement title change +INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89" class CSIIndex(IndexBase): diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 2cd080199..fcaa9ff92 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -23,7 +23,7 @@ from qlib.config import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from data_collector.base import BaseCollector, BaseNormalize, BaseRun +from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize from data_collector.utils import ( get_calendar_list, get_hs_stock_symbols, @@ -297,6 +297,7 @@ class YahooNormalize(BaseNormalize): calendar_list: list = None, date_field_name: str = "date", symbol_field_name: str = "symbol", + last_close: float = None, ): if df.empty: return df @@ -318,7 +319,10 @@ class YahooNormalize(BaseNormalize): df.sort_index(inplace=True) df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan _tmp_series = df["close"].fillna(method="ffill") - df["change"] = _tmp_series / _tmp_series.shift(1) - 1 + _tmp_shift_series = _tmp_series.shift(1) + if last_close is not None and isinstance(last_close, (int, float)): + _tmp_shift_series.iloc[0] = last_close + df["change"] = _tmp_series / _tmp_shift_series - 1 columns += ["change"] df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan @@ -367,6 +371,17 @@ class YahooNormalize1d(YahooNormalize, ABC): df = self._manual_adj_data(df) return df + def _get_first_close(self, df: pd.DataFrame) -> float: + """get first close value + + Notes + ----- + For incremental updates(append) to Yahoo 1D data, user need to use a close that is not 0 on the first trading day of the existing data + """ + df = df.loc[df["close"].first_valid_index() :] + _close = df["close"].iloc[0] + return _close + def _manual_adj_data(self, df: pd.DataFrame) -> pd.DataFrame: """manual adjust data: All fields (except change) are standardized according to the close of the first day""" if df.empty: @@ -374,8 +389,7 @@ class YahooNormalize1d(YahooNormalize, ABC): df = df.copy() df.sort_values(self._date_field_name, inplace=True) df = df.set_index(self._date_field_name) - df = df.loc[df["close"].first_valid_index() :] - _close = df["close"].iloc[0] + _close = self._get_first_close(df) for _col in df.columns: if _col == self._symbol_field_name: continue @@ -388,18 +402,97 @@ class YahooNormalize1d(YahooNormalize, ABC): return df.reset_index() +class YahooNormalize1dExtend(YahooNormalize1d): + def __init__( + self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + ): + """ + + Parameters + ---------- + old_qlib_data_dir: str, Path + the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name) + self._end_date, self._old_close = self._get_old_data(old_qlib_data_dir) + self._end_date = pd.Timestamp(self._end_date).strftime(self.DAILY_FORMAT) + + def _get_old_data(self, qlib_data_dir: [str, Path]): + import qlib + from qlib.data import D + + qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve()) + qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None) + df = D.features(D.instruments("all"), ["$close/$factor"]) + df.columns = ["close"] + return D.calendar()[-1], df + + def _get_first_close(self, df: pd.DataFrame) -> float: + _symbol = df.iloc[0][self._symbol_field_name] + try: + _df = self._old_close.loc(axis=0)[_symbol.upper()] + _close = _df.loc[_df.first_valid_index()]["close"] + except KeyError: + _close = super(YahooNormalize1dExtend, self)._get_first_close(df) + return _close + + def _get_last_close(self, df: pd.DataFrame) -> float: + _symbol = df.iloc[0][self._symbol_field_name] + try: + _df = self._old_close.loc(axis=0)[_symbol.upper()] + _close = _df.loc[_df.last_valid_index()]["close"] + except KeyError: + _close = None + return _close + + def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp: + _symbol = df.iloc[0][self._symbol_field_name] + try: + _df = self._old_close.loc(axis=0)[_symbol.upper()] + _date = _df.index.max() + except KeyError: + _date = None + return _date + + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + _last_close = self._get_last_close(df) + # reindex + _last_date = self._get_last_date(df) + if _last_date is not None: + df = df.set_index(self._date_field_name) + df.index = pd.to_datetime(df.index) + df = df[~df.index.duplicated(keep="first")] + _max_date = df.index.max() + df = df.reindex(self._calendar_list).loc[:_max_date].reset_index() + df = df[df[self._date_field_name] > _last_date] + _si = df["close"].first_valid_index() + if _si > df.index[0]: + logger.warning( + f"{df.iloc[0][self._symbol_field_name]} missing data: {df.loc[:_si][self._date_field_name]}" + ) + # normalize + df = self.normalize_yahoo( + df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close + ) + # adjusted price + df = self.adjusted_price(df) + df = self._manual_adj_data(df) + return df + + class YahooNormalize1min(YahooNormalize, ABC): AM_RANGE = None # type: tuple # eg: ("09:30:00", "11:29:00") PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00") # Whether the trading day of 1min data is consistent with 1d CONSISTENT_1d = False + CALC_PAUSED_NUM = False - def __init__( - self, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): + def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs): """ Parameters @@ -478,6 +571,54 @@ class YahooNormalize1min(YahooNormalize, ABC): df[_col] = df[_col] / df["factor"] else: df[_col] = df[_col] * df["factor"] + + if self.CALC_PAUSED_NUM: + df = self.calc_paused_num(df) + return df + + def calc_paused_num(self, df: pd.DataFrame): + _symbol = df.iloc[0][self._symbol_field_name] + df = df.copy() + df["date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) + # remove data that starts and ends with `np.nan` all day + all_data = [] + # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan + all_nan_nums = 0 + # Record the number of consecutive occurrences of trading days that are not nan throughout the day + not_nan_nums = 0 + for _date, _df in df.groupby(level="date"): + _df["paused"] = 0 + if not _df.loc[_df["volume"] < 0].empty: + logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}") + _df.loc[_df["volume"] < 0, "volume"] = np.nan + + check_fields = set(_df.columns) - { + "date", + "paused", + "factor", + self._date_field_name, + self._symbol_field_name, + } + if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all(): + all_nan_nums += 1 + not_nan_nums = 0 + _df["paused"] = 1 + if all_data: + _df["paused_num"] = not_nan_nums + all_data.append(_df) + else: + all_nan_nums = 0 + not_nan_nums += 1 + _df["paused_num"] = not_nan_nums + all_data.append(_df) + all_data = all_data[: len(all_data) - all_nan_nums] + if all_data: + df = pd.concat(all_data, sort=False) + else: + logger.warning(f"data is empty: {_symbol}") + df = pd.DataFrame() + return df + del df["date"] return df @abc.abstractmethod @@ -523,11 +664,16 @@ class YahooNormalizeCN1d(YahooNormalizeCN, YahooNormalize1d): pass +class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend): + pass + + class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min): AM_RANGE = ("09:30:00", "11:29:00") PM_RANGE = ("13:00:00", "14:59:00") CONSISTENT_1d = True + CALC_PAUSED_NUM = True def _get_calendar_list(self): return self.generate_1min_from_daily(self.calendar_list_1d) @@ -624,10 +770,54 @@ class Run(BaseRun): Examples --------- - $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d + $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d """ super(Run, self).normalize_data(date_field_name, symbol_field_name) + def normalize_data_1d_extend( + self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol" + ): + """normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data) + + Notes + ----- + Steps to extend yahoo qlib data: + + 1. download qlib data: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data; save to + + 2. collector source data: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#collector-data; save to + + 3. normalize new source data(from step 2): python scripts/data_collector/yahoo/collector.py normalize_data_1d_extend --old_qlib_dir --source_dir --normalize_dir --region CN --interval 1d + + 4. dump data: python scripts/dump_bin.py dump_update --csv_path --qlib_dir --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date + + 5. update instrument(eg. csi300): python python scripts/data_collector/cn_index/collector.py --index_name CSI300 --qlib_dir --method parse_instruments + + Parameters + ---------- + old_qlib_data_dir: str + the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data + date_field_name: str + date field name, default date + symbol_field_name: str + symbol field name, default symbol + + Examples + --------- + $ python collector.py normalize_data_1d_extend --old_qlib_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d + """ + _class = getattr(self._cur_module, f"{self.normalize_class_name}Extend") + yc = Normalize( + source_dir=self.source_dir, + target_dir=self.normalize_dir, + normalize_class=_class, + max_workers=self.max_workers, + date_field_name=date_field_name, + symbol_field_name=symbol_field_name, + old_qlib_data_dir=old_qlib_data_dir, + ) + yc.normalize() + if __name__ == "__main__": fire.Fire(Run)