From 91fd53ab4d0724df73ccf8855ed83b6e1760bb08 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sat, 6 Mar 2021 05:33:08 -0800 Subject: [PATCH] Add reset_default_uri func for R and expm --- qlib/workflow/__init__.py | 11 ++++++++++- qlib/workflow/exp.py | 5 ++++- qlib/workflow/expm.py | 4 ++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 7bff505ce..54297ecd7 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from contextlib import contextmanager +from typing import Text, Optional from .expm import MLflowExpManager from .exp import Experiment from .recorder import Recorder @@ -20,7 +21,9 @@ class QlibRecorder: return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager) @contextmanager - def start(self, experiment_name=None, recorder_name=None, uri=None): + def start( + self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None + ): """ Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code: @@ -282,6 +285,12 @@ class QlibRecorder: """ return self.exp_manager.uri + def reset_default_uri(self, uri: Text): + """ + Method to reset the default uri of current experiment manager. + """ + self.exp_manager.reset_default_uri(uri) + def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None): """ Method for retrieving a recorder. diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 18b0a143d..9cda020c3 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -23,7 +23,7 @@ class Experiment: self.active_recorder = None # only one recorder can running each time def __repr__(self): - return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info) + return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info) def __str__(self): return str(self.info) @@ -175,6 +175,9 @@ class MLflowExperiment(Experiment): self._default_rec_name = "mlflow_recorder" self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) + def __repr__(self): + return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info) + def start(self, recorder_name=None): logger.info(f"Experiment {self.id} starts running ...") # set up recorder diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 4ba72a634..f9a5d0252 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -33,6 +33,10 @@ class ExpManager: name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri ) + def reset_default_uri(self, uri: Text): + self._default_uri = uri + self.set_uri(None) + def start_exp( self, experiment_name: Optional[Text] = None,