mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Optimize the implementation of uri & Fix async log bug (#1364)
* Optimize the implementation of uri * remove redundant func * Set the right order of _set_client_uri * Update qlib/workflow/expm.py * Simplify client & add test.Add docs; Fix async bug * Fix comments & pylint * Improve README
This commit is contained in:
@@ -8,7 +8,6 @@ from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
from ..utils import Wrapper
|
||||
from ..utils.exceptions import RecorderInitializationError
|
||||
from qlib.config import C
|
||||
|
||||
|
||||
class QlibRecorder:
|
||||
@@ -347,14 +346,14 @@ class QlibRecorder:
|
||||
|
||||
def set_uri(self, uri: Optional[Text]):
|
||||
"""
|
||||
Method to reset the current uri of current experiment manager.
|
||||
Method to reset the **default** uri of current experiment manager.
|
||||
|
||||
NOTE:
|
||||
|
||||
- When the uri is refer to a file path, please using the absolute path instead of strings like "~/mlruns/"
|
||||
The backend don't support strings like this.
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
self.exp_manager.default_uri = uri
|
||||
|
||||
@contextmanager
|
||||
def uri_context(self, uri: Text):
|
||||
@@ -370,11 +369,11 @@ class QlibRecorder:
|
||||
the temporal uri
|
||||
"""
|
||||
prev_uri = self.exp_manager.default_uri
|
||||
C.exp_manager["kwargs"]["uri"] = uri
|
||||
self.exp_manager.default_uri = uri
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
C.exp_manager["kwargs"]["uri"] = prev_uri
|
||||
self.exp_manager.default_uri = prev_uri
|
||||
|
||||
def get_recorder(
|
||||
self,
|
||||
|
||||
@@ -249,7 +249,6 @@ class MLflowExperiment(Experiment):
|
||||
def __init__(self, id, name, uri):
|
||||
super(MLflowExperiment, self).__init__(id, name)
|
||||
self._uri = uri
|
||||
self._default_name = None
|
||||
self._default_rec_name = "mlflow_recorder"
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
|
||||
|
||||
@@ -15,23 +15,32 @@ from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
from ..utils.exceptions import ExpAlreadyExistError
|
||||
|
||||
|
||||
logger = get_module_logger("workflow")
|
||||
|
||||
|
||||
class ExpManager:
|
||||
"""
|
||||
This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
|
||||
The `ExpManager` is expected to be a singleton (btw, we can have multiple `Experiment`s with different uri. user can get different experiments from different uri, and then compare records of them). Global Config (i.e. `C`) is also a singleton.
|
||||
So we try to align them together. They share the same variable, which is called **default uri**. Please refer to `ExpManager.default_uri` for details of variable sharing.
|
||||
|
||||
When the user starts an experiment, the user may want to set the uri to a specific uri (it will override **default uri** during this period), and then unset the **specific uri** and fallback to the **default uri**. `ExpManager._active_exp_uri` is that **specific uri**.
|
||||
"""
|
||||
|
||||
active_experiment: Optional[Experiment]
|
||||
|
||||
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
|
||||
self._current_uri = uri
|
||||
self.default_uri = uri
|
||||
self._active_exp_uri = None # No active experiments. So it is set to None
|
||||
self._default_exp_name = default_exp_name
|
||||
self.active_experiment = None # only one experiment can be active each time
|
||||
logger.info(f"experiment manager uri is at {self._current_uri}")
|
||||
logger.info(f"experiment manager uri is at {self.uri}")
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(current_uri={curi})".format(name=self.__class__.__name__, curi=self._current_uri)
|
||||
return "{name}(uri={uri})".format(name=self.__class__.__name__, uri=self.uri)
|
||||
|
||||
def start_exp(
|
||||
self,
|
||||
@@ -43,11 +52,13 @@ class ExpManager:
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Experiment:
|
||||
"""
|
||||
Start an experiment. This method includes first get_or_create an experiment, and then
|
||||
set it to be active.
|
||||
|
||||
Maintaining `_active_exp_uri` is included in start_exp, remaining implementation should be included in _end_exp in subclass
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
@@ -67,12 +78,28 @@ class ExpManager:
|
||||
-------
|
||||
An active experiment.
|
||||
"""
|
||||
self._active_exp_uri = uri
|
||||
# The subclass may set the underlying uri back.
|
||||
# So setting `_active_exp_uri` come before `_start_exp`
|
||||
return self._start_exp(
|
||||
experiment_id=experiment_id,
|
||||
experiment_name=experiment_name,
|
||||
recorder_id=recorder_id,
|
||||
recorder_name=recorder_name,
|
||||
resume=resume,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _start_exp(self, *args, **kwargs) -> Experiment:
|
||||
"""Please refer to the doc of `start_exp`"""
|
||||
raise NotImplementedError(f"Please implement the `start_exp` method.")
|
||||
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
|
||||
"""
|
||||
End an active experiment.
|
||||
|
||||
Maintaining `_active_exp_uri` is included in end_exp, remaining implementation should be included in _end_exp in subclass
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_name : str
|
||||
@@ -80,6 +107,12 @@ class ExpManager:
|
||||
recorder_status : str
|
||||
the status of the active recorder of the experiment.
|
||||
"""
|
||||
self._active_exp_uri = None
|
||||
# The subclass may set the underlying uri back.
|
||||
# So setting `_active_exp_uri` come before `_end_exp`
|
||||
self._end_exp(recorder_status=recorder_status, **kwargs)
|
||||
|
||||
def _end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `end_exp` method.")
|
||||
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
@@ -254,6 +287,10 @@ class ExpManager:
|
||||
raise ValueError("The default URI is not set in qlib.config.C")
|
||||
return C.exp_manager["kwargs"]["uri"]
|
||||
|
||||
@default_uri.setter
|
||||
def default_uri(self, value):
|
||||
C.exp_manager.setdefault("kwargs", {})["uri"] = value
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
"""
|
||||
@@ -263,33 +300,7 @@ class ExpManager:
|
||||
-------
|
||||
The tracking URI string.
|
||||
"""
|
||||
return self._current_uri or self.default_uri
|
||||
|
||||
def set_uri(self, uri: Optional[Text] = None):
|
||||
"""
|
||||
Set the current tracking URI and the corresponding variables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
|
||||
"""
|
||||
if uri is None:
|
||||
if self._current_uri is None:
|
||||
logger.debug("No tracking URI is provided. Use the default tracking URI.")
|
||||
self._current_uri = self.default_uri
|
||||
else:
|
||||
# Temporarily re-set the current uri as the uri argument.
|
||||
self._current_uri = uri
|
||||
# Customized features for subclasses.
|
||||
self._set_uri()
|
||||
|
||||
def _set_uri(self):
|
||||
"""
|
||||
Customized features for subclasses' set_uri function.
|
||||
This method is designed for the underlying experiment backend storage.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_set_uri` method.")
|
||||
return self._active_exp_uri or self.default_uri
|
||||
|
||||
def list_experiments(self):
|
||||
"""
|
||||
@@ -307,33 +318,21 @@ class MLflowExpManager(ExpManager):
|
||||
Use mlflow to implement ExpManager.
|
||||
"""
|
||||
|
||||
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
|
||||
super(MLflowExpManager, self).__init__(uri, default_exp_name)
|
||||
self._client = None
|
||||
|
||||
def _set_uri(self):
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
logger.info("{:}".format(self._client))
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
|
||||
if self._client is None:
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
return self._client
|
||||
# Please refer to `tests/dependency_tests/test_mlflow.py::MLflowTest::test_creating_client`
|
||||
# The test ensure the speed of create a new client
|
||||
return mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
|
||||
def start_exp(
|
||||
def _start_exp(
|
||||
self,
|
||||
*,
|
||||
experiment_id: Optional[Text] = None,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_id: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
):
|
||||
# Set the tracking uri
|
||||
self.set_uri(uri)
|
||||
# Create experiment
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
@@ -345,12 +344,10 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S):
|
||||
def _end_exp(self, recorder_status: Text = Recorder.STATUS_S):
|
||||
if self.active_experiment is not None:
|
||||
self.active_experiment.end(recorder_status)
|
||||
self.active_experiment = None
|
||||
# When an experiment end, we will release the current uri.
|
||||
self._current_uri = None
|
||||
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
assert experiment_name is not None
|
||||
@@ -362,9 +359,7 @@ class MLflowExpManager(ExpManager):
|
||||
raise ExpAlreadyExistError() from e
|
||||
raise e
|
||||
|
||||
experiment = MLflowExperiment(experiment_id, experiment_name, self.uri)
|
||||
experiment._default_name = self._default_exp_name
|
||||
return experiment
|
||||
return MLflowExperiment(experiment_id, experiment_name, self.uri)
|
||||
|
||||
def _get_exp(self, experiment_id=None, experiment_name=None):
|
||||
"""
|
||||
|
||||
@@ -378,14 +378,15 @@ class MLflowRecorder(Recorder):
|
||||
Recorder.STATUS_FI,
|
||||
Recorder.STATUS_FA,
|
||||
], f"The status type {status} is not supported."
|
||||
mlflow.end_run(status)
|
||||
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self.status != Recorder.STATUS_S:
|
||||
self.status = status
|
||||
if self.async_log is not None:
|
||||
# Waiting Queue should go before mlflow.end_run. Otherwise mlflow will raise error
|
||||
with TimeInspector.logt("waiting `async_log`"):
|
||||
self.async_log.wait()
|
||||
self.async_log = None
|
||||
mlflow.end_run(status)
|
||||
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
|
||||
2
setup.py
2
setup.py
@@ -62,7 +62,7 @@ REQUIRED = [
|
||||
"matplotlib>=3.3",
|
||||
"tables>=3.6.1",
|
||||
"pyyaml>=5.3.1",
|
||||
"mlflow>=1.12.1",
|
||||
"mlflow>=1.12.1, <=1.30.0",
|
||||
"tqdm",
|
||||
"loguru",
|
||||
"lightgbm>=3.3.0",
|
||||
|
||||
3
tests/dependency_tests/README.md
Normal file
3
tests/dependency_tests/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
Some implementations of Qlib depend on some assumptions of its dependencies.
|
||||
|
||||
So some tests are requried to ensure that these assumptions are valid.
|
||||
34
tests/dependency_tests/test_mlflow.py
Normal file
34
tests/dependency_tests/test_mlflow.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import unittest
|
||||
import mlflow
|
||||
import time
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
|
||||
class MLflowTest(unittest.TestCase):
|
||||
TMP_PATH = Path("./.mlruns_tmp/")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
if self.TMP_PATH.exists():
|
||||
shutil.rmtree(self.TMP_PATH)
|
||||
|
||||
def test_creating_client(self):
|
||||
"""
|
||||
Please refer to qlib/workflow/expm.py:MLflowExpManager._client
|
||||
we don't cache _client (this is helpful to reduce maintainance work when MLflowExpManager's uri is chagned)
|
||||
|
||||
This implementation is based on the assumption creating a client is fast
|
||||
"""
|
||||
start = time.time()
|
||||
for i in range(10):
|
||||
_ = mlflow.tracking.MlflowClient(tracking_uri=str(self.TMP_PATH))
|
||||
end = time.time()
|
||||
elasped = end - start
|
||||
self.assertLess(elasped, 1e-2) # it can be done in less than 10ms
|
||||
print(elasped)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user