mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
adjust for SepDataframe
This commit is contained in:
@@ -56,7 +56,24 @@ class DataHandler(Serializable):
|
||||
end_time=None,
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
init_data=True,
|
||||
fetch_orig=True,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
instruments :
|
||||
The stock list to retrive
|
||||
start_time :
|
||||
start_time of the original data
|
||||
end_time :
|
||||
end_time of the original data
|
||||
data_loader : Tuple[dict, str, DataLoader]
|
||||
data loader to load the data
|
||||
init_data :
|
||||
intialize the original data in the constructor
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible
|
||||
"""
|
||||
# Set logger
|
||||
self.logger = get_module_logger("DataHandler")
|
||||
|
||||
@@ -72,6 +89,7 @@ class DataHandler(Serializable):
|
||||
self.instruments = instruments
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.fetch_orig = fetch_orig
|
||||
if init_data:
|
||||
with TimeInspector.logt("Init data"):
|
||||
self.init()
|
||||
@@ -138,7 +156,7 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = fetch_df_by_index(self._data, selector, level)
|
||||
df = fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig)
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
if squeeze:
|
||||
# squeeze columns
|
||||
@@ -269,8 +287,10 @@ class DataHandlerLP(DataHandler):
|
||||
for pname in "infer_processors", "learn_processors":
|
||||
for proc in locals()[pname]:
|
||||
getattr(self, pname).append(
|
||||
init_instance_by_config(proc, processor_module, accept_types=(processor_module.Processor,))
|
||||
)
|
||||
init_instance_by_config(
|
||||
proc,
|
||||
None if (isinstance(data_loader, dict) and "module_path" in data_loader) else data_loader_module,
|
||||
accept_types=processor_module.Processor))
|
||||
|
||||
self.process_type = process_type
|
||||
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
|
||||
@@ -354,15 +374,16 @@ class DataHandlerLP(DataHandler):
|
||||
# init raw data
|
||||
super().init(enable_cache=enable_cache)
|
||||
|
||||
if init_type == DataHandlerLP.IT_FIT_IND:
|
||||
self.fit()
|
||||
self.process_data()
|
||||
elif init_type == DataHandlerLP.IT_LS:
|
||||
self.process_data()
|
||||
elif init_type == DataHandlerLP.IT_FIT_SEQ:
|
||||
self.fit_process_data()
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
with TimeInspector.logt("fit & process data"):
|
||||
if init_type == DataHandlerLP.IT_FIT_IND:
|
||||
self.fit()
|
||||
self.process_data()
|
||||
elif init_type == DataHandlerLP.IT_LS:
|
||||
self.process_data()
|
||||
elif init_type == DataHandlerLP.IT_FIT_SEQ:
|
||||
self.fit_process_data()
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
# TODO: Be able to cache handler data. Save the memory for data processing
|
||||
|
||||
@@ -396,7 +417,7 @@ class DataHandlerLP(DataHandler):
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = self._get_df_by_key(data_key)
|
||||
df = fetch_df_by_index(df, selector, level)
|
||||
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
return self._fetch_df_by_col(df, col_set)
|
||||
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
|
||||
|
||||
@@ -32,7 +32,7 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
|
||||
|
||||
def fetch_df_by_index(
|
||||
df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]
|
||||
df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int], fetch_orig=True,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from `data` with `selector` and `level`
|
||||
@@ -52,6 +52,11 @@ def fetch_df_by_index(
|
||||
idx_slc = (selector, slice(None, None))
|
||||
if get_level_index(df, level) == 1:
|
||||
idx_slc = idx_slc[1], idx_slc[0]
|
||||
return df.loc[
|
||||
pd.IndexSlice[idx_slc],
|
||||
] # This could be faster than df.loc(axis=0)[idx_slc]
|
||||
if fetch_orig:
|
||||
for slc in idx_slc:
|
||||
if slc != slice(None, None):
|
||||
return df.loc[pd.IndexSlice[idx_slc],]
|
||||
else:
|
||||
return df
|
||||
else:
|
||||
return df.loc[pd.IndexSlice[idx_slc],]
|
||||
|
||||
@@ -5,9 +5,9 @@ import sys, traceback, signal, atexit
|
||||
from . import R
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
# function to handle the experiment when unusual program ending occurs
|
||||
def experiment_exit_handler():
|
||||
"""
|
||||
@@ -31,10 +31,12 @@ def experiment_exception_hook(type, value, tb):
|
||||
value: Exception's value
|
||||
tb: Exception's traceback
|
||||
"""
|
||||
error_msg = "An exception has been raised.\n" f"Type: {type}\n"
|
||||
error_msg = f"An exception has been raised[{type.__name__}: {value}]."
|
||||
logger.error(error_msg)
|
||||
|
||||
# Same as original format
|
||||
traceback.print_tb(tb)
|
||||
logger.error(f"Value: {value}")
|
||||
print(f"{type.__name__}: {value}")
|
||||
|
||||
R.end_exp(recorder_status=Recorder.STATUS_FA)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user