diff --git a/examples/rolling_process_data/README.md b/examples/rolling_process_data/README.md index 3f1c8768d..6a6af0d3d 100644 --- a/examples/rolling_process_data/README.md +++ b/examples/rolling_process_data/README.md @@ -1,2 +1 @@ # Rolling Process Data - diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index d5f7fec10..29b1c19f8 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -38,9 +38,12 @@ class RollingDataWorkflow(object): "start_time": self.start_time, "end_time": self.end_time, "instruments": self.MARKET, + "infer_processors": [], + "learn_processors": [], }, } pre_handler = init_instance_by_config(handler_config) + pre_handler.config(dump_all=True) pre_handler.to_pickle(path) def _load_pre_handler(self, path): @@ -50,8 +53,8 @@ class RollingDataWorkflow(object): def rolling_process(self): self._init_qlib() - self._dump_pre_handler("pre_handler.py") - pre_handler = self._load_pre_handler("pre_handler.py") + self._dump_pre_handler("pre_handler.pkl") + pre_handler = self._load_pre_handler("pre_handler.pkl") train_start_time = (2010, 1, 1) train_end_time = (2012, 12, 31) @@ -72,6 +75,13 @@ class RollingDataWorkflow(object): "end_time": datetime(*test_end_time), "fit_start_time": datetime(*train_start_time), "fit_end_time": datetime(*train_end_time), + "infer_processors": [ + {"class":"RobustZScoreNorm", "kwargs": {"fields_group": "feature"}}, + ], + "learn_processors": [ + {"class": "DropnaLabel"}, + {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, + ], "data_loader_kwargs": { "handler_config": pre_handler, }, @@ -87,7 +97,8 @@ class RollingDataWorkflow(object): dataset = init_instance_by_config(dataset_config) - for rolling_offset in range(rolling_cnt): + for rolling_offset in range(self.rolling_cnt): + print(f"===========rolling{rolling_offset} start===========") if rolling_offset: dataset.init( handler_kwargs={ @@ -112,6 +123,8 @@ class RollingDataWorkflow(object): ) dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) + ## print or dump data + print(f"===========rolling{rolling_offset} end===========") if __name__ == "__main__": diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 539b930ec..1cda5c025 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -250,7 +250,9 @@ class DataLoaderDH(DataLoader): is_group will be used to describe whether the key of handler_config is group """ - if self.is_group: + from qlib.data.dataset.handler import DataHandler + + if is_group: self.handlers = { grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items() } @@ -274,5 +276,5 @@ class DataLoaderDH(DataLoader): axis=1, ) else: - df = self.handler.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs) + df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs) return df