1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

Refine default uri in expm

This commit is contained in:
D-X-Y
2021-03-11 02:49:03 +00:00
parent f6ed175070
commit cda96be8c3
2 changed files with 17 additions and 14 deletions

View File

@@ -285,11 +285,11 @@ class QlibRecorder:
"""
return self.exp_manager.uri
def reset_default_uri(self, uri: Text):
def set_uri(self, uri: Optional[Text]):
"""
Method to reset the default uri of current experiment manager.
Method to reset the current uri of current experiment manager.
"""
self.exp_manager.reset_default_uri(uri)
self.exp_manager.set_uri(uri)
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
"""

View File

@@ -10,6 +10,7 @@ from contextlib import contextmanager
from typing import Optional, Text
from .exp import MLflowExperiment, Experiment
from ..config import C
from .recorder import Recorder
from ..log import get_module_logger
@@ -23,19 +24,12 @@ class ExpManager:
"""
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
self._default_uri = uri
self._current_uri = None
self._current_uri = uri
self.default_exp_name = default_exp_name
self.active_experiment = None # only one experiment can active each time
def __repr__(self):
return "{name}(default_uri={duri}, current_uri={curi})".format(
name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri
)
def reset_default_uri(self, uri: Text):
self._default_uri = uri
self.set_uri(None)
return "{name}(current_uri={curi})".format(name=self.__class__.__name__, curi=self._current_uri)
def start_exp(
self,
@@ -221,6 +215,15 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `delete_exp` method.")
@property
def default_uri(self):
"""
Get the default tracking URI from qlib.config.C
"""
if "kwargs" not in C.exp_manager or "uri" not in C.exp_manager["kwargs"]:
raise ValueError("The default URI is not set in qlib.config.C")
return C.exp_manager["kwargs"]["uri"]
@property
def uri(self):
"""
@@ -230,7 +233,7 @@ class ExpManager:
-------
The tracking URI string.
"""
return self._current_uri or self._default_uri
return self._current_uri or self.default_uri
def set_uri(self, uri: Optional[Text] = None):
"""
@@ -243,7 +246,7 @@ class ExpManager:
"""
if uri is None:
logger.info("No tracking URI is provided. Use the default tracking URI.")
self._current_uri = self._default_uri
self._current_uri = self.default_uri
else:
# Temporarily re-set the current uri as the uri argument.
self._current_uri = uri