diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index e4fc8eef9..0ef062021 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 156beb690..80d471845 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -239,20 +239,17 @@ class MLflowExpManager(ExpManager): return self._client def start_exp(self, experiment_name=None, recorder_name=None, uri=None): + # set the tracking uri + if uri is None: + logger.info("No tracking URI is provided. Use the default tracking URI.") + else: + self.uri = uri # create experiment experiment, _ = self._get_or_create_exp(experiment_name=experiment_name) # set up active experiment self.active_experiment = experiment # start the experiment self.active_experiment.start(recorder_name) - # set the tracking uri - if uri is None: - logger.info( - "No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory." - ) - else: - self.uri = uri - mlflow.set_tracking_uri(self.uri) return self.active_experiment diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index b3069b9ac..b381a914a 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -224,6 +224,8 @@ class MLflowRecorder(Recorder): ) def start_run(self): + # set the tracking uri + mlflow.set_tracking_uri(self.uri) # start the run run = mlflow.start_run(self.id, self.experiment_id, self.name) # save the run id and artifact_uri