diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py index b6c1362fd..3108960c8 100644 --- a/examples/nested_decision_execution/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -124,14 +124,14 @@ class NestedDecisionExecutionWorkflow: def _init_qlib(self): """initialize qlib""" - # provider_uri_day = "/data/stock_data/huaxia/qlib" - # provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib" - provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir - GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True) - provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri") - GetData().qlib_data( - target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True - ) + provider_uri_day = "/data/stock_data/huaxia/qlib" + provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib" + # provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir + # GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True) + # provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri") + # GetData().qlib_data( + # target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True + # ) provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day} client_config = { "calendar_provider": { diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index f217ea169..7623af551 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -91,7 +91,7 @@ class Report: if freq is None: raise ValueError("benchmark freq can't be None!") - _codes = benchmark if isinstance(benchmark, list) else [benchmark] + _codes = benchmark if isinstance(benchmark, (list, dict)) else [benchmark] fields = ["$close/Ref($close,1)-1"] _temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq) if len(_temp_result) == 0: diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index edcc1ede2..2d5159292 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -197,7 +197,7 @@ class DataHandler(Serializable): ------- pd.DataFrame. """ - from .storage import HasingStockStorage + from .storage import BaseHandlerStorage data_storage = self._data if isinstance(data_storage, pd.DataFrame): @@ -211,10 +211,17 @@ class DataHandler(Serializable): # Fetch column first will be more friendly to SepDataFrame data_df = fetch_df_by_col(data_df, col_set) data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) - elif isinstance(data_storage, HasingStockStorage): - if proc_func is not None: - raise ValueError("proc_func is not supported by the HasingStockStorage") - data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + elif isinstance(data_storage, BaseHandlerStorage): + if not data_storage.is_proc_func_supported(): + if proc_func is not None: + raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}") + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig + ) + else: + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func + ) else: raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") @@ -522,7 +529,7 @@ class DataHandlerLP(DataHandler): ------- pd.DataFrame: """ - from .storage import HasingStockStorage + from .storage import BaseHandlerStorage data_storage = self._get_df_by_key(data_key) if isinstance(data_storage, pd.DataFrame): @@ -537,10 +544,17 @@ class DataHandlerLP(DataHandler): data_df = fetch_df_by_col(data_df, col_set) data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) - elif isinstance(data_storage, HasingStockStorage): - if proc_func is not None: - raise ValueError("proc_func is not supported by the HasingStockStorage") - data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + elif isinstance(data_storage, BaseHandlerStorage): + if not data_storage.is_proc_func_supported(): + if proc_func is not None: + raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}") + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig + ) + else: + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func + ) else: raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index 247970481..cd38bbefa 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -14,6 +14,7 @@ class BaseHandlerStorage: level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = DataHandler.CS_ALL, fetch_orig: bool = True, + proc_func: Callable = None, **kwargs, ) -> pd.DataFrame: """fetch data from the data storage @@ -24,6 +25,7 @@ class BaseHandlerStorage: describe how to select data by index level : Union[str, int] which index level to select the data + - if level is None, apply selector to df directly col_set : Union[str, List[str]] - if isinstance(col_set, str): select a set of meaningful columns.(e.g. features, columns) @@ -33,7 +35,8 @@ class BaseHandlerStorage: select several sets of meaningful columns, the returned data has multiple level fetch_orig : bool Return the original data instead of copy if possible. - + proc_func: Callable + please refer to the doc of DataHandler.fetch """ raise NotImplementedError("fetch is method not implemented!") @@ -42,6 +45,9 @@ class BaseHandlerStorage: def from_df(df: pd.DataFrame): raise NotImplementedError("from_df method is not implemented!") + def is_proc_func_supported(self): + raise NotImplementedError("is_proc_func_supported method is not implemented!") + class HasingStockStorage(BaseHandlerStorage): def __init__(self, df): @@ -105,3 +111,6 @@ class HasingStockStorage(BaseHandlerStorage): return fetch_stock_df_list[0] else: return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig) + + def is_proc_func_supported(self): + return False diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 7782b8486..7e0dc141c 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -270,6 +270,7 @@ def get_valid_value(series, last=True): Parameters ---------- series : pd.Seires + series should not be empty last : bool, optional wether to get the last valid value, by default True - if last is True, get the last valid value @@ -280,11 +281,7 @@ def get_valid_value(series, last=True): Nan | float the first/last valid value """ - x = series.dropna() - if x.empty: - return np.nan - else: - return x.iloc[-1] if last else x.iloc[0] + return series.fillna(method="ffill").iloc[-1] if last else series.fillna(method="bfill").iloc[0] def ts_data_last(ts_feature):