1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

Make mlflow client consistant with uri

This commit is contained in:
D-X-Y
2021-03-04 22:33:35 -08:00
parent c4d6e00470
commit 452fb8f013

View File

@@ -7,8 +7,10 @@ from mlflow.entities import ViewType
import os
from pathlib import Path
from contextlib import contextmanager
from typing import Optional, Text
from .exp import MLflowExperiment, Experiment
from .recorder import Recorder, MLflowRecorder
from .recorder import Recorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
@@ -20,16 +22,24 @@ class ExpManager:
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
"""
def __init__(self, uri, default_exp_name):
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
self._default_uri = uri
self._current_uri = None
self.default_exp_name = default_exp_name
self.active_experiment = None # only one experiment can active each time
def __repr__(self):
return "{name}(default_uri={duri}, current_uri={curi})".format(name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri)
return "{name}(default_uri={duri}, current_uri={curi})".format(
name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri
)
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs):
def start_exp(
self,
experiment_name: Optional[Text] = None,
recorder_name: Optional[Text] = None,
uri: Optional[Text] = None,
**kwargs,
):
"""
Start an experiment. This method includes first get_or_create an experiment, and then
set it to be active.
@@ -49,7 +59,7 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `start_exp` method.")
def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs):
def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
"""
End an active experiment.
@@ -62,7 +72,7 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `end_exp` method.")
def create_exp(self, experiment_name=None):
def create_exp(self, experiment_name: Optional[Text] = None):
"""
Create an experiment.
@@ -218,6 +228,30 @@ class ExpManager:
"""
return self._current_uri or self._default_uri
def set_uri(self, uri: Optional[Text] = None):
"""
Set the current tracking URI and the corresponding variables.
Parameters
----------
uri : str
"""
if uri is None:
logger.info("No tracking URI is provided. Use the default tracking URI.")
self._current_uri = self._default_uri
else:
# Temporarily re-set the current uri as the uri argument.
self._current_uri = uri
# Customized features for subclasses.
self._set_uri()
def _set_uri(self):
"""
Customized features for subclasses' set_uri function.
"""
raise NotImplementedError(f"Please implement the `_set_uri` method.")
def list_experiments(self):
"""
List all the existing experiments.
@@ -234,10 +268,14 @@ class MLflowExpManager(ExpManager):
Use mlflow to implement ExpManager.
"""
def __init__(self, uri, default_exp_name):
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
super(MLflowExpManager, self).__init__(uri, default_exp_name)
self._client = None
def _set_uri(self):
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
logger.info('{:}'.format(self._client))
@property
def client(self):
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
@@ -245,13 +283,11 @@ class MLflowExpManager(ExpManager):
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
return self._client
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
def start_exp(
self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None
):
# Set the tracking uri
if uri is None:
logger.info("No tracking URI is provided. Use the default tracking URI.")
else:
# Temporarily re-set the current uri as the uri argument.
self._current_uri = uri
self.set_uri(uri)
# Create experiment
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
# Set up active experiment
@@ -261,14 +297,14 @@ class MLflowExpManager(ExpManager):
return self.active_experiment
def end_exp(self, recorder_status: str = Recorder.STATUS_S):
def end_exp(self, recorder_status: Text = Recorder.STATUS_S):
if self.active_experiment is not None:
self.active_experiment.end(recorder_status)
self.active_experiment = None
# When an experiment end, we will release the current uri.
self._current_uri = None
def create_exp(self, experiment_name=None):
def create_exp(self, experiment_name: Optional[Text] = None):
assert experiment_name is not None
# init experiment
experiment_id = self.client.create_experiment(experiment_name)