diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 9812864af..d32b251de 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -56,7 +56,24 @@ class DataHandler(Serializable): end_time=None, data_loader: Tuple[dict, str, DataLoader] = None, init_data=True, + fetch_orig=True, ): + """ + Parameters + ---------- + instruments : + The stock list to retrive + start_time : + start_time of the original data + end_time : + end_time of the original data + data_loader : Tuple[dict, str, DataLoader] + data loader to load the data + init_data : + intialize the original data in the constructor + fetch_orig : bool + Return the original data instead of copy if possible + """ # Set logger self.logger = get_module_logger("DataHandler") @@ -72,6 +89,7 @@ class DataHandler(Serializable): self.instruments = instruments self.start_time = start_time self.end_time = end_time + self.fetch_orig = fetch_orig if init_data: with TimeInspector.logt("Init data"): self.init() @@ -138,7 +156,7 @@ class DataHandler(Serializable): ------- pd.DataFrame: """ - df = fetch_df_by_index(self._data, selector, level) + df = fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig) df = self._fetch_df_by_col(df, col_set) if squeeze: # squeeze columns @@ -269,8 +287,10 @@ class DataHandlerLP(DataHandler): for pname in "infer_processors", "learn_processors": for proc in locals()[pname]: getattr(self, pname).append( - init_instance_by_config(proc, processor_module, accept_types=(processor_module.Processor,)) - ) + init_instance_by_config( + proc, + None if (isinstance(data_loader, dict) and "module_path" in data_loader) else data_loader_module, + accept_types=processor_module.Processor)) self.process_type = process_type super().__init__(instruments, start_time, end_time, data_loader, **kwargs) @@ -354,15 +374,16 @@ class DataHandlerLP(DataHandler): # init raw data super().init(enable_cache=enable_cache) - if init_type == DataHandlerLP.IT_FIT_IND: - self.fit() - self.process_data() - elif init_type == DataHandlerLP.IT_LS: - self.process_data() - elif init_type == DataHandlerLP.IT_FIT_SEQ: - self.fit_process_data() - else: - raise NotImplementedError(f"This type of input is not supported") + with TimeInspector.logt("fit & process data"): + if init_type == DataHandlerLP.IT_FIT_IND: + self.fit() + self.process_data() + elif init_type == DataHandlerLP.IT_LS: + self.process_data() + elif init_type == DataHandlerLP.IT_FIT_SEQ: + self.fit_process_data() + else: + raise NotImplementedError(f"This type of input is not supported") # TODO: Be able to cache handler data. Save the memory for data processing @@ -396,7 +417,7 @@ class DataHandlerLP(DataHandler): pd.DataFrame: """ df = self._get_df_by_key(data_key) - df = fetch_df_by_index(df, selector, level) + df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) return self._fetch_df_by_col(df, col_set) def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index 8ee199bc0..85a5e8389 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -32,7 +32,7 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int: def fetch_df_by_index( - df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int] + df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int], fetch_orig=True, ) -> pd.DataFrame: """ fetch data from `data` with `selector` and `level` @@ -52,6 +52,11 @@ def fetch_df_by_index( idx_slc = (selector, slice(None, None)) if get_level_index(df, level) == 1: idx_slc = idx_slc[1], idx_slc[0] - return df.loc[ - pd.IndexSlice[idx_slc], - ] # This could be faster than df.loc(axis=0)[idx_slc] + if fetch_orig: + for slc in idx_slc: + if slc != slice(None, None): + return df.loc[pd.IndexSlice[idx_slc],] + else: + return df + else: + return df.loc[pd.IndexSlice[idx_slc],] diff --git a/qlib/workflow/utils.py b/qlib/workflow/utils.py index d4594d28e..f5a73a157 100644 --- a/qlib/workflow/utils.py +++ b/qlib/workflow/utils.py @@ -5,9 +5,9 @@ import sys, traceback, signal, atexit from . import R from .recorder import Recorder from ..log import get_module_logger - logger = get_module_logger("workflow", "INFO") + # function to handle the experiment when unusual program ending occurs def experiment_exit_handler(): """ @@ -31,10 +31,12 @@ def experiment_exception_hook(type, value, tb): value: Exception's value tb: Exception's traceback """ - error_msg = "An exception has been raised.\n" f"Type: {type}\n" + error_msg = f"An exception has been raised[{type.__name__}: {value}]." logger.error(error_msg) + + # Same as original format traceback.print_tb(tb) - logger.error(f"Value: {value}") + print(f"{type.__name__}: {value}") R.end_exp(recorder_status=Recorder.STATUS_FA)