diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py index 3108960c8..b6c1362fd 100644 --- a/examples/nested_decision_execution/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -124,14 +124,14 @@ class NestedDecisionExecutionWorkflow: def _init_qlib(self): """initialize qlib""" - provider_uri_day = "/data/stock_data/huaxia/qlib" - provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib" - # provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir - # GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True) - # provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri") - # GetData().qlib_data( - # target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True - # ) + # provider_uri_day = "/data/stock_data/huaxia/qlib" + # provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib" + provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir + GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True) + provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri") + GetData().qlib_data( + target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True + ) provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day} client_config = { "calendar_provider": { diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index cd38bbefa..9325807f9 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -37,8 +37,12 @@ class BaseHandlerStorage: Return the original data instead of copy if possible. proc_func: Callable please refer to the doc of DataHandler.fetch - """ + Returns + ------- + pd.DataFrame + the dataframe fetched + """ raise NotImplementedError("fetch is method not implemented!") @staticmethod @@ -46,6 +50,7 @@ class BaseHandlerStorage: raise NotImplementedError("from_df method is not implemented!") def is_proc_func_supported(self): + """whether the arg `proc_func` in `fetch` method is supported.""" raise NotImplementedError("is_proc_func_supported method is not implemented!") @@ -113,4 +118,5 @@ class HasingStockStorage(BaseHandlerStorage): return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig) def is_proc_func_supported(self): + """the arg `proc_func` in `fetch` method is not supported in HasingStockStorage""" return False diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 7e0dc141c..9e9590e30 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -3,6 +3,8 @@ import datetime import numpy as np import pandas as pd + +from functools import partial from typing import Tuple, List, Union, Optional, Callable from . import lazy_sort_index @@ -284,21 +286,15 @@ def get_valid_value(series, last=True): return series.fillna(method="ffill").iloc[-1] if last else series.fillna(method="bfill").iloc[0] -def ts_data_last(ts_feature): - """get the last not nan value of pd.Series|DataFrame with single level index""" +def _ts_data_valid(ts_feature, last=False): + """get the first/last not nan value of pd.Series|DataFrame with single level index""" if isinstance(ts_feature, pd.DataFrame): - return ts_feature.apply(lambda column: get_valid_value(column, last=True)) + return ts_feature.apply(lambda column: get_valid_value(column, last=last)) elif isinstance(ts_feature, pd.Series): - return get_valid_value(ts_feature, last=True) + return get_valid_value(ts_feature, last=last) else: raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}") -def ts_data_first(ts_feature): - """get the first not nan value of pd.Series|DataFrame with single level index""" - if isinstance(ts_feature, pd.DataFrame): - return ts_feature.apply(lambda column: get_valid_value(column, last=False)) - elif isinstance(ts_feature, pd.Series): - return get_valid_value(ts_feature, last=False) - else: - raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}") +ts_data_last = partial(_ts_data_valid, last=False) +ts_data_first = partial(_ts_data_valid, last=True)