1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

adjust for SepDataframe

This commit is contained in:
Young
2020-11-19 04:08:05 +00:00
parent 64ed43b791
commit afcfa0a478
3 changed files with 48 additions and 20 deletions

View File

@@ -56,7 +56,24 @@ class DataHandler(Serializable):
end_time=None,
data_loader: Tuple[dict, str, DataLoader] = None,
init_data=True,
fetch_orig=True,
):
"""
Parameters
----------
instruments :
The stock list to retrive
start_time :
start_time of the original data
end_time :
end_time of the original data
data_loader : Tuple[dict, str, DataLoader]
data loader to load the data
init_data :
intialize the original data in the constructor
fetch_orig : bool
Return the original data instead of copy if possible
"""
# Set logger
self.logger = get_module_logger("DataHandler")
@@ -72,6 +89,7 @@ class DataHandler(Serializable):
self.instruments = instruments
self.start_time = start_time
self.end_time = end_time
self.fetch_orig = fetch_orig
if init_data:
with TimeInspector.logt("Init data"):
self.init()
@@ -138,7 +156,7 @@ class DataHandler(Serializable):
-------
pd.DataFrame:
"""
df = fetch_df_by_index(self._data, selector, level)
df = fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig)
df = self._fetch_df_by_col(df, col_set)
if squeeze:
# squeeze columns
@@ -269,8 +287,10 @@ class DataHandlerLP(DataHandler):
for pname in "infer_processors", "learn_processors":
for proc in locals()[pname]:
getattr(self, pname).append(
init_instance_by_config(proc, processor_module, accept_types=(processor_module.Processor,))
)
init_instance_by_config(
proc,
None if (isinstance(data_loader, dict) and "module_path" in data_loader) else data_loader_module,
accept_types=processor_module.Processor))
self.process_type = process_type
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
@@ -354,15 +374,16 @@ class DataHandlerLP(DataHandler):
# init raw data
super().init(enable_cache=enable_cache)
if init_type == DataHandlerLP.IT_FIT_IND:
self.fit()
self.process_data()
elif init_type == DataHandlerLP.IT_LS:
self.process_data()
elif init_type == DataHandlerLP.IT_FIT_SEQ:
self.fit_process_data()
else:
raise NotImplementedError(f"This type of input is not supported")
with TimeInspector.logt("fit & process data"):
if init_type == DataHandlerLP.IT_FIT_IND:
self.fit()
self.process_data()
elif init_type == DataHandlerLP.IT_LS:
self.process_data()
elif init_type == DataHandlerLP.IT_FIT_SEQ:
self.fit_process_data()
else:
raise NotImplementedError(f"This type of input is not supported")
# TODO: Be able to cache handler data. Save the memory for data processing
@@ -396,7 +417,7 @@ class DataHandlerLP(DataHandler):
pd.DataFrame:
"""
df = self._get_df_by_key(data_key)
df = fetch_df_by_index(df, selector, level)
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
return self._fetch_df_by_col(df, col_set)
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:

View File

@@ -32,7 +32,7 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
def fetch_df_by_index(
df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]
df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int], fetch_orig=True,
) -> pd.DataFrame:
"""
fetch data from `data` with `selector` and `level`
@@ -52,6 +52,11 @@ def fetch_df_by_index(
idx_slc = (selector, slice(None, None))
if get_level_index(df, level) == 1:
idx_slc = idx_slc[1], idx_slc[0]
return df.loc[
pd.IndexSlice[idx_slc],
] # This could be faster than df.loc(axis=0)[idx_slc]
if fetch_orig:
for slc in idx_slc:
if slc != slice(None, None):
return df.loc[pd.IndexSlice[idx_slc],]
else:
return df
else:
return df.loc[pd.IndexSlice[idx_slc],]

View File

@@ -5,9 +5,9 @@ import sys, traceback, signal, atexit
from . import R
from .recorder import Recorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
# function to handle the experiment when unusual program ending occurs
def experiment_exit_handler():
"""
@@ -31,10 +31,12 @@ def experiment_exception_hook(type, value, tb):
value: Exception's value
tb: Exception's traceback
"""
error_msg = "An exception has been raised.\n" f"Type: {type}\n"
error_msg = f"An exception has been raised[{type.__name__}: {value}]."
logger.error(error_msg)
# Same as original format
traceback.print_tb(tb)
logger.error(f"Value: {value}")
print(f"{type.__name__}: {value}")
R.end_exp(recorder_status=Recorder.STATUS_FA)