mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Adjust interface
This commit is contained in:
@@ -96,7 +96,7 @@ class DatasetH(Dataset):
|
||||
}
|
||||
"""
|
||||
self._handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self._segments = segments
|
||||
self._segments = segments.copy()
|
||||
|
||||
def prepare(
|
||||
self, segments: Union[List[str], Tuple[str], str, slice], col_set=DataHandler.CS_ALL, **kwargs
|
||||
|
||||
@@ -156,8 +156,9 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig)
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(self._data, col_set)
|
||||
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
if squeeze:
|
||||
# squeeze columns
|
||||
df = df.squeeze()
|
||||
@@ -417,8 +418,9 @@ class DataHandlerLP(DataHandler):
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = self._get_df_by_key(data_key)
|
||||
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
return self._fetch_df_by_col(df, col_set)
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
|
||||
"""
|
||||
|
||||
@@ -10,13 +10,51 @@ class Serializable:
|
||||
Serializable behaves like pickle.
|
||||
But it only saves the state whose name **does not** start with `_`
|
||||
"""
|
||||
def __init__(self):
|
||||
self._dump_all = False
|
||||
self._exclude = []
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.__dict__.items() if k not in self.exclude and (self.dump_all or not k.startswith("_"))
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def to_pickle(self, path: [Path, str]):
|
||||
@property
|
||||
def dump_all(self):
|
||||
"""
|
||||
will the object dump all object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
"""
|
||||
return getattr(self, "_dump_all", False)
|
||||
|
||||
@property
|
||||
def exclude(self):
|
||||
"""
|
||||
What attribute will be dumped
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
"""
|
||||
return getattr(self, "_exclude", [])
|
||||
|
||||
def config(self, dump_all: bool = None, exclude: list = None):
|
||||
if dump_all is not None:
|
||||
self._dump_all = dump_all
|
||||
|
||||
if exclude is not None:
|
||||
self._exclude = exclude
|
||||
|
||||
def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None):
|
||||
self.config(dump_all=dump_all, exclude=exclude)
|
||||
with Path(path).open("wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
@@ -6,7 +6,7 @@ from mlflow.exceptions import MlflowException
|
||||
import os
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from .exp import MLflowExperiment
|
||||
from .exp import MLflowExperiment, Experiment
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
@@ -128,7 +128,61 @@ class ExpManager:
|
||||
-------
|
||||
An experiment object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_exp` method.")
|
||||
# special case of getting experiment
|
||||
if experiment_id is None and experiment_name is None:
|
||||
if self.active_experiment is not None:
|
||||
return self.active_experiment
|
||||
# User don't want get active code now.
|
||||
# Don't assume underlying code could handle the case of two None
|
||||
if experiment_id is None and experiment_name is None:
|
||||
experiment_name = self.default_exp_name
|
||||
|
||||
if create:
|
||||
exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
|
||||
else:
|
||||
exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
if is_new:
|
||||
self.active_experiment = exp
|
||||
# start the recorder
|
||||
self.active_experiment.start()
|
||||
return exp
|
||||
|
||||
def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (object, bool):
|
||||
"""
|
||||
Method for getting or creating an experiment. It will try to first get a valid experiment, if exception occurs, it will
|
||||
automatically create a new experiment based on the given id and name.
|
||||
"""
|
||||
try:
|
||||
if experiment_id is None and experiment_name is None:
|
||||
experiment_name = self.default_exp_name
|
||||
return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment_name = self.default_exp_name
|
||||
logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
return self.create_exp(experiment_name), True
|
||||
|
||||
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
|
||||
"""
|
||||
get specific experiment by name or id. If it does not exist, raise ValueError
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id :
|
||||
The id of experiment
|
||||
experiment_name :
|
||||
The id name experiment
|
||||
|
||||
Returns
|
||||
-------
|
||||
Experiment:
|
||||
The searched experiment
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_get_exp` method")
|
||||
|
||||
def delete_exp(self, experiment_id=None, experiment_name=None):
|
||||
"""
|
||||
@@ -197,6 +251,7 @@ class MLflowExpManager(ExpManager):
|
||||
self.active_experiment = None
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
assert(experiment_name is not None)
|
||||
# init experiment
|
||||
experiment_id = self.client.create_experiment(experiment_name)
|
||||
experiment = MLflowExperiment(experiment_id, experiment_name, self.uri)
|
||||
@@ -204,34 +259,6 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
return experiment
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create=True):
|
||||
# special case of getting experiment
|
||||
if experiment_id is None and experiment_name is None:
|
||||
if self.active_experiment is not None:
|
||||
return self.active_experiment
|
||||
if create:
|
||||
exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
|
||||
else:
|
||||
exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
if is_new:
|
||||
self.active_experiment = exp
|
||||
# start the recorder
|
||||
self.active_experiment.start()
|
||||
return exp
|
||||
|
||||
def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (object, bool):
|
||||
"""
|
||||
Method for getting or creating an experiment. It will try to first get a valid experiment, if exception occurs, it will
|
||||
automatically create a new experiment based on the given id and name.
|
||||
"""
|
||||
try:
|
||||
return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment = self.default_exp_name
|
||||
logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
return self.create_exp(experiment_name), True
|
||||
|
||||
def _get_exp(self, experiment_id=None, experiment_name=None):
|
||||
"""
|
||||
Method for getting or creating an experiment. It will try to first get a valid experiment, if exception occurs, it will
|
||||
@@ -247,7 +274,7 @@ class MLflowExpManager(ExpManager):
|
||||
raise MlflowException("No valid experiment has been found.")
|
||||
experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)
|
||||
return experiment
|
||||
except MlflowException as e:
|
||||
except MlflowException:
|
||||
raise ValueError(
|
||||
"No valid experiment has been found, please make sure the input experiment id is correct."
|
||||
)
|
||||
@@ -293,6 +320,5 @@ class MLflowExpManager(ExpManager):
|
||||
experiments = dict()
|
||||
for exp in exps:
|
||||
experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)
|
||||
experiments[ename] = experiment
|
||||
|
||||
experiments[exp.name] = experiment
|
||||
return experiments
|
||||
|
||||
Reference in New Issue
Block a user