diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 82265b585..362b1a82b 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -7,8 +7,10 @@ from mlflow.entities import ViewType import os from pathlib import Path from contextlib import contextmanager +from typing import Optional, Text + from .exp import MLflowExperiment, Experiment -from .recorder import Recorder, MLflowRecorder +from .recorder import Recorder from ..log import get_module_logger logger = get_module_logger("workflow", "INFO") @@ -20,16 +22,24 @@ class ExpManager: (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) """ - def __init__(self, uri, default_exp_name): + def __init__(self, uri: Text, default_exp_name: Optional[Text]): self._default_uri = uri self._current_uri = None self.default_exp_name = default_exp_name self.active_experiment = None # only one experiment can active each time def __repr__(self): - return "{name}(default_uri={duri}, current_uri={curi})".format(name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri) + return "{name}(default_uri={duri}, current_uri={curi})".format( + name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri + ) - def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs): + def start_exp( + self, + experiment_name: Optional[Text] = None, + recorder_name: Optional[Text] = None, + uri: Optional[Text] = None, + **kwargs, + ): """ Start an experiment. This method includes first get_or_create an experiment, and then set it to be active. @@ -49,7 +59,7 @@ class ExpManager: """ raise NotImplementedError(f"Please implement the `start_exp` method.") - def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs): + def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs): """ End an active experiment. @@ -62,7 +72,7 @@ class ExpManager: """ raise NotImplementedError(f"Please implement the `end_exp` method.") - def create_exp(self, experiment_name=None): + def create_exp(self, experiment_name: Optional[Text] = None): """ Create an experiment. @@ -218,6 +228,30 @@ class ExpManager: """ return self._current_uri or self._default_uri + def set_uri(self, uri: Optional[Text] = None): + """ + Set the current tracking URI and the corresponding variables. + + Parameters + ---------- + uri : str + + """ + if uri is None: + logger.info("No tracking URI is provided. Use the default tracking URI.") + self._current_uri = self._default_uri + else: + # Temporarily re-set the current uri as the uri argument. + self._current_uri = uri + # Customized features for subclasses. + self._set_uri() + + def _set_uri(self): + """ + Customized features for subclasses' set_uri function. + """ + raise NotImplementedError(f"Please implement the `_set_uri` method.") + def list_experiments(self): """ List all the existing experiments. @@ -234,10 +268,14 @@ class MLflowExpManager(ExpManager): Use mlflow to implement ExpManager. """ - def __init__(self, uri, default_exp_name): + def __init__(self, uri: Text, default_exp_name: Optional[Text]): super(MLflowExpManager, self).__init__(uri, default_exp_name) self._client = None + def _set_uri(self): + self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) + logger.info('{:}'.format(self._client)) + @property def client(self): # Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib @@ -245,13 +283,11 @@ class MLflowExpManager(ExpManager): self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri) return self._client - def start_exp(self, experiment_name=None, recorder_name=None, uri=None): + def start_exp( + self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None + ): # Set the tracking uri - if uri is None: - logger.info("No tracking URI is provided. Use the default tracking URI.") - else: - # Temporarily re-set the current uri as the uri argument. - self._current_uri = uri + self.set_uri(uri) # Create experiment experiment, _ = self._get_or_create_exp(experiment_name=experiment_name) # Set up active experiment @@ -261,14 +297,14 @@ class MLflowExpManager(ExpManager): return self.active_experiment - def end_exp(self, recorder_status: str = Recorder.STATUS_S): + def end_exp(self, recorder_status: Text = Recorder.STATUS_S): if self.active_experiment is not None: self.active_experiment.end(recorder_status) self.active_experiment = None # When an experiment end, we will release the current uri. self._current_uri = None - def create_exp(self, experiment_name=None): + def create_exp(self, experiment_name: Optional[Text] = None): assert experiment_name is not None # init experiment experiment_id = self.client.create_experiment(experiment_name)