mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 09:31:18 +08:00
fix mlflow bug
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
from pathlib import Path
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
@@ -226,7 +227,7 @@ class MLflowExperiment(Experiment):
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
logger.info(f"No valid recorder found. Create a new recorder with name {recorder_name}.")
|
||||
return self.create(recorder_name), True
|
||||
return self.create_recorder(recorder_name), True
|
||||
|
||||
def _get_recorder(self, recorder_id=None, recorder_name=None):
|
||||
"""
|
||||
@@ -241,7 +242,7 @@ class MLflowExperiment(Experiment):
|
||||
run = self.client.get_run(recorder_id)
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
|
||||
return recorder
|
||||
except MlflowException as e:
|
||||
except MlflowException:
|
||||
raise ValueError("No valid recorder has been found, please make sure the input recorder id is correct.")
|
||||
elif recorder_name is not None:
|
||||
logger.warning(
|
||||
@@ -269,15 +270,17 @@ class MLflowExperiment(Experiment):
|
||||
if recorder_id is not None:
|
||||
self.client.delete_run(recorder_id)
|
||||
else:
|
||||
recorder = self._get_recorder_by_name(recorder_name)
|
||||
recorder = self._get_recorder(recorder_name=recorder_name)
|
||||
self.client.delete_run(recorder.id)
|
||||
except MlflowException as e:
|
||||
raise Exception(
|
||||
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
|
||||
)
|
||||
|
||||
def list_recorders(self):
|
||||
runs = self.client.search_runs(self.id, run_view_type=1)[::-1]
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED):
|
||||
runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import mlflow
|
||||
from mlflow.exceptions import MlflowException
|
||||
from mlflow.entities import ViewType
|
||||
import os
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
@@ -324,7 +325,7 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
def list_experiments(self):
|
||||
# retrieve all the existing experiments
|
||||
exps = self.client.list_experiments(view_type=1)
|
||||
exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY)
|
||||
experiments = dict()
|
||||
for exp in exps:
|
||||
experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)
|
||||
|
||||
Reference in New Issue
Block a user