mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-06 04:20:57 +08:00
Update structure for resuming
This commit is contained in:
@@ -39,7 +39,7 @@ class Experiment:
|
||||
output["recorders"] = list(recorders.keys())
|
||||
return output
|
||||
|
||||
def start(self, recorder_name=None):
|
||||
def start(self, recorder_name=None, resume=False):
|
||||
"""
|
||||
Start the experiment and set it to be active. This method will also start a new recorder.
|
||||
|
||||
@@ -47,6 +47,8 @@ class Experiment:
|
||||
----------
|
||||
recorder_name : str
|
||||
the name of the recorder to be created.
|
||||
resume : bool
|
||||
whether to resume the first recorder
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -149,59 +151,6 @@ class Experiment:
|
||||
-------
|
||||
A recorder object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_recorder` method.")
|
||||
|
||||
def list_recorders(self):
|
||||
"""
|
||||
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
|
||||
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary (id -> recorder) of recorder information that being stored.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_recorders` method.")
|
||||
|
||||
|
||||
class MLflowExperiment(Experiment):
|
||||
"""
|
||||
Use mlflow to implement Experiment.
|
||||
"""
|
||||
|
||||
def __init__(self, id, name, uri):
|
||||
super(MLflowExperiment, self).__init__(id, name)
|
||||
self._uri = uri
|
||||
self._default_name = None
|
||||
self._default_rec_name = "mlflow_recorder"
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
|
||||
|
||||
def start(self, recorder_name=None):
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
# Get or create recorder
|
||||
recorder, _ = self._get_or_create_rec(recorder_name=recorder_name)
|
||||
# Set up active recorder
|
||||
self.active_recorder = recorder
|
||||
# Start the recorder
|
||||
self.active_recorder.start_run()
|
||||
|
||||
return self.active_recorder
|
||||
|
||||
def end(self, recorder_status):
|
||||
if self.active_recorder is not None:
|
||||
self.active_recorder.end_run(recorder_status)
|
||||
self.active_recorder = None
|
||||
|
||||
def create_recorder(self, recorder_name=None):
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
recorder = MLflowRecorder(self.id, self._uri, recorder_name)
|
||||
|
||||
return recorder
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, create=True):
|
||||
# special case of getting the recorder
|
||||
if recorder_id is None and recorder_name is None:
|
||||
if self.active_recorder is not None:
|
||||
@@ -232,6 +181,63 @@ class MLflowExperiment(Experiment):
|
||||
logger.info(f"No valid recorder found. Create a new recorder with name {recorder_name}.")
|
||||
return self.create_recorder(recorder_name), True
|
||||
|
||||
def list_recorders(self):
|
||||
"""
|
||||
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
|
||||
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary (id -> recorder) of recorder information that being stored.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_recorders` method.")
|
||||
|
||||
|
||||
class MLflowExperiment(Experiment):
|
||||
"""
|
||||
Use mlflow to implement Experiment.
|
||||
"""
|
||||
|
||||
def __init__(self, id, name, uri):
|
||||
super(MLflowExperiment, self).__init__(id, name)
|
||||
self._uri = uri
|
||||
self._default_name = None
|
||||
self._default_rec_name = "mlflow_recorder"
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
|
||||
|
||||
def start(self, recorder_name=None, resume=False):
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
# Get or create recorder
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
# resume the recorder
|
||||
if resume:
|
||||
recorder, _ = self._get_or_create_rec(recorder_name=recorder_name)
|
||||
# create a new recorder
|
||||
else:
|
||||
recorder = self.create_recorder(recorder_name)
|
||||
# Set up active recorder
|
||||
self.active_recorder = recorder
|
||||
# Start the recorder
|
||||
self.active_recorder.start_run()
|
||||
|
||||
return self.active_recorder
|
||||
|
||||
def end(self, recorder_status):
|
||||
if self.active_recorder is not None:
|
||||
self.active_recorder.end_run(recorder_status)
|
||||
self.active_recorder = None
|
||||
|
||||
def create_recorder(self, recorder_name=None):
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
recorder = MLflowRecorder(self.id, self._uri, recorder_name)
|
||||
|
||||
return recorder
|
||||
|
||||
def _get_recorder(self, recorder_id=None, recorder_name=None):
|
||||
"""
|
||||
Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will
|
||||
@@ -249,7 +255,7 @@ class MLflowExperiment(Experiment):
|
||||
raise ValueError("No valid recorder has been found, please make sure the input recorder id is correct.")
|
||||
elif recorder_name is not None:
|
||||
logger.warning(
|
||||
f"Please make sure the recorder name {recorder_name} is unique, we will only return the first recorder if there exist several matched the given name."
|
||||
f"Please make sure the recorder name {recorder_name} is unique, we will only return the latest recorder if there exist several matched the given name."
|
||||
)
|
||||
recorders = self.list_recorders()
|
||||
for rid in recorders:
|
||||
@@ -283,7 +289,7 @@ class MLflowExperiment(Experiment):
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED):
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
|
||||
Reference in New Issue
Block a user