1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 19:41:00 +08:00

Update exp related and pytorch_nn

This commit is contained in:
Jactus
2020-11-09 16:42:21 +08:00
parent 9a826eefa3
commit 853410c16e
6 changed files with 297 additions and 157 deletions

View File

@@ -6,18 +6,20 @@ from __future__ import division
from __future__ import print_function
import os
import logging
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
from ...log import get_module_logger, TimeInspector
import torch
import torch.nn as nn
import torch.optim as optim
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
from ...log import get_module_logger, TimeInspector
class DNNModelPytorch(Model):
@@ -144,20 +146,25 @@ class DNNModelPytorch(Model):
def fit(
self,
x_train,
y_train,
x_valid,
y_valid,
w_train=None,
w_valid=None,
dataset: DatasetH,
evals_result=dict(),
verbose=True,
save_path=None,
):
if w_train is None:
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
try:
wdf_train, wdf_valid = dataset.prepare(
["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L
)
w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
except:
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
if w_valid is None:
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
save_path = create_save_path(save_path)
@@ -188,6 +195,7 @@ class DNNModelPytorch(Model):
w_val_auto = w_val_auto.cuda()
for step in range(self.max_steps):
self.logger.info(step)
if stop_steps >= self.early_stop_rounds:
if verbose:
self.logger.info("\tearly stop")
@@ -195,6 +203,7 @@ class DNNModelPytorch(Model):
loss = AverageMeter()
self.dnn_model.train()
self.train_optimizer.zero_grad()
self.logger.info("INIT")
choice = np.random.choice(train_num, self.batch_size)
x_batch_auto = x_train_values[choice]
@@ -264,10 +273,11 @@ class DNNModelPytorch(Model):
else:
raise NotImplementedError("loss {} is not supported!".format(loss_type))
def predict(self, x_test):
def predict(self, dataset):
if not self._fitted:
raise ValueError("model is not fitted yet!")
x_test = torch.from_numpy(x_test.values).float()
x_test_pd = dataset.prepare("test", col_set="feature")
x_test = torch.from_numpy(x_test_pd.values).float()
if self.use_gpu:
x_test = x_test.cuda()
self.dnn_model.eval()
@@ -277,13 +287,20 @@ class DNNModelPytorch(Model):
preds = self.dnn_model(x_test).detach().cpu().numpy()
else:
preds = self.dnn_model(x_test).detach().numpy()
return preds
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
def score(self, x_test, y_test, w_test=None):
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
df_test = dataset.prepare("test", col_set=["feature", "label"])
x_test, y_test = df_test["feature"], df_test["label"]
x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
preds = self.predict(x_test)
w_test_weight = None if w_test is None else w_test.values
try:
df_test = dataset.prepare("test", col_set=["weight"])
w_test = df_test["weight"]
w_test_weight = w_test.values
except:
w_test_weight = None
return self._scorer(y_test.values, preds, sample_weight=w_test_weight)
def save(self, filename, **kwargs):
@@ -303,7 +320,12 @@ class DNNModelPytorch(Model):
self.dnn_model.load_state_dict(torch.load(_model_path))
self._fitted = True
def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
def finetune(self, dataset, w_train=None, w_valid=None, **kwargs):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
self.fit(x_train, y_train, x_valid, y_valid, w_train=w_train, w_valid=w_valid, **kwargs)

View File

@@ -4,31 +4,32 @@
from contextlib import contextmanager
from .expm import MLflowExpManager
from ..utils import Wrapper
from ..config import C
class QlibRecorder:
"""
A global system that helps to manage the experiments.
"""
def __init__(self, exp_manager, uri):
def __init__(self, exp_manager):
self.exp_manager = exp_manager
self.uri = uri
self.uri = C["exp_uri"]
@contextmanager
def start(self, experiment_name):
run = self.start_exp(experiment_name)
try:
yield run
except:
self.end_exp() # end the experiment if something went wrong
self.end_exp()
except Exception as e:
self.end_exp("FAILED") # end the experiment if something went wrong
raise e
self.end_exp("FINISHED")
def start_exp(self, experiment_name=None):
return self.exp_manager.start_exp(experiment_name, self.uri)
def end_exp(self):
self.exp_manager.end_exp()
def end_exp(self, status):
self.exp_manager.end_exp(status)
def search_records(self, experiment_ids, **kwargs):
return self.exp_manager.search_records(experiment_ids, **kwargs)
@@ -45,11 +46,8 @@ class QlibRecorder:
def get_recorder(self):
return self.exp_manager.active_recorder
def save_object(self, data=None, name=None, local_path=None):
self.exp_manager.active_recorder.save_object(data, name, local_path)
def save_objects(self, data_name_list=None, local_path=None):
self.exp_manager.active_recorder.save_objects(data_name_list, local_path)
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
self.exp_manager.active_recorder.save_objects(local_path, artifact_path, **kwargs)
def load_object(self, name):
return self.exp_manager.active_recorder.load_object(name)
@@ -63,8 +61,8 @@ class QlibRecorder:
def set_tags(self, **kwargs):
self.exp_manager.active_recorder.set_tags(**kwargs)
def delete_tag(self, key):
self.exp_manager.active_recorder.delete_tag(key)
def delete_tag(self, *key):
self.exp_manager.active_recorder.delete_tag(*key)
# global record

View File

@@ -14,7 +14,47 @@ class Experiment:
def __init__(self):
self.name = None
self.id = None
self.recorders = list()
self.active_recorder = None # only one recorder can running each time
self.recorders = dict() # recorder id -> object
def __repr__(self):
return str(self.info)
def __str__(self):
return str(self.info)
@property
def info(self):
output = dict()
output['class'] = "Experiment"
output['id'] = self.id
output['name'] = self.name
output['active_recorder'] = self.active_recorder.id
output['recorders'] = list(self.recorders.keys())
def start(self):
"""
Start the experiment.
Parameters
----------
Returns
-------
A running recorder instance.
"""
raise NotImplementedError(f"Please implement the `start` method.")
def end(self, status):
"""
End the experiment.
Parameters
----------
status : str
the status the recorder to be set with when ending (SCHEDULED, RUNNING, FINISHED, FAILED).
"""
raise NotImplementedError(f"Please implement the `end` method.")
def create_recorder(self):
"""
@@ -25,7 +65,7 @@ class Experiment:
Returns
-------
A recorder instance.
A recorder object.
"""
raise NotImplementedError(f"Please implement the `create_recorder` method.")
@@ -46,24 +86,40 @@ class Experiment:
Returns
-------
A pandas.DataFrame of records.
A pandas.DataFrame of records, where each metric, parameter, and tag
are expanded into their own columns named metrics.*, params.*, and tags.*
respectively. For records that don't have a particular metric, parameter, or tag, their
value will be (NumPy) Nan, None, or None respectively.
"""
raise NotImplementedError(f"Please implement the `search_records` method.")
def delete_recorder(self, rid):
def delete_recorder(self, recorder_id):
"""
Create a recorder for each experiment.
Parameters
----------
rid : str
recorder_id : str
the id of the recorder to be deleted.
"""
raise NotImplementedError(f"Please implement the `delete_recorder` method.")
def get_recorder(self, recorder_id=None, recorder_name=None):
"""
Get the current active Recorder.
Parameters
----------
recorder_id : str
the id of the recorder to be deleted.
recorder_name : str
the name of the recorder to be deleted.
Returns
-------
A recorder instance.
A recorder object.
"""
raise NotImplementedError(f"Please implement the `delete_recorder` method.")
raise NotImplementedError(f"Please implement the `get_recorder` method.")
class MLflowExperiment(Experiment):
@@ -71,9 +127,26 @@ class MLflowExperiment(Experiment):
Use mlflow to implement Experiment.
"""
def start(self):
# set up recorder
recorder = self.create_recorder()
self.active_recorder = recorder
# start the recorder
run = self.active_recorder.start_run()
# store the recorder
self.recorders[self.active_recorder.id] = recorder
return self.active_recorder
def end(self, status):
if self.active_recorder is not None:
self.active_recorder.end_run(status)
self.active_recorder = None
def create_recorder(self):
recorder = MLflowRecorder(self.id)
self.recorders.append(recorder)
num = len(self.recorders)
name = "Recorder_{}".format(num+1)
recorder = MLflowRecorder(name, self.id)
return recorder
def search_records(self, **kwargs):
@@ -81,8 +154,23 @@ class MLflowExperiment(Experiment):
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
order_by = kwargs.get("order_by")
return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by)
return mlflow.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
def delete_recorder(self, rid):
mlflow.delete_run(rid)
self.recorders = [r for r in self.recorders if r.recorder_id == rid]
def delete_recorder(self, recorder_id):
mlflow.delete_run(recorder_id)
self.recorders = [r for r in self.recorders if r.id == recorder_id]
def get_recorder(self, recorder_id=None, recorder_name=None):
if recorder_id is not None:
return self.recorders[recorder_id]
elif recorder_name is not None:
for rid in self.recorders:
if self.recorders[rid].name == recorder_name:
return self.recorders[rid]
elif self.active_recorder is None:
raise Exception('No valid active recorder exists. Please make sure the experiment is running.')
else:
logger.info(
"No experiment id or name is given. Return the current active experiment."
)
return self.active_recorder

