mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
fix config_data bug
This commit is contained in:
@@ -98,6 +98,7 @@ class RollingDataWorkflow(object):
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
for rolling_offset in range(self.rolling_cnt):
|
||||
|
||||
print(f"===========rolling{rolling_offset} start===========")
|
||||
if rolling_offset:
|
||||
dataset.init(
|
||||
@@ -105,6 +106,8 @@ class RollingDataWorkflow(object):
|
||||
"init_type": DataHandlerLP.IT_FIT_SEQ,
|
||||
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
},
|
||||
segment_kwargs={
|
||||
"train": (
|
||||
@@ -123,6 +126,7 @@ class RollingDataWorkflow(object):
|
||||
)
|
||||
|
||||
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
|
||||
print(dtrain, dvalid, dtest)
|
||||
## print or dump data
|
||||
print(f"===========rolling{rolling_offset} end===========")
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ class DatasetH(Dataset):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}")
|
||||
kwargs_init = {}
|
||||
kwargs_conf_data = {}
|
||||
conf_data_arg = {"instruments", "start_time", "end_time"}
|
||||
conf_data_arg = {"instruments", "start_time", "end_time", "fit_start_time", "fit_end_time"}
|
||||
for k, v in handler_kwargs.items():
|
||||
if k in conf_data_arg:
|
||||
kwargs_conf_data.update({k: v})
|
||||
|
||||
@@ -115,8 +115,7 @@ class DataHandler(Serializable):
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list:
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
raise KeyError("Such config is not supported.")
|
||||
|
||||
|
||||
def init(self, enable_cache: bool = False):
|
||||
"""
|
||||
@@ -405,11 +404,34 @@ class DataHandlerLP(DataHandler):
|
||||
if self.drop_raw:
|
||||
del self._data
|
||||
|
||||
|
||||
def conf_data(self, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
|
||||
This method will be used when loading pickled handler from dataset.
|
||||
The data will be initialized with different time range.
|
||||
|
||||
"""
|
||||
attr_list = {"fit_start_time", "fit_end_time"}
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list:
|
||||
for infer_processor in self.infer_processors:
|
||||
if getattr(infer_processor, k, None):
|
||||
setattr(infer_processor, k, v)
|
||||
|
||||
for learn_processor in self.learn_processors:
|
||||
if getattr(learn_processor, k, None):
|
||||
setattr(learn_processor, k, v)
|
||||
|
||||
super().conf_data(**kwargs)
|
||||
|
||||
# init type
|
||||
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
|
||||
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
|
||||
IT_LS = "load_state" # The state of the object has been load by pickle
|
||||
|
||||
|
||||
def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False):
|
||||
"""
|
||||
Initialize the data of Qlib
|
||||
|
||||
Reference in New Issue
Block a user