diff --git a/examples/highfreq/highfreq_processor.py b/examples/highfreq/highfreq_processor.py index f0ab0dec2..4ec8f3dd2 100644 --- a/examples/highfreq/highfreq_processor.py +++ b/examples/highfreq/highfreq_processor.py @@ -70,3 +70,10 @@ class HighFreqNorm(Processor): columns=["FEATURE_%d" % i for i in range(12 * 240)], ).sort_index() return df_new_features + + def config(fit_start_time=None, fit_end_time=None, **kwargs): + if fit_start_time: + self.fit_start_time = fit_start_time + if fit_end_time: + self.fit_end_time = fit_end_time + super().config(**kwargs) diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index c2ca36db3..0b48b971f 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -31,7 +31,7 @@ class HighfreqWorkflow(object): SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None} - MARKET = "all" + MARKET = "csi300" start_time = "2020-09-15 00:00:00" end_time = "2021-01-18 16:00:00" @@ -145,35 +145,40 @@ class HighfreqWorkflow(object): self._prepare_calender_cache() ##=============reinit dataset============= - dataset.init( + dataset.config( + handler_kwargs={ + "start_time": "2021-01-19 00:00:00", + "end_time": "2021-01-25 16:00:00", + }, + segments={ + "test": ( + "2021-01-19 00:00:00", + "2021-01-25 16:00:00", + ), + }, + ) + dataset.setup_data( handler_kwargs={ "init_type": DataHandlerLP.IT_LS, - "start_time": "2021-01-19 00:00:00", - "end_time": "2021-01-25 16:00:00", - }, - segment_kwargs={ - "test": ( - "2021-01-19 00:00:00", - "2021-01-25 16:00:00", - ), }, ) - dataset_backtest.init( + dataset_backtest.config( handler_kwargs={ "start_time": "2021-01-19 00:00:00", "end_time": "2021-01-25 16:00:00", }, - segment_kwargs={ + segments={ "test": ( "2021-01-19 00:00:00", "2021-01-25 16:00:00", ), }, ) + dataset_backtest.setup_data(handler_kwargs={}) ##=============get data============= - xtest = dataset.prepare(["test"]) - backtest_test = dataset_backtest.prepare(["test"]) + xtest, = dataset.prepare(["test"]) + backtest_test, = dataset_backtest.prepare(["test"]) print(xtest, backtest_test) return diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 518b8eecd..aa90cee2f 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -20,17 +20,25 @@ class Dataset(Serializable): """ init is designed to finish following steps: + - init instance + + - config the state of the dataset(info to prepare the data) + - The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing. + - setup data - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing. - - initialize the state of the dataset(info to prepare the data) - - The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing. - The data could specify the info to caculate the essential data for preparation """ self.setup_data(*args, **kwargs) super().__init__() + def config(self, *arg, **kwargs): + """ + config is designed to configure and parameters that cannot be learned from the data + """ + super().config(*arg, **kwargs) + def setup_data(self, *args, **kwargs): """ Setup the data. @@ -39,7 +47,7 @@ class Dataset(Serializable): - User have a Dataset object with learned status on disk. - - User load the Dataset object from the disk(Note the init function is skiped). + - User load the Dataset object from the disk. - User call `setup_data` to load new data. @@ -76,44 +84,7 @@ class DatasetH(Dataset): - The processing is related to data split. """ - def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None): - """ - Initialize the DatasetH - - Parameters - ---------- - handler_kwargs : dict - Config of DataHanlder, which could include the following arguments: - - - arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'. - - - arguments of DataHandler.init, such as 'enable_cache', etc. - - segment_kwargs : dict - Config of segments which is same as 'segments' in DatasetH.setup_data - - """ - if handler_kwargs: - if not isinstance(handler_kwargs, dict): - raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}") - kwargs_init = {} - kwargs_conf_data = {} - conf_data_arg = {"instruments", "start_time", "end_time", "fit_start_time", "fit_end_time"} - for k, v in handler_kwargs.items(): - if k in conf_data_arg: - kwargs_conf_data.update({k: v}) - else: - kwargs_init.update({k: v}) - - self.handler.conf_data(**kwargs_conf_data) - self.handler.init(**kwargs_init) - - if segment_kwargs: - if not isinstance(segment_kwargs, dict): - raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}") - self.segments = segment_kwargs.copy() - - def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]): + def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs): """ Setup the underlying data. @@ -144,6 +115,52 @@ class DatasetH(Dataset): """ self.handler = init_instance_by_config(handler, accept_types=DataHandler) self.segments = segments.copy() + super().__init__(**kwargs) + + def config(self, handler_kwargs:dict = None, segments:dict = None, **kwargs): + """ + Initialize the DatasetH + + Parameters + ---------- + handler_kwargs : dict + Config of DataHanlder, which could include the following arguments: + + - arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'. + + kwargs : dict + Config of DatasetH, such as + + - segments : dict + Config of segments which is same as 'segments' in self.__init__ + + """ + super().config(**kwargs) + if handler_kwargs is not None: + self.handler.config(**handler_kwargs) + if segments is not None: + self.segments = segments.copy() + + + + def setup_data(self, handler_kwargs: dict = None, **kwargs): + """ + Setup the Data + + Parameters + ---------- + handler_kwargs : dict + init arguments of DataHanlder, which could include the following arguments: + + - init_type : Init Type of Handler + + - enable_cache : wheter to enable cache + + """ + super().setup_data(**kwargs) + if handler_kwargs is not None: + self.handler.setup_data(**handler_kwargs) + def __repr__(self): return "{name}(handler={handler}, segments={segments})".format( @@ -433,16 +450,21 @@ class TSDatasetH(DatasetH): - The dimension of a batch of data """ - def __init__(self, step_len=30, *args, **kwargs): + def __init__(self, step_len=30, **kwargs): self.step_len = step_len - super().__init__(*args, **kwargs) + super().__init__(**kwargs) - def setup_data(self, *args, **kwargs): - super().setup_data(*args, **kwargs) + def config(self, step_len=None, **kwargs): + super().config(**kwargs) + if step_len: + self.step_len = step_len + + def setup_data(self, **kwargs): + super().setup_data(**kwargs) cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique() cal = sorted(cal) - # Get the datatime index for building timestamp self.cal = cal + def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: # Dataset decide how to slice data(Get more data for timeseries). diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 9aa05b9b9..712cd6232 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -6,6 +6,7 @@ import abc import bisect import logging import warnings +from inspect import getfullargspec from typing import Union, Tuple, List, Iterator, Optional import pandas as pd @@ -99,10 +100,10 @@ class DataHandler(Serializable): self.fetch_orig = fetch_orig if init_data: with TimeInspector.logt("Init data"): - self.init() + self.setup_data() super().__init__() - def conf_data(self, **kwargs): + def config(self, instruments=None, start_time=None, end_time=None, **kwargs): """ configuration of data. # what data to be loaded from data source @@ -111,14 +112,17 @@ class DataHandler(Serializable): The data will be initialized with different time range. """ - attr_list = {"instruments", "start_time", "end_time"} - for k, v in kwargs.items(): - if k in attr_list: - setattr(self, k, v) - - def init(self, enable_cache: bool = False): + super().config(**kwargs) + if instruments: + self.instruments = instruments + if start_time: + self.start_time = start_time + if end_time: + self.end_time = end_time + + def setup_data(self, enable_cache: bool = False): """ - initialize the data. + Set Up the data. In case of running intialization for multiple time, it will do nothing for the second time. It is responsible for maintaining following variable @@ -403,7 +407,7 @@ class DataHandlerLP(DataHandler): if self.drop_raw: del self._data - def conf_data(self, **kwargs): + def config(self, processors_kwargs:dict = None, **kwargs): """ configuration of data. # what data to be loaded from data source @@ -412,27 +416,19 @@ class DataHandlerLP(DataHandler): The data will be initialized with different time range. """ - attr_list = {"fit_start_time", "fit_end_time"} - for k, v in kwargs.items(): - if k in attr_list: - for infer_processor in self.infer_processors: - if getattr(infer_processor, k, None): - setattr(infer_processor, k, v) - - for learn_processor in self.learn_processors: - if getattr(learn_processor, k, None): - setattr(learn_processor, k, v) - - super().conf_data(**kwargs) + super().config(**kwargs) + if processors_kwargs is not None: + for processor in self.get_all_processors(): + processor.config(**processor_kwargs) # init type IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df IT_LS = "load_state" # The state of the object has been load by pickle - def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False): + def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs): """ - Initialize the data of Qlib + Set up the data of Qlib Parameters ---------- @@ -447,7 +443,7 @@ class DataHandlerLP(DataHandler): when we call `init` next time """ # init raw data - super().init(enable_cache=enable_cache) + super().setup_data(**kwargs) with TimeInspector.logt("fit & process data"): if init_type == DataHandlerLP.IT_FIT_IND: diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 1cda5c025..a58bca5e8 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -53,7 +53,6 @@ class DataLoader(abc.ABC): """ pass - class DLWParser(DataLoader): """ (D)ata(L)oader (W)ith (P)arser for features and names diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 5a06f66be..e14e85831 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -72,6 +72,9 @@ class Processor(Serializable): """ return True + def config(**kwargs): + super().config(kwargs.get("dump_all", None), kwargs.get("exclude", None)) + class DropnaProcessor(Processor): def __init__(self, fields_group=None): @@ -192,6 +195,12 @@ class MinMaxNorm(Processor): df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df + def config(fit_start_time=None, fit_end_time=None, **kwargs): + if fit_start_time: + self.fit_start_time = fit_start_time + if fit_end_time: + self.fit_end_time = fit_end_time + super().config(**kwargs) class ZScoreNorm(Processor): """ZScore Normalization""" @@ -220,6 +229,13 @@ class ZScoreNorm(Processor): df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df + + def config(fit_start_time=None, fit_end_time=None, **kwargs): + if fit_start_time: + self.fit_start_time = fit_start_time + if fit_end_time: + self.fit_end_time = fit_end_time + super().config(**kwargs) class RobustZScoreNorm(Processor): @@ -257,6 +273,12 @@ class RobustZScoreNorm(Processor): df.clip(-3, 3, inplace=True) return df + def config(fit_start_time=None, fit_end_time=None, **kwargs): + if fit_start_time: + self.fit_start_time = fit_start_time + if fit_end_time: + self.fit_end_time = fit_end_time + super().config(**kwargs) class CSZScoreNorm(Processor): """Cross Sectional ZScore Normalization"""