From f476ada22d9ea7a050ee5e01465da3bcc6561d7e Mon Sep 17 00:00:00 2001 From: Young Date: Sat, 21 Nov 2020 08:54:11 +0000 Subject: [PATCH] Adjust interface --- qlib/data/dataset/__init__.py | 2 +- qlib/data/dataset/handler.py | 10 ++-- qlib/utils/serial.py | 42 +++++++++++++++- qlib/workflow/expm.py | 92 ++++++++++++++++++++++------------- 4 files changed, 106 insertions(+), 40 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index d5b8a12e9..e7a149c65 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -96,7 +96,7 @@ class DatasetH(Dataset): } """ self._handler = init_instance_by_config(handler, accept_types=DataHandler) - self._segments = segments + self._segments = segments.copy() def prepare( self, segments: Union[List[str], Tuple[str], str, slice], col_set=DataHandler.CS_ALL, **kwargs diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 422cc6b1d..b3608464d 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -156,8 +156,9 @@ class DataHandler(Serializable): ------- pd.DataFrame: """ - df = fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig) - df = self._fetch_df_by_col(df, col_set) + # Fetch column first will be more friendly to SepDataFrame + df = self._fetch_df_by_col(self._data, col_set) + df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) if squeeze: # squeeze columns df = df.squeeze() @@ -417,8 +418,9 @@ class DataHandlerLP(DataHandler): pd.DataFrame: """ df = self._get_df_by_key(data_key) - df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) - return self._fetch_df_by_col(df, col_set) + # Fetch column first will be more friendly to SepDataFrame + df = self._fetch_df_by_col(df, col_set) + return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: """ diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index 9bc8ce94a..b5734d726 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -10,13 +10,51 @@ class Serializable: Serializable behaves like pickle. But it only saves the state whose name **does not** start with `_` """ + def __init__(self): + self._dump_all = False + self._exclude = [] def __getstate__(self) -> dict: - return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} + return { + k: v + for k, v in self.__dict__.items() if k not in self.exclude and (self.dump_all or not k.startswith("_")) + } def __setstate__(self, state: dict): self.__dict__.update(state) - def to_pickle(self, path: [Path, str]): + @property + def dump_all(self): + """ + will the object dump all object + + Parameters + ---------- + self : [TODO:type] + [TODO:description] + """ + return getattr(self, "_dump_all", False) + + @property + def exclude(self): + """ + What attribute will be dumped + + Parameters + ---------- + self : [TODO:type] + [TODO:description] + """ + return getattr(self, "_exclude", []) + + def config(self, dump_all: bool = None, exclude: list = None): + if dump_all is not None: + self._dump_all = dump_all + + if exclude is not None: + self._exclude = exclude + + def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None): + self.config(dump_all=dump_all, exclude=exclude) with Path(path).open("wb") as f: pickle.dump(self, f) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 8fb7962e9..25c5d4661 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -6,7 +6,7 @@ from mlflow.exceptions import MlflowException import os from pathlib import Path from contextlib import contextmanager -from .exp import MLflowExperiment +from .exp import MLflowExperiment, Experiment from .recorder import Recorder, MLflowRecorder from ..log import get_module_logger @@ -128,7 +128,61 @@ class ExpManager: ------- An experiment object. """ - raise NotImplementedError(f"Please implement the `get_exp` method.") + # special case of getting experiment + if experiment_id is None and experiment_name is None: + if self.active_experiment is not None: + return self.active_experiment + # User don't want get active code now. + # Don't assume underlying code could handle the case of two None + if experiment_id is None and experiment_name is None: + experiment_name = self.default_exp_name + + if create: + exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name) + else: + exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False + if is_new: + self.active_experiment = exp + # start the recorder + self.active_experiment.start() + return exp + + def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (object, bool): + """ + Method for getting or creating an experiment. It will try to first get a valid experiment, if exception occurs, it will + automatically create a new experiment based on the given id and name. + """ + try: + if experiment_id is None and experiment_name is None: + experiment_name = self.default_exp_name + return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False + except ValueError: + if experiment_name is None: + experiment_name = self.default_exp_name + logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.") + return self.create_exp(experiment_name), True + + def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment: + """ + get specific experiment by name or id. If it does not exist, raise ValueError + + Parameters + ---------- + experiment_id : + The id of experiment + experiment_name : + The id name experiment + + Returns + ------- + Experiment: + The searched experiment + + Raises + ------ + ValueError + """ + raise NotImplementedError(f"Please implement the `_get_exp` method") def delete_exp(self, experiment_id=None, experiment_name=None): """ @@ -197,6 +251,7 @@ class MLflowExpManager(ExpManager): self.active_experiment = None def create_exp(self, experiment_name=None): + assert(experiment_name is not None) # init experiment experiment_id = self.client.create_experiment(experiment_name) experiment = MLflowExperiment(experiment_id, experiment_name, self.uri) @@ -204,34 +259,6 @@ class MLflowExpManager(ExpManager): return experiment - def get_exp(self, experiment_id=None, experiment_name=None, create=True): - # special case of getting experiment - if experiment_id is None and experiment_name is None: - if self.active_experiment is not None: - return self.active_experiment - if create: - exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name) - else: - exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False - if is_new: - self.active_experiment = exp - # start the recorder - self.active_experiment.start() - return exp - - def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (object, bool): - """ - Method for getting or creating an experiment. It will try to first get a valid experiment, if exception occurs, it will - automatically create a new experiment based on the given id and name. - """ - try: - return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False - except ValueError: - if experiment_name is None: - experiment = self.default_exp_name - logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.") - return self.create_exp(experiment_name), True - def _get_exp(self, experiment_id=None, experiment_name=None): """ Method for getting or creating an experiment. It will try to first get a valid experiment, if exception occurs, it will @@ -247,7 +274,7 @@ class MLflowExpManager(ExpManager): raise MlflowException("No valid experiment has been found.") experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri) return experiment - except MlflowException as e: + except MlflowException: raise ValueError( "No valid experiment has been found, please make sure the input experiment id is correct." ) @@ -293,6 +320,5 @@ class MLflowExpManager(ExpManager): experiments = dict() for exp in exps: experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri) - experiments[ename] = experiment - + experiments[exp.name] = experiment return experiments