mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
simplify the code and add docs
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union
|
||||
import mlflow, logging
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
@@ -324,14 +325,21 @@ class MLflowExperiment(Experiment):
|
||||
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED, status=None):
|
||||
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
max_results : int
|
||||
the number limitation of the results
|
||||
status : str
|
||||
the criteria based on status to filter results.
|
||||
`None` indicates no filtering.
|
||||
"""
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
if status is not None:
|
||||
if recorder.status != status:
|
||||
continue
|
||||
recorders[runs[i].info.run_id] = recorder
|
||||
if status is None or recorder.status == status:
|
||||
recorders[runs[i].info.run_id] = recorder
|
||||
|
||||
return recorders
|
||||
|
||||
Reference in New Issue
Block a user