View File

@@ -9,7 +9,7 @@ from .exp import MLflowExperiment
from .recorder import MLflowRecorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "WARNING")
logger = get_module_logger("workflow", "INFO")
class ExpManager:
@@ -20,7 +20,7 @@ class ExpManager:
def __init__(self):
self.uri = None
self.active_recorder = None # only one recorder can running each time
self.active_experiment = None # only one experiment can running each time
self.experiments = dict() # store the experiment name --> Experiment object
def start_exp(self, experiment_name=None, uri=None, **kwargs):
@@ -39,7 +39,7 @@ class ExpManager:
controls whether run is nested in parent run.
Returns
An object wrapped by context manager.
An active recorder.
"""
raise NotImplementedError(f"Please implement the `start_exp` method.")
@@ -73,11 +73,14 @@ class ExpManager:
Returns
-------
A pandas.DataFrame of runs.
A pandas.DataFrame of records, where each metric, parameter, and tag
are expanded into their own columns named metrics.*, params.*, and tags.*
respectively. For records that don't have a particular metric, parameter, or tag, their
value will be (NumPy) Nan, None, or None respectively.
"""
raise NotImplementedError(f"Please implement the `search_records` method.")
def __create_exp(self, experiment_name, artifact_location=None):
def create_exp(self, experiment_name, artifact_location=None):
"""
Create an experiment.
@@ -133,19 +136,6 @@ class ExpManager:
"""
return self.uri
def get_recorder(self):
"""
Get the current active Recorder.
Parameters
----------
Returns
-------
An Recorder object.
"""
return self.active_recorder
class MLflowExpManager(ExpManager):
"""
@@ -158,26 +148,27 @@ class MLflowExpManager(ExpManager):
def start_exp(self, experiment_name=None, uri=None):
# create experiment
experiment = self.__create_exp(experiment_name, uri)
# set up recorder
recorder = experiment.create_recorder()
self.active_recorder = recorder
experiment = self.create_exp(experiment_name, uri)
# set up active experiment
self.active_experiment = experiment
# store the experiment
self.experiments[experiment_name] = experiment
# start the experiment
self.active_experiment.start()
return self.active_recorder.start_run(experiment_id=experiment.id)
return self.active_experiment
def end_exp(self):
if self.active_recorder is not None:
self.active_recorder.end_run()
self.active_recorder = None
def end_exp(self, status):
if self.active_experiment is not None:
self.active_experiment.end(status)
self.active_experiment = None
def __create_exp(self, experiment_name=None, uri=None):
def create_exp(self, experiment_name=None, uri=None):
# init experiment
experiment = MLflowExperiment()
# set the tracking uri
if uri is None:
logger.warning(
logger.info(
"No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory."
)
else:
@@ -185,7 +176,7 @@ class MLflowExpManager(ExpManager):
mlflow.set_tracking_uri(self.uri)
# start the experiment
if experiment_name is None:
logger.warning("No experiment name provided. The default experiment name is set as `experiment`.")
logger.info("No experiment name provided. The default experiment name is set as `experiment`.")
experiment_id = mlflow.create_experiment("experiment")
# set the active experiment
mlflow.set_experiment("experiment")
@@ -216,17 +207,19 @@ class MLflowExpManager(ExpManager):
return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
def get_exp(self, experiment_id=None, experiment_name=None):
assert (
experiment_id is not None or experiment_name is not None
), "Please provide at least one of the experiment id or name to retrieve an experiment."
if experiment_name is not None:
return self.experiments[experiment_name]
elif experiment_id is not None:
for name in self.experiments:
if self.experiments[name].id == experiment_id:
return self.experiments[name]
elif self.active_experiment is None:
raise Exception('No valid active experiment exists. Please make sure experiment manager is running.')
else:
raise Exception("No valid experiment is found. Please make sure the id and name are correctly given.")
logger.info(
"No experiment id or name is given. Return the current active experiment."
)
return self.active_experiment
def delete_exp(self, experiment_id):
mlflow.delete_experiment(experiment_id)

View File

@@ -11,6 +11,11 @@ from ..utils import init_instance_by_config, get_module_by_module_path
class RecordTemp:
"""
This is the Records Template class that enables user to generate experiment results such as IC and
backtest in a certain format.
"""
def __init__(self, *args, **kwargs):
pass
@@ -24,10 +29,23 @@ class RecordTemp:
Return
------
The generated records.
"""
raise NotImplementedError(f"Please implement the `generate` method.")
def load(self, **kwargs):
"""
Load the stored records.
Parameters
----------
kwargs
Return
------
The stored records.
"""
raise NotImplementedError(f"Please implement the `load` method.")
def check(self, **kwargs):
"""
Check if the records is properly generated and saved.
@@ -35,12 +53,20 @@ class RecordTemp:
Parameters
----------
kwargs
Return
------
Boolean: whether the records are stored properly.
"""
raise NotImplementedError(f"Please implement the `check` method.")
# TODO: this can only be run under R's running experiment.
class SignalRecord(RecordTemp):
"""
This is the Signal Record class that generates the signal prediction.
"""
def __init__(self, model, dataset, recorder, **kwargs):
super(SignalRecord, self).__init__()
self.model = model
@@ -61,12 +87,16 @@ class SignalRecord(RecordTemp):
raise Exception("Something went wrong when loading the saved object.")
def check(self, **kwargs):
return self.recorder.check("pred.pkl")
artifacts = self.recorder.list_artifacts()
for artifact in artifacts:
if "pred.pkl" in artifact.path:
return True
return False
# TODO
class SigAnaRecord(SignalRecord):
def __init__(self, recorder, **kwargs):
def __init__(self, recorder, config, **kwargs):
pass
def generate(self):
@@ -80,13 +110,16 @@ class SigAnaRecord(SignalRecord):
class PortAnaRecord(SignalRecord):
def __init__(self, recorder, STRATEGY_CONFIG, BACKTEST_CONFIG, **kwargs):
"""
This is the Portfolio Analysis Record class that generates the results such as those of backtest.
"""
def __init__(self, recorder, config, **kwargs):
self.recorder = recorder
self.STRATEGY_CONFIG = STRATEGY_CONFIG
self.BACKTEST_CONFIG = BACKTEST_CONFIG
module = get_module_by_module_path("qlib.contrib.strategy")
self.strategy = init_instance_by_config(STRATEGY_CONFIG, module)
self.artifact_path = Path("portfolio_analysis").resolve()
self.strategy_config = config['strategy']
self.backtest_config = config['backtest']
self.strategy = init_instance_by_config(self.strategy_config)
self.artifact_path = "portfolio_analysis"
def generate(self, **kwargs):
"""
@@ -121,4 +154,8 @@ class PortAnaRecord(SignalRecord):
raise Exception("Something went wrong when loading the saved object.")
def check(self):
return self.recorder.check("port_analysis.pkl", self.artifact_path)
artifacts = self.recorder.list_artifacts(self.artifact_path)
for artifact in artifacts:
if "port_analysis.pkl" in artifact.path:
return True
return False

View File

@@ -11,19 +11,37 @@ class Recorder:
"""
This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow.
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
The status of the recorder can be SCHEDULED, RUNNING, FINISHED, FAILED.
"""
def __init__(self, experiment_id):
def __init__(self, name, experiment_id):
self.id = None
self.name = name
self.experiment_id = experiment_id
self.recorder_id = None
self.recorder_name = None
self.status = "SCHEDULED"
def __repr__(self):
return str(self.info)
def __str__(self):
return str(self.info)
@property
def info(self):
output = dict()
output['class'] = "Recorder"
output['id'] = self.id
output['name'] = self.name
output['experiment_id'] = self.experiment_id
output['status'] = self.status
def set_recorder_name(self, rname):
self.recorder_name = rname
def save_object(self, data=None, name=None, local_path=None, artifact_path=None):
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
"""
Save object such as prediction file or model checkpoints to the artifact URI.
Save objects such as prediction file or model checkpoints to the artifact URI.
Parameters
----------
@@ -31,19 +49,6 @@ class Recorder:
the data to be saved.
name : str
name of the file to be saved.
local_path : str
if provided, them save the file or directory to the artifact URI.
artifact_path=None : str
the relative path for the artifact to be stored in the URI.
"""
raise NotImplementedError(f"Please implement the `save_object` method.")
def save_objects(self, data_name_list=None, local_path=None, artifact_path=None):
"""
Save objects such as prediction file or model checkpoints to the artifact URI.
Parameters
----------
data_name_list : list
list of (data, name) pairs
local_path : str
@@ -68,21 +73,13 @@ class Recorder:
"""
raise NotImplementedError(f"Please implement the `load_object` method.")
def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False):
def start_run(self):
"""
Start running the Recorder. The return value can be used as a context manager within a `with` block;
Start running or resuming the Recorder. The return value can be used as a context manager within a `with` block;
otherwise, you must call end_run() to terminate the current run. (See `ActiveRun` class in mlflow)
Parameters
----------
run_id : str
id of the active Recorder.
experiment_id : str
id of the active experiment.
run_name : str
name of the Recorder.
nested : boolean
controls whether run is nested in parent run.
Returns
-------
@@ -127,18 +124,33 @@ class Recorder:
keyword arguments
key, value pair to be logged as tags.
"""
raise NotImplementedError(f"Please implement the `log_tags` method.")
raise NotImplementedError(f"Please implement the `set_tags` method.")
def delete_tag(self, key):
def delete_tags(self, *keys):
"""
Delete a tag from a run.
Delete some tags from a run.
Parameters
----------
key : str
the name of the tag to be deleted.
keys : series of strs of the keys
all the name of the tag to be deleted.
"""
raise NotImplementedError(f"Please implement the `delete_tag` method.")
raise NotImplementedError(f"Please implement the `delete_tags` method.")
def list_artifacts(self, artifact_path=None):
"""
Delete some tags from a run.
Parameters
----------
artifact_path=None : str
the relative path for the artifact to be stored in the URI.
Returns
-------
A list of artifacts information (name, path, etc.) that being stored.
"""
raise NotImplementedError(f"Please implement the `list_artifacts` method.")
class MLflowRecorder(Recorder):
@@ -149,51 +161,43 @@ class MLflowRecorder(Recorder):
use file manager to help maintain the objects in the project.
"""
def __init__(self, experiment_id):
super(MLflowRecorder, self).__init__(experiment_id)
def __init__(self, name, experiment_id):
super(MLflowRecorder, self).__init__(name, experiment_id)
self.fm = None
self.temp_dir = None
def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False):
if run_id is None:
run_id = self.recorder_id
if experiment_id is None:
experiment_id = self.experiment_id
if run_name is None:
run_name = self.recorder_name
def start_run(self):
# start the run
run = mlflow.start_run(run_id, experiment_id, run_name, nested)
run = mlflow.start_run(self.id, self.experiment_id, self.name)
# save the run id and artifact_uri
self.recorder_id = run.info.run_id
self.id = run.info.run_id
self.artifact_uri = run.info.artifact_uri
self._uri = mlflow.get_tracking_uri() # Fix!!! : this is not proper to have uri in recorder
# set up file manager for saving objects
self.temp_dir = tempfile.mkdtemp()
self.fm = FileManager(Path(self.temp_dir).absolute())
self.status = "RUNNING"
return run
def end_run(self):
mlflow.end_run()
def end_run(self, status):
mlflow.end_run(status)
self.status = status
shutil.rmtree(self.temp_dir)
def save_object(self, data=None, name=None, local_path=None, artifact_path=None):
def save_objects(self, data_name_list=None, local_path=None, artifact_path=None, **kwargs):
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
if local_path is None:
assert data is not None and name is not None, "Please provide data and name input."
if local_path is not None:
client.log_artifacts(self.id, local_path, artifact_path)
elif kwargs.get('data') is not None and kwargs.get('name') is not None:
data, name = kwargs.get('data'), kwargs.get('name')
self.fm.save_obj(data, name)
client.log_artifact(self.recorder_id, self.fm.path / name, artifact_path)
else:
assert local_path is not None, "Please provide a valid local path for the "
client.log_artifact(self.recorder_id, local_path, artifact_path)
def save_objects(self, data_name_list=None, local_path=None, artifact_path=None):
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
if local_path is None:
assert data_name_list is not None, "Please provide data_name_list input."
client.log_artifact(self.id, self.fm.path / name, artifact_path)
elif kwargs.get('data_name_list') is not None:
data_name_list = kwargs.get('data_name_list')
self.fm.save_objs(data_name_list)
client.log_artifacts(self.recorder_id, self.fm.path, artifact_path)
client.log_artifacts(self.id, self.fm.path, artifact_path)
else:
client.log_artifacts(self.recorder_id, local_path, artifact_path)
raise Exception('Please provide valid arguments in order to save object properly.')
def load_object(self, name):
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
@@ -227,18 +231,16 @@ class MLflowRecorder(Recorder):
else:
mlflow.set_tags(dict(kwargs))
def delete_tag(self, key):
mlflow.delete_tag(key)
def delete_tags(self, *keys):
for count, key in enumerate(keys):
mlflow.delete_tag(key)
def get_artifact_uri(self, artifact_path=None):
if self.artifact_uri is not None:
return self.artifact_uri
return mlflow.get_artifact_uri(artifact_path)
def check(self, name, path=None):
def list_artifacts(self, artifact_path=None):
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
artifacts = client.list_artifacts(self.recorder_id, path)
for artifact in artifacts:
if name in artifact.path:
return True
return False
artifacts = client.list_artifacts(self.id, path)
return artifacts