From a939445da39ae3e289aaeaba17cff6ee7d93fa2e Mon Sep 17 00:00:00 2001 From: Young Date: Sun, 29 Nov 2020 13:00:35 +0000 Subject: [PATCH] fix mlflow bug --- qlib/workflow/exp.py | 13 ++++++++----- qlib/workflow/expm.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index c23f27f09..09c680e59 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import mlflow +from mlflow.entities import ViewType from mlflow.exceptions import MlflowException from pathlib import Path from .recorder import Recorder, MLflowRecorder @@ -226,7 +227,7 @@ class MLflowExperiment(Experiment): if recorder_name is None: recorder_name = self._default_rec_name logger.info(f"No valid recorder found. Create a new recorder with name {recorder_name}.") - return self.create(recorder_name), True + return self.create_recorder(recorder_name), True def _get_recorder(self, recorder_id=None, recorder_name=None): """ @@ -241,7 +242,7 @@ class MLflowExperiment(Experiment): run = self.client.get_run(recorder_id) recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run) return recorder - except MlflowException as e: + except MlflowException: raise ValueError("No valid recorder has been found, please make sure the input recorder id is correct.") elif recorder_name is not None: logger.warning( @@ -269,15 +270,17 @@ class MLflowExperiment(Experiment): if recorder_id is not None: self.client.delete_run(recorder_id) else: - recorder = self._get_recorder_by_name(recorder_name) + recorder = self._get_recorder(recorder_name=recorder_name) self.client.delete_run(recorder.id) except MlflowException as e: raise Exception( f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct." ) - def list_recorders(self): - runs = self.client.search_runs(self.id, run_view_type=1)[::-1] + UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!! + + def list_recorders(self, max_results=UNLIMITED): + runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1] recorders = dict() for i in range(len(runs)): recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i]) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 80d471845..cfb0290fc 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -3,6 +3,7 @@ import mlflow from mlflow.exceptions import MlflowException +from mlflow.entities import ViewType import os from pathlib import Path from contextlib import contextmanager @@ -324,7 +325,7 @@ class MLflowExpManager(ExpManager): def list_experiments(self): # retrieve all the existing experiments - exps = self.client.list_experiments(view_type=1) + exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY) experiments = dict() for exp in exps: experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)