mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
Make mlflow client consistant with uri
This commit is contained in:
@@ -7,8 +7,10 @@ from mlflow.entities import ViewType
|
||||
import os
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Text
|
||||
|
||||
from .exp import MLflowExperiment, Experiment
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
@@ -20,16 +22,24 @@ class ExpManager:
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
"""
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
|
||||
self._default_uri = uri
|
||||
self._current_uri = None
|
||||
self.default_exp_name = default_exp_name
|
||||
self.active_experiment = None # only one experiment can active each time
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(default_uri={duri}, current_uri={curi})".format(name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri)
|
||||
return "{name}(default_uri={duri}, current_uri={curi})".format(
|
||||
name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri
|
||||
)
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs):
|
||||
def start_exp(
|
||||
self,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Start an experiment. This method includes first get_or_create an experiment, and then
|
||||
set it to be active.
|
||||
@@ -49,7 +59,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start_exp` method.")
|
||||
|
||||
def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs):
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
|
||||
"""
|
||||
End an active experiment.
|
||||
|
||||
@@ -62,7 +72,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end_exp` method.")
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
"""
|
||||
Create an experiment.
|
||||
|
||||
@@ -218,6 +228,30 @@ class ExpManager:
|
||||
"""
|
||||
return self._current_uri or self._default_uri
|
||||
|
||||
def set_uri(self, uri: Optional[Text] = None):
|
||||
"""
|
||||
Set the current tracking URI and the corresponding variables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
|
||||
"""
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
self._current_uri = self._default_uri
|
||||
else:
|
||||
# Temporarily re-set the current uri as the uri argument.
|
||||
self._current_uri = uri
|
||||
# Customized features for subclasses.
|
||||
self._set_uri()
|
||||
|
||||
def _set_uri(self):
|
||||
"""
|
||||
Customized features for subclasses' set_uri function.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_set_uri` method.")
|
||||
|
||||
def list_experiments(self):
|
||||
"""
|
||||
List all the existing experiments.
|
||||
@@ -234,10 +268,14 @@ class MLflowExpManager(ExpManager):
|
||||
Use mlflow to implement ExpManager.
|
||||
"""
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
|
||||
super(MLflowExpManager, self).__init__(uri, default_exp_name)
|
||||
self._client = None
|
||||
|
||||
def _set_uri(self):
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
logger.info('{:}'.format(self._client))
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
|
||||
@@ -245,13 +283,11 @@ class MLflowExpManager(ExpManager):
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
return self._client
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
|
||||
def start_exp(
|
||||
self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None
|
||||
):
|
||||
# Set the tracking uri
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
else:
|
||||
# Temporarily re-set the current uri as the uri argument.
|
||||
self._current_uri = uri
|
||||
self.set_uri(uri)
|
||||
# Create experiment
|
||||
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
|
||||
# Set up active experiment
|
||||
@@ -261,14 +297,14 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
def end_exp(self, recorder_status: str = Recorder.STATUS_S):
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S):
|
||||
if self.active_experiment is not None:
|
||||
self.active_experiment.end(recorder_status)
|
||||
self.active_experiment = None
|
||||
# When an experiment end, we will release the current uri.
|
||||
self._current_uri = None
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
assert experiment_name is not None
|
||||
# init experiment
|
||||
experiment_id = self.client.create_experiment(experiment_name)
|
||||
|
||||
Reference in New Issue
Block a user