1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-06 04:20:57 +08:00

Update structure for resuming

This commit is contained in:
Jactus
2021-03-16 17:16:00 +08:00
parent 08b44ed727
commit 447fed8e54
3 changed files with 91 additions and 72 deletions

View File

@@ -39,7 +39,7 @@ class Experiment:
output["recorders"] = list(recorders.keys())
return output
def start(self, recorder_name=None):
def start(self, recorder_name=None, resume=False):
"""
Start the experiment and set it to be active. This method will also start a new recorder.
@@ -47,6 +47,8 @@ class Experiment:
----------
recorder_name : str
the name of the recorder to be created.
resume : bool
whether to resume the first recorder
Returns
-------
@@ -149,59 +151,6 @@ class Experiment:
-------
A recorder object.
"""
raise NotImplementedError(f"Please implement the `get_recorder` method.")
def list_recorders(self):
"""
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
Returns
-------
A dictionary (id -> recorder) of recorder information that being stored.
"""
raise NotImplementedError(f"Please implement the `list_recorders` method.")
class MLflowExperiment(Experiment):
"""
Use mlflow to implement Experiment.
"""
def __init__(self, id, name, uri):
super(MLflowExperiment, self).__init__(id, name)
self._uri = uri
self._default_name = None
self._default_rec_name = "mlflow_recorder"
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
def __repr__(self):
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
def start(self, recorder_name=None):
logger.info(f"Experiment {self.id} starts running ...")
# Get or create recorder
recorder, _ = self._get_or_create_rec(recorder_name=recorder_name)
# Set up active recorder
self.active_recorder = recorder
# Start the recorder
self.active_recorder.start_run()
return self.active_recorder
def end(self, recorder_status):
if self.active_recorder is not None:
self.active_recorder.end_run(recorder_status)
self.active_recorder = None
def create_recorder(self, recorder_name=None):
if recorder_name is None:
recorder_name = self._default_rec_name
recorder = MLflowRecorder(self.id, self._uri, recorder_name)
return recorder
def get_recorder(self, recorder_id=None, recorder_name=None, create=True):
# special case of getting the recorder
if recorder_id is None and recorder_name is None:
if self.active_recorder is not None:
@@ -232,6 +181,63 @@ class MLflowExperiment(Experiment):
logger.info(f"No valid recorder found. Create a new recorder with name {recorder_name}.")
return self.create_recorder(recorder_name), True
def list_recorders(self):
"""
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
Returns
-------
A dictionary (id -> recorder) of recorder information that being stored.
"""
raise NotImplementedError(f"Please implement the `list_recorders` method.")
class MLflowExperiment(Experiment):
"""
Use mlflow to implement Experiment.
"""
def __init__(self, id, name, uri):
super(MLflowExperiment, self).__init__(id, name)
self._uri = uri
self._default_name = None
self._default_rec_name = "mlflow_recorder"
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
def __repr__(self):
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
def start(self, recorder_name=None, resume=False):
logger.info(f"Experiment {self.id} starts running ...")
# Get or create recorder
if recorder_name is None:
recorder_name = self._default_rec_name
# resume the recorder
if resume:
recorder, _ = self._get_or_create_rec(recorder_name=recorder_name)
# create a new recorder
else:
recorder = self.create_recorder(recorder_name)
# Set up active recorder
self.active_recorder = recorder
# Start the recorder
self.active_recorder.start_run()
return self.active_recorder
def end(self, recorder_status):
if self.active_recorder is not None:
self.active_recorder.end_run(recorder_status)
self.active_recorder = None
def create_recorder(self, recorder_name=None):
if recorder_name is None:
recorder_name = self._default_rec_name
recorder = MLflowRecorder(self.id, self._uri, recorder_name)
return recorder
def _get_recorder(self, recorder_id=None, recorder_name=None):
"""
Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will
@@ -249,7 +255,7 @@ class MLflowExperiment(Experiment):
raise ValueError("No valid recorder has been found, please make sure the input recorder id is correct.")
elif recorder_name is not None:
logger.warning(
f"Please make sure the recorder name {recorder_name} is unique, we will only return the first recorder if there exist several matched the given name."
f"Please make sure the recorder name {recorder_name} is unique, we will only return the latest recorder if there exist several matched the given name."
)
recorders = self.list_recorders()
for rid in recorders:
@@ -283,7 +289,7 @@ class MLflowExperiment(Experiment):
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
def list_recorders(self, max_results=UNLIMITED):
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
recorders = dict()
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])