mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
Add RecordTemp & update
This commit is contained in:
@@ -82,12 +82,11 @@ def init(default_conf="client", **kwargs):
|
||||
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
|
||||
# set up QlibRecorder
|
||||
default_uri = str(Path(os.getcwd()).resolve() / "mlruns")
|
||||
current_uri = C["exp_uri"] if C["exp_uri"] is not None else default_uri
|
||||
uri = C["exp_uri"]
|
||||
# exp manager module
|
||||
module = get_module_by_module_path("qlib.workflow")
|
||||
module = get_module_by_module_path("qlib.workflow.expm")
|
||||
exp_manager = init_instance_by_config(C["exp_manager"], module)
|
||||
qr = QlibRecorder(exp_manager, default_uri, current_uri)
|
||||
qr = QlibRecorder(exp_manager, uri)
|
||||
R.register(qr)
|
||||
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ _default_config = {
|
||||
},
|
||||
# Defatult config for experiment manager
|
||||
"exp_manager": {"class": "MLflowExpManager", "kwargs": {}},
|
||||
"exp_uri": None,
|
||||
"exp_uri": str(Path(os.getcwd()).resolve() / "mlruns"),
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
|
||||
@@ -24,7 +24,7 @@ from ..log import get_module_logger
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields
|
||||
from .base import Feature
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, get_provider_obj, register_wrapper
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC):
|
||||
@@ -1031,34 +1031,44 @@ D = Wrapper()
|
||||
def register_all_wrappers():
|
||||
"""register_all_wrappers"""
|
||||
logger = get_module_logger("data")
|
||||
|
||||
_calendar_provider = get_provider_obj(C.calendar_provider)
|
||||
module = get_module_by_module_path("qlib.data")
|
||||
|
||||
_calendar_provider = init_instance_by_config(C.calendar_provider, module)
|
||||
if getattr(C, "calendar_cache", None) is not None:
|
||||
_calendar_provider = get_provider_obj(C.calendar_cache, provider=_calendar_provider)
|
||||
register_wrapper(Cal, _calendar_provider)
|
||||
_calendar_cache_config = {}
|
||||
_calendar_cache_config.update(C.calendar_cache)
|
||||
_calendar_cache_config['kwargs'].update(provider=_calendar_provider)
|
||||
_calendar_provider = init_instance_by_config(_calendar_cache_config, module)
|
||||
register_wrapper(Cal, _calendar_provider, "qlib.data")
|
||||
logger.debug(f"registering Cal {C.calendar_provider}-{C.calenar_cache}")
|
||||
|
||||
register_wrapper(Inst, C.instrument_provider)
|
||||
register_wrapper(Inst, C.instrument_provider, "qlib.data")
|
||||
logger.debug(f"registering Inst {C.instrument_provider}")
|
||||
|
||||
if getattr(C, "feature_provider", None) is not None:
|
||||
feature_provider = get_provider_obj(C.feature_provider)
|
||||
register_wrapper(FeatureD, feature_provider)
|
||||
feature_provider = init_instance_by_config(C.feature_provider, module)
|
||||
register_wrapper(FeatureD, feature_provider, "qlib.data")
|
||||
logger.debug(f"registering FeatureD {C.feature_provider}")
|
||||
|
||||
if getattr(C, "expression_provider", None) is not None:
|
||||
# This provider is unnecessary in client provider
|
||||
_eprovider = get_provider_obj(C.expression_provider)
|
||||
_eprovider = init_instance_by_config(C.expression_provider, module)
|
||||
if getattr(C, "expression_cache", None) is not None:
|
||||
_eprovider = get_provider_obj(C.expression_cache, provider=_eprovider)
|
||||
register_wrapper(ExpressionD, _eprovider)
|
||||
_expression_cache_config = {}
|
||||
_expression_cache_config.update(C.expression_cache)
|
||||
_expression_cache_config['kwargs'].update(provider=_eprovider)
|
||||
_eprovider = init_instance_by_config(C.expression_cache, module)
|
||||
register_wrapper(ExpressionD, _eprovider, "qlib.data")
|
||||
logger.debug(f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}")
|
||||
|
||||
_dprovider = get_provider_obj(C.dataset_provider)
|
||||
_dprovider = init_instance_by_config(C.dataset_provider, module)
|
||||
if getattr(C, "dataset_cache", None) is not None:
|
||||
_dprovider = get_provider_obj(C.dataset_cache, provider=_dprovider)
|
||||
register_wrapper(DatasetD, _dprovider)
|
||||
_dataset_cache_config = {}
|
||||
_dataset_cache_config.update(C.dataset_cache)
|
||||
_dataset_cache_config['kwargs'].update(provider=_dprovider)
|
||||
_dprovider = init_instance_by_config(_dataset_cache_config, module)
|
||||
register_wrapper(DatasetD, _dprovider, "qlib.data")
|
||||
logger.debug(f"registering DataseteD {C.dataset_provider}-{C.dataset_cache}")
|
||||
|
||||
register_wrapper(D, C.provider)
|
||||
register_wrapper(D, C.provider, "qlib.data")
|
||||
logger.debug(f"registering D {C.provider}")
|
||||
|
||||
@@ -632,21 +632,14 @@ class Wrapper(object):
|
||||
return getattr(self._provider, key)
|
||||
|
||||
|
||||
def get_provider_obj(config, **params):
|
||||
module = get_module_by_module_path("qlib.data")
|
||||
klass, kwargs = get_cls_kwargs(config, module)
|
||||
kwargs.update(params)
|
||||
return klass(**kwargs)
|
||||
|
||||
|
||||
def register_wrapper(wrapper, cls_or_obj):
|
||||
def register_wrapper(wrapper, cls_or_obj, module_path=None):
|
||||
"""register_wrapper
|
||||
|
||||
:param wrapper: A wrapper of all kinds of providers
|
||||
:param cls_or_obj: A class or class name or object instance in data/data.py
|
||||
:param wrapper: A wrapper.
|
||||
:param cls_or_obj: A class or class name or object instance.
|
||||
"""
|
||||
if isinstance(cls_or_obj, str):
|
||||
module = get_module_by_module_path("qlib.data")
|
||||
module = get_module_by_module_path(module_path)
|
||||
cls_or_obj = getattr(module, cls_or_obj)
|
||||
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
|
||||
wrapper.register(obj)
|
||||
|
||||
@@ -2,24 +2,29 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from .expm import *
|
||||
from .expm import MLflowExpManager
|
||||
from ..utils import Wrapper
|
||||
|
||||
|
||||
class QlibRecorder:
|
||||
def __init__(self, exp_manager, default_uri, current_uri):
|
||||
"""
|
||||
A global system that helps to manage the experiments.
|
||||
"""
|
||||
def __init__(self, exp_manager, uri):
|
||||
self.exp_manager = exp_manager
|
||||
self.default_uri = default_uri
|
||||
self.current_uri = current_uri
|
||||
self.uri = uri
|
||||
|
||||
@contextmanager
|
||||
def start(self, experiment_name):
|
||||
run = self.start_exp(experiment_name, self.current_uri)
|
||||
yield run
|
||||
run = self.start_exp(experiment_name, self.uri)
|
||||
try:
|
||||
yield run
|
||||
except:
|
||||
self.end_exp() # end the experiment if something went wrong
|
||||
self.end_exp()
|
||||
|
||||
def start_exp(self, experiment_name=None):
|
||||
return self.exp_manager.start_exp(experiment_name, self.current_uri)
|
||||
return self.exp_manager.start_exp(experiment_name, self.uri)
|
||||
|
||||
def end_exp(self):
|
||||
self.exp_manager.end_exp()
|
||||
@@ -33,8 +38,8 @@ class QlibRecorder:
|
||||
def delete_exp(self, experiment_id):
|
||||
self.exp_manager.delete_exp(experiment_id)
|
||||
|
||||
def get_uri(self, type):
|
||||
return self.exp_manager.get_uri(type)
|
||||
def get_uri(self):
|
||||
return self.exp_manager.get_uri()
|
||||
|
||||
def get_recorder(self):
|
||||
return self.exp_manager.active_recorder
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import mlflow
|
||||
from pathlib import Path
|
||||
|
||||
from .recorder import MLflowRecorder
|
||||
|
||||
class Experiment:
|
||||
"""
|
||||
@@ -15,6 +15,19 @@ class Experiment:
|
||||
self.id = None
|
||||
self.recorders = list()
|
||||
|
||||
def create_recorder(self):
|
||||
"""
|
||||
Create a recorder for each experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Returns
|
||||
-------
|
||||
A recorder instance.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_recorder` method.")
|
||||
|
||||
def search_records(self, **kwargs):
|
||||
"""
|
||||
Get a pandas DataFrame of records that fit the search criteria of the experiment.
|
||||
@@ -36,15 +49,39 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
def delete_recorder(self, rid):
|
||||
"""
|
||||
Create a recorder for each experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rid : str
|
||||
the id of the recorder to be deleted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A recorder instance.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_recorder` method.")
|
||||
|
||||
|
||||
class MLflowExperiment(Experiment):
|
||||
"""
|
||||
Use mlflow to implement Experiment.
|
||||
"""
|
||||
|
||||
def create_recorder(self):
|
||||
recorder = MLflowRecorder(self.id)
|
||||
self.recorders.append(recorder)
|
||||
return recorders
|
||||
|
||||
def search_records(self, **kwargs):
|
||||
filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string")
|
||||
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
order_by = kwargs.get("order_by")
|
||||
return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def delete_recorder(self, rid):
|
||||
mlflow.delete_run(rid)
|
||||
self.recorders = [r for r in self.recorders if r.recorder_id == rid]
|
||||
@@ -6,8 +6,10 @@ import os
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from .exp import MLflowExperiment
|
||||
from .record import MLflowRecorder
|
||||
from .recorder import MLflowRecorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger('workflow', 'Warning')
|
||||
|
||||
class ExpManager:
|
||||
"""
|
||||
@@ -16,7 +18,7 @@ class ExpManager:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.default_uri = None
|
||||
self.uri = None
|
||||
self.active_recorder = None # only one recorder can running each time
|
||||
self.experiments = dict() # store the experiment name --> Experiment object
|
||||
|
||||
@@ -117,20 +119,18 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
def get_uri(self, type):
|
||||
def get_uri(self):
|
||||
"""
|
||||
Get the default tracking URI or current URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
type : str
|
||||
the type of the tracking URI one wants to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The tracking URI string.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
return self.uri
|
||||
|
||||
def get_recorder(self):
|
||||
"""
|
||||
@@ -143,7 +143,7 @@ class ExpManager:
|
||||
-------
|
||||
An Recorder object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_recorder` method.")
|
||||
return self.active_recorder
|
||||
|
||||
|
||||
class MLflowExpManager(ExpManager):
|
||||
@@ -153,17 +153,14 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
def __init__(self):
|
||||
super(MLflowExpManager, self).__init__()
|
||||
self.default_uri = None
|
||||
self.current_uri = None
|
||||
self.uri = None
|
||||
|
||||
def start_exp(self, experiment_name=None, uri=None):
|
||||
# create experiment
|
||||
experiment = self.__create_exp(experiment_name, uri)
|
||||
# set up recorder
|
||||
recorder = MLflowRecorder(experiment.id)
|
||||
recorder = experiment.create_recorder()
|
||||
self.active_recorder = recorder
|
||||
# store the recorder
|
||||
experiment.recorders.append(self.active_recorder)
|
||||
# store the experiment
|
||||
self.experiments[experiment_name] = experiment
|
||||
|
||||
@@ -178,15 +175,15 @@ class MLflowExpManager(ExpManager):
|
||||
experiment = MLflowExperiment()
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
print(
|
||||
logger.warning(
|
||||
"No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory."
|
||||
)
|
||||
else:
|
||||
self.current_uri = uri
|
||||
mlflow.set_tracking_uri(self.current_uri)
|
||||
self.uri = uri
|
||||
mlflow.set_tracking_uri(self.uri)
|
||||
# start the experiment
|
||||
if experiment_name is None:
|
||||
print("No experiment name provided. The default experiment name is set as `experiment`.")
|
||||
logger.warning("No experiment name provided. The default experiment name is set as `experiment`.")
|
||||
experiment_id = mlflow.create_experiment("experiment")
|
||||
# set the active experiment
|
||||
mlflow.set_experiment("experiment")
|
||||
@@ -227,19 +224,8 @@ class MLflowExpManager(ExpManager):
|
||||
if self.experiments[name].id == experiment_id:
|
||||
return self.experiments[name]
|
||||
else:
|
||||
print("No valid experiment is found. Please make sure the id and name are correctly given.")
|
||||
raise Exception("No valid experiment is found. Please make sure the id and name are correctly given.")
|
||||
|
||||
def delete_exp(self, experiment_id):
|
||||
mlflow.delete_experiment(experiment_id)
|
||||
self.experiments = {key: val for key, val in self.experiments.items() if val.id != experiment_id}
|
||||
|
||||
def get_uri(self, type):
|
||||
if uri == "default":
|
||||
return self.default_uri
|
||||
elif uri == "current":
|
||||
return self.current_uri
|
||||
else:
|
||||
raise ValueError("Input type is not supported. Please choose type default or current to get the uri.")
|
||||
|
||||
def get_recorder(self):
|
||||
return self.active_recorder
|
||||
|
||||
137
qlib/workflow/record_temp.py
Normal file
137
qlib/workflow/record_temp.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from ..contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
|
||||
|
||||
class RecordTemp:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""
|
||||
Generate certain records such as IC, backtest etc., and save them.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kwargs
|
||||
|
||||
Return
|
||||
------
|
||||
The generated records.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `generate` method.")
|
||||
|
||||
def check(self, **kwargs):
|
||||
"""
|
||||
Check if the records is properly generated and saved.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kwargs
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `check` method.")
|
||||
|
||||
|
||||
# TODO: this can only be run under R's running experiment.
|
||||
class SignalRecord(RecordTemp):
|
||||
def __init__(self, model, dataset, recorder, **kwargs):
|
||||
super(SignalRecord, self).__init__()
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.recorder = recorder
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# generate prediciton
|
||||
pred = self.model.predict(self.dataset)
|
||||
self.recorder.save_object(pred, 'pred.pkl')
|
||||
|
||||
def load(self):
|
||||
# try to load the saved object
|
||||
try:
|
||||
pred = self.recorder.load_object('pred.pkl')
|
||||
return pred
|
||||
except:
|
||||
raise Exception('Something went wrong when loading the saved object.')
|
||||
|
||||
def check(self, **kwargs):
|
||||
return self.recorder.check('pred.pkl')
|
||||
|
||||
|
||||
# TODO
|
||||
class SigAnaRecord(SignalRecord):
|
||||
def __init__(self, recorder, **kwargs):
|
||||
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def check(self):
|
||||
pass
|
||||
|
||||
|
||||
class PortAnaRecord(SignalRecord):
|
||||
def __init__(self, recorder, STRATEGY_CONFIG, BACKTEST_CONFIG, **kwargs):
|
||||
self.recorder = recorder
|
||||
self.STRATEGY_CONFIG = STRATEGY_CONFIG
|
||||
self.BACKTEST_CONFIG = BACKTEST_CONFIG
|
||||
module = get_module_by_module_path("qlib.contrib.strategy")
|
||||
self.strategy = init_instance_by_config(STRATEGY_CONFIG, module)
|
||||
self.artifact_path = Path('portfolio_analysis').resolve()
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""
|
||||
STRATEGY_CONFIG : dict
|
||||
define the strategy class as well as the kwargs.
|
||||
BACKTEST_CONFIG : dict
|
||||
define the backtest kwargs.
|
||||
"""
|
||||
# check previously stored prediction results
|
||||
assert super().check(), "Make sure the parent process is completed and store the data properly."
|
||||
# custom strategy and get backtest
|
||||
pred_score = super().load()
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.BACKTEST_CONFIG)
|
||||
self.recorder.save_object(report_normal, 'report_normal.pkl', self.artifact_path)
|
||||
self.recorder.save_object(positions_normal, 'positions_normal.pkl', self.artifact_path)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
self.recorder.save_object(pred, 'port_analysis.pkl', self.artifact_path)
|
||||
|
||||
def load(self):
|
||||
# try to load the saved object
|
||||
try:
|
||||
pred = self.recorder.load_object(self.artifact_path / 'port_analysis.pkl'')
|
||||
return pred
|
||||
except:
|
||||
raise Exception('Something went wrong when loading the saved object.')
|
||||
|
||||
def check(self):
|
||||
return self.recorder.check('port_analysis.pkl', self.artifact_path)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class Recorder:
|
||||
def set_recorder_name(self, rname):
|
||||
self.recorder_name = rname
|
||||
|
||||
def save_object(self, data, name, local_path=None):
|
||||
def save_object(self, data=None, name=None, local_path=None, artifact_path=None):
|
||||
"""
|
||||
Save object such as prediction file or model checkpoints to the artifact URI.
|
||||
|
||||
@@ -33,10 +33,12 @@ class Recorder:
|
||||
name of the file to be saved.
|
||||
local_path : str
|
||||
if provided, them save the file or directory to the artifact URI.
|
||||
artifact_path=None : str
|
||||
the relative path for the artifact to be stored in the URI.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `save_object` method.")
|
||||
|
||||
def save_objects(self, data_name_list, local_path=None):
|
||||
def save_objects(self, data_name_list=None, local_path=None, artifact_path=None):
|
||||
"""
|
||||
Save objects such as prediction file or model checkpoints to the artifact URI.
|
||||
|
||||
@@ -46,6 +48,8 @@ class Recorder:
|
||||
list of (data, name) pairs
|
||||
local_path : str
|
||||
if provided, them save the file or directory to the artifact URI.
|
||||
artifact_path=None : str
|
||||
the relative path for the artifact to be stored in the URI.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `save_objects` method.")
|
||||
|
||||
@@ -162,6 +166,7 @@ class MLflowRecorder(Recorder):
|
||||
# save the run id and artifact_uri
|
||||
self.recorder_id = run.info.run_id
|
||||
self.artifact_uri = run.info.artifact_uri
|
||||
self._uri = mlflow.get_tracking_uri() # Fix!!! : this is not proper to have uri in recorder
|
||||
# set up file manager for saving objects
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.fm = FileManager(Path(self.temp_dir).absolute())
|
||||
@@ -171,27 +176,27 @@ class MLflowRecorder(Recorder):
|
||||
mlflow.end_run()
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def save_object(self, data, name, local_path=None):
|
||||
def save_object(self, data=None, name=None, local_path=None, artifact_path=None):
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
if local_path is None:
|
||||
assert data is not None and name is not None, "Please provide data and name input."
|
||||
self.fm.save_obj(data, name)
|
||||
mlflow.log_artifact(self.fm.path / name)
|
||||
self.fm.remove(name)
|
||||
client.log_artifact(self.recorder_id, self.fm.path / name, artifact_path)
|
||||
else:
|
||||
mlflow.log_artifact(local_path)
|
||||
assert local_path is not None, "Please provide a valid local path for the "
|
||||
client.log_artifact(self.recorder_id, local_path, artifact_path)
|
||||
|
||||
def save_objects(self, data_name_list, local_path=None):
|
||||
def save_objects(self, data_name_list=None, local_path=None, artifact_path=None):
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
if local_path is None:
|
||||
assert data_name_list is not None, "Please provide data_name_list input."
|
||||
self.fm.save_objs(data_name_list)
|
||||
mlflow.log_artifacts(self.fm.path)
|
||||
for obj, name in data_name_list:
|
||||
self.fm.remove(name)
|
||||
client.log_artifacts(self.recorder_id, self.fm.path, artifact_path)
|
||||
else:
|
||||
mlflow.log_artifacts(local_path)
|
||||
client.log_artifacts(self.recorder_id, local_path, artifact_path)
|
||||
|
||||
def load_object(self, name):
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
path = client.download_artifacts(self.recorder_id, name)
|
||||
try:
|
||||
with Path(path).open("rb") as f:
|
||||
@@ -229,3 +234,11 @@ class MLflowRecorder(Recorder):
|
||||
if self.artifact_uri is not None:
|
||||
return self.artifact_uri
|
||||
return mlflow.get_artifact_uri(artifact_path)
|
||||
|
||||
def check(self, name, path=None):
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
artifacts = client.list_artifacts(self.recorder_id, path)
|
||||
for artifact in artifacts
|
||||
if name in artifact.path:
|
||||
return True
|
||||
return False
|
||||
@@ -22,3 +22,4 @@ scikit_learn==0.23.2
|
||||
torch==1.6.0
|
||||
tqdm==4.49.0
|
||||
yahooquery==2.2.7
|
||||
mlflow==1.11.0
|
||||
Reference in New Issue
Block a user