From 9985befe6955c4953e1bb8b57854171b5df24181 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 29 Jun 2021 12:02:27 +0000 Subject: [PATCH] update HashingStockStorage --- qlib/data/dataset/handler.py | 65 ++++++++++++++------- qlib/data/dataset/storage.py | 28 ++++++++- tests/test_handler_storage.py | 107 ++++++++++++++++++++++++++++++++++ 3 files changed, 176 insertions(+), 24 deletions(-) create mode 100644 tests/test_handler_storage.py diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 30cfa7732..475601625 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -175,7 +175,7 @@ class DataHandler(Serializable): select a set of meaningful columns.(e.g. features, columns) - if cal_set == CS_RAW: + if col_set == CS_RAW: the raw dataset will be returned. - if isinstance(col_set, List[str]): @@ -197,23 +197,33 @@ class DataHandler(Serializable): ------- pd.DataFrame. """ - if proc_func is None: - df = self._data - else: - # FIXME: fetching by time first will be more friendly to `proc_func` - # Copy in case of `proc_func` changing the data inplace.... - df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy()) + from .storage import HasingStockStorage + + data_storage = self._data + if isinstance(data_storage, pd.DataFrame): + data_df = data_storage + if proc_func is not None: + # FIXME: fetching by time first will be more friendly to `proc_func` + # Copy in case of `proc_func` changing the data inplace.... + data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy()) + + # Fetch column first will be more friendly to SepDataFrame + data_df = fetch_df_by_col(data_df, col_set) + data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) + elif isinstance(data_storage, HasingStockStorage): + if proc_func is not None: + warnings.warn(f"proc_func is not supported by the HasingStockStorage") + data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + else: + raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") - # Fetch column first will be more friendly to SepDataFrame - df = fetch_df_by_col(df, col_set) - df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) if squeeze: # squeeze columns - df = df.squeeze() + data_df = data_df.squeeze() # squeeze index if isinstance(selector, (str, pd.Timestamp)): - df = df.reset_index(level=level, drop=True) - return df + data_df = data_df.reset_index(level=level, drop=True) + return data_df def get_cols(self, col_set=CS_ALL) -> list: """ @@ -511,14 +521,27 @@ class DataHandlerLP(DataHandler): ------- pd.DataFrame: """ - df = self._get_df_by_key(data_key) - if proc_func is not None: - # FIXME: fetch by time first will be more friendly to proc_func - # Copy incase of `proc_func` changing the data inplace.... - df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy()) - # Fetch column first will be more friendly to SepDataFrame - df = fetch_df_by_col(df, col_set) - return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) + from .storage import HasingStockStorage + + data_storage = self._get_df_by_key(data_key) + if isinstance(data_storage, pd.DataFrame): + data_df = data_storage + if proc_func is not None: + # FIXME: fetch by time first will be more friendly to proc_func + # Copy incase of `proc_func` changing the data inplace.... + data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy()) + # Fetch column first will be more friendly to SepDataFrame + data_df = fetch_df_by_col(data_df, col_set) + data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) + + elif isinstance(data_storage, HasingStockStorage): + if proc_func is not None: + warnings.warn(f"proc_func is not supported by the HasingStockStorage") + data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + else: + raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") + + return data_df def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: """ diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index 1849b6fcb..66895cfe7 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -2,7 +2,7 @@ import pandas as pd import numpy as np from .handler import DataHandler -from typing import Tuple, Union, List +from typing import Tuple, Union, List, Callable from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col @@ -13,8 +13,29 @@ class BaseHandlerStorage: selector: Union[pd.Timestamp, slice, str, list] = slice(None, None), level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = DataHandler.CS_ALL, + fetch_orig: bool = True, **kwargs, ) -> pd.DataFrame: + """fetch data from the data storage + + Parameters + ---------- + selector : Union[pd.Timestamp, slice, str] + describe how to select data by index + level : Union[str, int] + which index level to select the data + col_set : Union[str, List[str]] + - if isinstance(col_set, str): + select a set of meaningful columns.(e.g. features, columns) + if col_set == DataHandler.CS_RAW: + the raw dataset will be returned. + - if isinstance(col_set, List[str]): + select several sets of meaningful columns, the returned data has multiple level + fetch_orig : bool + Return the original data instead of copy if possible. + + """ + raise NotImplementedError("fetch is method not implemented!") @staticmethod @@ -68,11 +89,12 @@ class HasingStockStorage(BaseHandlerStorage): selector: Union[pd.Timestamp, slice, str] = slice(None, None), level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = DataHandler.CS_ALL, + fetch_orig: bool = True, ) -> pd.DataFrame: fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values()) for _index, stock_df in enumerate(fetch_stock_df_list): fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set) - fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level) + fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig) fetch_stock_df_list[_index] = fetch_index_df if len(fetch_stock_df_list) == 0: index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument") @@ -82,4 +104,4 @@ class HasingStockStorage(BaseHandlerStorage): elif len(fetch_stock_df_list) == 1: return fetch_stock_df_list[0] else: - return pd.concat(fetch_stock_df_list, axis=0, sort=False) + return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig) diff --git a/tests/test_handler_storage.py b/tests/test_handler_storage.py new file mode 100644 index 000000000..be36788bd --- /dev/null +++ b/tests/test_handler_storage.py @@ -0,0 +1,107 @@ +import unittest +import qlib +import time +import pandas as pd + +from qlib.data import D +from qlib.tests import TestAutoData + +from qlib.data.dataset.handler import DataHandlerLP +from qlib.data.dataset.processor import Processor +from qlib.contrib.data.handler import check_transform_proc +from qlib.utils import init_instance_by_config +from qlib.log import TimeInspector + + +class TestHandler(DataHandlerLP): + def __init__( + self, + instruments="csi300", + start_time=None, + end_time=None, + infer_processors=[], + learn_processors=[], + fit_start_time=None, + fit_end_time=None, + drop_raw=True, + ): + + infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) + learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "freq": "day", + "config": self.get_feature_config(), + "swap_level": False, + }, + } + + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + data_loader=data_loader, + infer_processors=infer_processors, + learn_processors=learn_processors, + drop_raw=drop_raw, + ) + + def get_feature_config(self): + fields = ["Ref($open, 1)", "Ref($close, 1)", "Ref($volume, 1)", "$open", "$close", "$volume"] + names = ["open_0", "close_0", "volume_0", "open_1", "close_1", "volume_1"] + return fields, names + + +class MiniTimer: + def __init__(self, name): + self.name = name + + def __enter__(self): + self.start = time.time() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end = time.time() + print(f"[MyTimer Info] <{self.name}> process costs {self.end - self.start} seconds") + + +class TestHandlerStorage(TestAutoData): + + market = "all" + + start_time = "2020-01-01" + end_time = "2020-12-31" + train_end_time = "2020-05-31" + test_start_time = "2020-06-01" + + data_handler_kwargs = { + "start_time": start_time, + "end_time": end_time, + "fit_start_time": start_time, + "fit_end_time": train_end_time, + "instruments": market, + "infer_processors": ["HashingStock"], + } + + def test_handler_storage(self): + with MiniTimer("init data hanlder"): + data_handler = TestHandler(**self.data_handler_kwargs) + + with MiniTimer("random fetch"): + print(data_handler.fetch(selector=("SH600170", slice(None)), level=None)) + print( + data_handler.fetch( + selector=("SH600170", slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))), level=None + ) + ) + print( + data_handler.fetch( + selector=(["SH600170", "SH600383"], slice(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01"))), + level=None, + ) + ) + + +if __name__ == "__main__": + unittest.main()