mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Fix workflow
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
|
||||
@@ -239,20 +239,17 @@ class MLflowExpManager(ExpManager):
|
||||
return self._client
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
else:
|
||||
self.uri = uri
|
||||
# create experiment
|
||||
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
|
||||
# set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# start the experiment
|
||||
self.active_experiment.start(recorder_name)
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
logger.info(
|
||||
"No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory."
|
||||
)
|
||||
else:
|
||||
self.uri = uri
|
||||
mlflow.set_tracking_uri(self.uri)
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
|
||||
@@ -224,6 +224,8 @@ class MLflowRecorder(Recorder):
|
||||
)
|
||||
|
||||
def start_run(self):
|
||||
# set the tracking uri
|
||||
mlflow.set_tracking_uri(self.uri)
|
||||
# start the run
|
||||
run = mlflow.start_run(self.id, self.experiment_id, self.name)
|
||||
# save the run id and artifact_uri
|
||||
|
||||
Reference in New Issue
Block a user