diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 719d93a1b..0be88dddc 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -98,6 +98,7 @@ class RollingDataWorkflow(object): dataset = init_instance_by_config(dataset_config) for rolling_offset in range(self.rolling_cnt): + print(f"===========rolling{rolling_offset} start===========") if rolling_offset: dataset.init( @@ -105,6 +106,8 @@ class RollingDataWorkflow(object): "init_type": DataHandlerLP.IT_FIT_SEQ, "start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), "end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), + "fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), + "fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), }, segment_kwargs={ "train": ( @@ -123,6 +126,7 @@ class RollingDataWorkflow(object): ) dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) + print(dtrain, dvalid, dtest) ## print or dump data print(f"===========rolling{rolling_offset} end===========") diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 0f5d2baba..518b8eecd 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -98,7 +98,7 @@ class DatasetH(Dataset): raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}") kwargs_init = {} kwargs_conf_data = {} - conf_data_arg = {"instruments", "start_time", "end_time"} + conf_data_arg = {"instruments", "start_time", "end_time", "fit_start_time", "fit_end_time"} for k, v in handler_kwargs.items(): if k in conf_data_arg: kwargs_conf_data.update({k: v}) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index f4795c566..40db5e4f3 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -115,8 +115,7 @@ class DataHandler(Serializable): for k, v in kwargs.items(): if k in attr_list: setattr(self, k, v) - else: - raise KeyError("Such config is not supported.") + def init(self, enable_cache: bool = False): """ @@ -405,11 +404,34 @@ class DataHandlerLP(DataHandler): if self.drop_raw: del self._data + + def conf_data(self, **kwargs): + """ + configuration of data. + # what data to be loaded from data source + + This method will be used when loading pickled handler from dataset. + The data will be initialized with different time range. + + """ + attr_list = {"fit_start_time", "fit_end_time"} + for k, v in kwargs.items(): + if k in attr_list: + for infer_processor in self.infer_processors: + if getattr(infer_processor, k, None): + setattr(infer_processor, k, v) + + for learn_processor in self.learn_processors: + if getattr(learn_processor, k, None): + setattr(learn_processor, k, v) + + super().conf_data(**kwargs) + # init type IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df IT_LS = "load_state" # The state of the object has been load by pickle - + def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False): """ Initialize the data of Qlib