From 853410c16eb79708371b87a4af37de18840e03f2 Mon Sep 17 00:00:00 2001 From: Jactus Date: Mon, 9 Nov 2020 16:42:21 +0800 Subject: [PATCH] Update exp related and pytorch_nn --- qlib/contrib/model/pytorch_nn.py | 54 ++++++++---- qlib/workflow/__init__.py | 28 +++---- qlib/workflow/exp.py | 114 ++++++++++++++++++++++--- qlib/workflow/expm.py | 61 ++++++-------- qlib/workflow/record_temp.py | 57 ++++++++++--- qlib/workflow/recorder.py | 140 ++++++++++++++++--------------- 6 files changed, 297 insertions(+), 157 deletions(-) diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index b5bf91472..1acb5c843 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -6,18 +6,20 @@ from __future__ import division from __future__ import print_function import os +import logging import numpy as np import pandas as pd from sklearn.metrics import roc_auc_score, mean_squared_error -import logging -from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index -from ...log import get_module_logger, TimeInspector import torch import torch.nn as nn import torch.optim as optim from ...model.base import Model +from ...data.dataset import DatasetH +from ...data.dataset.handler import DataHandlerLP +from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index +from ...log import get_module_logger, TimeInspector class DNNModelPytorch(Model): @@ -144,20 +146,25 @@ class DNNModelPytorch(Model): def fit( self, - x_train, - y_train, - x_valid, - y_valid, - w_train=None, - w_valid=None, + dataset: DatasetH, evals_result=dict(), verbose=True, save_path=None, ): - if w_train is None: + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + x_train, y_train = df_train["feature"], df_train["label"] + x_valid, y_valid = df_valid["feature"], df_valid["label"] + + try: + wdf_train, wdf_valid = dataset.prepare( + ["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L + ) + w_train, w_valid = wdf_train["weight"], wdf_valid["weight"] + except: w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index) - if w_valid is None: w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index) save_path = create_save_path(save_path) @@ -188,6 +195,7 @@ class DNNModelPytorch(Model): w_val_auto = w_val_auto.cuda() for step in range(self.max_steps): + self.logger.info(step) if stop_steps >= self.early_stop_rounds: if verbose: self.logger.info("\tearly stop") @@ -195,6 +203,7 @@ class DNNModelPytorch(Model): loss = AverageMeter() self.dnn_model.train() self.train_optimizer.zero_grad() + self.logger.info("INIT") choice = np.random.choice(train_num, self.batch_size) x_batch_auto = x_train_values[choice] @@ -264,10 +273,11 @@ class DNNModelPytorch(Model): else: raise NotImplementedError("loss {} is not supported!".format(loss_type)) - def predict(self, x_test): + def predict(self, dataset): if not self._fitted: raise ValueError("model is not fitted yet!") - x_test = torch.from_numpy(x_test.values).float() + x_test_pd = dataset.prepare("test", col_set="feature") + x_test = torch.from_numpy(x_test_pd.values).float() if self.use_gpu: x_test = x_test.cuda() self.dnn_model.eval() @@ -277,13 +287,20 @@ class DNNModelPytorch(Model): preds = self.dnn_model(x_test).detach().cpu().numpy() else: preds = self.dnn_model(x_test).detach().numpy() - return preds + return pd.Series(np.squeeze(preds), index=x_test_pd.index) def score(self, x_test, y_test, w_test=None): # Remove rows from x, y and w, which contain Nan in any columns in y_test. + df_test = dataset.prepare("test", col_set=["feature", "label"]) + x_test, y_test = df_test["feature"], df_test["label"] x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test) preds = self.predict(x_test) - w_test_weight = None if w_test is None else w_test.values + try: + df_test = dataset.prepare("test", col_set=["weight"]) + w_test = df_test["weight"] + w_test_weight = w_test.values + except: + w_test_weight = None return self._scorer(y_test.values, preds, sample_weight=w_test_weight) def save(self, filename, **kwargs): @@ -303,7 +320,12 @@ class DNNModelPytorch(Model): self.dnn_model.load_state_dict(torch.load(_model_path)) self._fitted = True - def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs): + def finetune(self, dataset, w_train=None, w_valid=None, **kwargs): + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + x_train, y_train = df_train["feature"], df_train["label"] + x_valid, y_valid = df_valid["feature"], df_valid["label"] self.fit(x_train, y_train, x_valid, y_valid, w_train=w_train, w_valid=w_valid, **kwargs) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 91880b281..a941ed7cf 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -4,31 +4,32 @@ from contextlib import contextmanager from .expm import MLflowExpManager from ..utils import Wrapper - +from ..config import C class QlibRecorder: """ A global system that helps to manage the experiments. """ - def __init__(self, exp_manager, uri): + def __init__(self, exp_manager): self.exp_manager = exp_manager - self.uri = uri + self.uri = C["exp_uri"] @contextmanager def start(self, experiment_name): run = self.start_exp(experiment_name) try: yield run - except: - self.end_exp() # end the experiment if something went wrong - self.end_exp() + except Exception as e: + self.end_exp("FAILED") # end the experiment if something went wrong + raise e + self.end_exp("FINISHED") def start_exp(self, experiment_name=None): return self.exp_manager.start_exp(experiment_name, self.uri) - def end_exp(self): - self.exp_manager.end_exp() + def end_exp(self, status): + self.exp_manager.end_exp(status) def search_records(self, experiment_ids, **kwargs): return self.exp_manager.search_records(experiment_ids, **kwargs) @@ -45,11 +46,8 @@ class QlibRecorder: def get_recorder(self): return self.exp_manager.active_recorder - def save_object(self, data=None, name=None, local_path=None): - self.exp_manager.active_recorder.save_object(data, name, local_path) - - def save_objects(self, data_name_list=None, local_path=None): - self.exp_manager.active_recorder.save_objects(data_name_list, local_path) + def save_objects(self, local_path=None, artifact_path=None, **kwargs): + self.exp_manager.active_recorder.save_objects(local_path, artifact_path, **kwargs) def load_object(self, name): return self.exp_manager.active_recorder.load_object(name) @@ -63,8 +61,8 @@ class QlibRecorder: def set_tags(self, **kwargs): self.exp_manager.active_recorder.set_tags(**kwargs) - def delete_tag(self, key): - self.exp_manager.active_recorder.delete_tag(key) + def delete_tag(self, *key): + self.exp_manager.active_recorder.delete_tag(*key) # global record diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 9b5517471..86163c0ea 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -14,7 +14,47 @@ class Experiment: def __init__(self): self.name = None self.id = None - self.recorders = list() + self.active_recorder = None # only one recorder can running each time + self.recorders = dict() # recorder id -> object + + def __repr__(self): + return str(self.info) + + def __str__(self): + return str(self.info) + + @property + def info(self): + output = dict() + output['class'] = "Experiment" + output['id'] = self.id + output['name'] = self.name + output['active_recorder'] = self.active_recorder.id + output['recorders'] = list(self.recorders.keys()) + + def start(self): + """ + Start the experiment. + + Parameters + ---------- + + Returns + ------- + A running recorder instance. + """ + raise NotImplementedError(f"Please implement the `start` method.") + + def end(self, status): + """ + End the experiment. + + Parameters + ---------- + status : str + the status the recorder to be set with when ending (SCHEDULED, RUNNING, FINISHED, FAILED). + """ + raise NotImplementedError(f"Please implement the `end` method.") def create_recorder(self): """ @@ -25,7 +65,7 @@ class Experiment: Returns ------- - A recorder instance. + A recorder object. """ raise NotImplementedError(f"Please implement the `create_recorder` method.") @@ -46,24 +86,40 @@ class Experiment: Returns ------- - A pandas.DataFrame of records. + A pandas.DataFrame of records, where each metric, parameter, and tag + are expanded into their own columns named metrics.*, params.*, and tags.* + respectively. For records that don't have a particular metric, parameter, or tag, their + value will be (NumPy) Nan, None, or None respectively. """ raise NotImplementedError(f"Please implement the `search_records` method.") - def delete_recorder(self, rid): + def delete_recorder(self, recorder_id): """ Create a recorder for each experiment. Parameters ---------- - rid : str + recorder_id : str the id of the recorder to be deleted. + """ + raise NotImplementedError(f"Please implement the `delete_recorder` method.") + + def get_recorder(self, recorder_id=None, recorder_name=None): + """ + Get the current active Recorder. + + Parameters + ---------- + recorder_id : str + the id of the recorder to be deleted. + recorder_name : str + the name of the recorder to be deleted. Returns ------- - A recorder instance. + A recorder object. """ - raise NotImplementedError(f"Please implement the `delete_recorder` method.") + raise NotImplementedError(f"Please implement the `get_recorder` method.") class MLflowExperiment(Experiment): @@ -71,9 +127,26 @@ class MLflowExperiment(Experiment): Use mlflow to implement Experiment. """ + def start(self): + # set up recorder + recorder = self.create_recorder() + self.active_recorder = recorder + # start the recorder + run = self.active_recorder.start_run() + # store the recorder + self.recorders[self.active_recorder.id] = recorder + + return self.active_recorder + + def end(self, status): + if self.active_recorder is not None: + self.active_recorder.end_run(status) + self.active_recorder = None + def create_recorder(self): - recorder = MLflowRecorder(self.id) - self.recorders.append(recorder) + num = len(self.recorders) + name = "Recorder_{}".format(num+1) + recorder = MLflowRecorder(name, self.id) return recorder def search_records(self, **kwargs): @@ -81,8 +154,23 @@ class MLflowExperiment(Experiment): run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type") max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") order_by = kwargs.get("order_by") - return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by) + return mlflow.search_runs([self.id], filter_string, run_view_type, max_results, order_by) - def delete_recorder(self, rid): - mlflow.delete_run(rid) - self.recorders = [r for r in self.recorders if r.recorder_id == rid] + def delete_recorder(self, recorder_id): + mlflow.delete_run(recorder_id) + self.recorders = [r for r in self.recorders if r.id == recorder_id] + + def get_recorder(self, recorder_id=None, recorder_name=None): + if recorder_id is not None: + return self.recorders[recorder_id] + elif recorder_name is not None: + for rid in self.recorders: + if self.recorders[rid].name == recorder_name: + return self.recorders[rid] + elif self.active_recorder is None: + raise Exception('No valid active recorder exists. Please make sure the experiment is running.') + else: + logger.info( + "No experiment id or name is given. Return the current active experiment." + ) + return self.active_recorder \ No newline at end of file diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index d5c0c247e..e81da0fcb 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -9,7 +9,7 @@ from .exp import MLflowExperiment from .recorder import MLflowRecorder from ..log import get_module_logger -logger = get_module_logger("workflow", "WARNING") +logger = get_module_logger("workflow", "INFO") class ExpManager: @@ -20,7 +20,7 @@ class ExpManager: def __init__(self): self.uri = None - self.active_recorder = None # only one recorder can running each time + self.active_experiment = None # only one experiment can running each time self.experiments = dict() # store the experiment name --> Experiment object def start_exp(self, experiment_name=None, uri=None, **kwargs): @@ -39,7 +39,7 @@ class ExpManager: controls whether run is nested in parent run. Returns - An object wrapped by context manager. + An active recorder. """ raise NotImplementedError(f"Please implement the `start_exp` method.") @@ -73,11 +73,14 @@ class ExpManager: Returns ------- - A pandas.DataFrame of runs. + A pandas.DataFrame of records, where each metric, parameter, and tag + are expanded into their own columns named metrics.*, params.*, and tags.* + respectively. For records that don't have a particular metric, parameter, or tag, their + value will be (NumPy) Nan, None, or None respectively. """ raise NotImplementedError(f"Please implement the `search_records` method.") - def __create_exp(self, experiment_name, artifact_location=None): + def create_exp(self, experiment_name, artifact_location=None): """ Create an experiment. @@ -133,19 +136,6 @@ class ExpManager: """ return self.uri - def get_recorder(self): - """ - Get the current active Recorder. - - Parameters - ---------- - - Returns - ------- - An Recorder object. - """ - return self.active_recorder - class MLflowExpManager(ExpManager): """ @@ -158,26 +148,27 @@ class MLflowExpManager(ExpManager): def start_exp(self, experiment_name=None, uri=None): # create experiment - experiment = self.__create_exp(experiment_name, uri) - # set up recorder - recorder = experiment.create_recorder() - self.active_recorder = recorder + experiment = self.create_exp(experiment_name, uri) + # set up active experiment + self.active_experiment = experiment # store the experiment self.experiments[experiment_name] = experiment + # start the experiment + self.active_experiment.start() - return self.active_recorder.start_run(experiment_id=experiment.id) + return self.active_experiment - def end_exp(self): - if self.active_recorder is not None: - self.active_recorder.end_run() - self.active_recorder = None + def end_exp(self, status): + if self.active_experiment is not None: + self.active_experiment.end(status) + self.active_experiment = None - def __create_exp(self, experiment_name=None, uri=None): + def create_exp(self, experiment_name=None, uri=None): # init experiment experiment = MLflowExperiment() # set the tracking uri if uri is None: - logger.warning( + logger.info( "No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory." ) else: @@ -185,7 +176,7 @@ class MLflowExpManager(ExpManager): mlflow.set_tracking_uri(self.uri) # start the experiment if experiment_name is None: - logger.warning("No experiment name provided. The default experiment name is set as `experiment`.") + logger.info("No experiment name provided. The default experiment name is set as `experiment`.") experiment_id = mlflow.create_experiment("experiment") # set the active experiment mlflow.set_experiment("experiment") @@ -216,17 +207,19 @@ class MLflowExpManager(ExpManager): return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by) def get_exp(self, experiment_id=None, experiment_name=None): - assert ( - experiment_id is not None or experiment_name is not None - ), "Please provide at least one of the experiment id or name to retrieve an experiment." if experiment_name is not None: return self.experiments[experiment_name] elif experiment_id is not None: for name in self.experiments: if self.experiments[name].id == experiment_id: return self.experiments[name] + elif self.active_experiment is None: + raise Exception('No valid active experiment exists. Please make sure experiment manager is running.') else: - raise Exception("No valid experiment is found. Please make sure the id and name are correctly given.") + logger.info( + "No experiment id or name is given. Return the current active experiment." + ) + return self.active_experiment def delete_exp(self, experiment_id): mlflow.delete_experiment(experiment_id) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index e45ef47b6..cf3a86f7f 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -11,6 +11,11 @@ from ..utils import init_instance_by_config, get_module_by_module_path class RecordTemp: + """ + This is the Records Template class that enables user to generate experiment results such as IC and + backtest in a certain format. + """ + def __init__(self, *args, **kwargs): pass @@ -24,10 +29,23 @@ class RecordTemp: Return ------ - The generated records. """ raise NotImplementedError(f"Please implement the `generate` method.") + def load(self, **kwargs): + """ + Load the stored records. + + Parameters + ---------- + kwargs + + Return + ------ + The stored records. + """ + raise NotImplementedError(f"Please implement the `load` method.") + def check(self, **kwargs): """ Check if the records is properly generated and saved. @@ -35,12 +53,20 @@ class RecordTemp: Parameters ---------- kwargs + + Return + ------ + Boolean: whether the records are stored properly. """ raise NotImplementedError(f"Please implement the `check` method.") # TODO: this can only be run under R's running experiment. class SignalRecord(RecordTemp): + """ + This is the Signal Record class that generates the signal prediction. + """ + def __init__(self, model, dataset, recorder, **kwargs): super(SignalRecord, self).__init__() self.model = model @@ -61,12 +87,16 @@ class SignalRecord(RecordTemp): raise Exception("Something went wrong when loading the saved object.") def check(self, **kwargs): - return self.recorder.check("pred.pkl") + artifacts = self.recorder.list_artifacts() + for artifact in artifacts: + if "pred.pkl" in artifact.path: + return True + return False # TODO class SigAnaRecord(SignalRecord): - def __init__(self, recorder, **kwargs): + def __init__(self, recorder, config, **kwargs): pass def generate(self): @@ -80,13 +110,16 @@ class SigAnaRecord(SignalRecord): class PortAnaRecord(SignalRecord): - def __init__(self, recorder, STRATEGY_CONFIG, BACKTEST_CONFIG, **kwargs): + """ + This is the Portfolio Analysis Record class that generates the results such as those of backtest. + """ + + def __init__(self, recorder, config, **kwargs): self.recorder = recorder - self.STRATEGY_CONFIG = STRATEGY_CONFIG - self.BACKTEST_CONFIG = BACKTEST_CONFIG - module = get_module_by_module_path("qlib.contrib.strategy") - self.strategy = init_instance_by_config(STRATEGY_CONFIG, module) - self.artifact_path = Path("portfolio_analysis").resolve() + self.strategy_config = config['strategy'] + self.backtest_config = config['backtest'] + self.strategy = init_instance_by_config(self.strategy_config) + self.artifact_path = "portfolio_analysis" def generate(self, **kwargs): """ @@ -121,4 +154,8 @@ class PortAnaRecord(SignalRecord): raise Exception("Something went wrong when loading the saved object.") def check(self): - return self.recorder.check("port_analysis.pkl", self.artifact_path) + artifacts = self.recorder.list_artifacts(self.artifact_path) + for artifact in artifacts: + if "port_analysis.pkl" in artifact.path: + return True + return False diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 307a740b6..157e29347 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -11,19 +11,37 @@ class Recorder: """ This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) + + The status of the recorder can be SCHEDULED, RUNNING, FINISHED, FAILED. """ - def __init__(self, experiment_id): + def __init__(self, name, experiment_id): + self.id = None + self.name = name self.experiment_id = experiment_id - self.recorder_id = None - self.recorder_name = None + self.status = "SCHEDULED" + + def __repr__(self): + return str(self.info) + + def __str__(self): + return str(self.info) + + @property + def info(self): + output = dict() + output['class'] = "Recorder" + output['id'] = self.id + output['name'] = self.name + output['experiment_id'] = self.experiment_id + output['status'] = self.status def set_recorder_name(self, rname): self.recorder_name = rname - def save_object(self, data=None, name=None, local_path=None, artifact_path=None): + def save_objects(self, local_path=None, artifact_path=None, **kwargs): """ - Save object such as prediction file or model checkpoints to the artifact URI. + Save objects such as prediction file or model checkpoints to the artifact URI. Parameters ---------- @@ -31,19 +49,6 @@ class Recorder: the data to be saved. name : str name of the file to be saved. - local_path : str - if provided, them save the file or directory to the artifact URI. - artifact_path=None : str - the relative path for the artifact to be stored in the URI. - """ - raise NotImplementedError(f"Please implement the `save_object` method.") - - def save_objects(self, data_name_list=None, local_path=None, artifact_path=None): - """ - Save objects such as prediction file or model checkpoints to the artifact URI. - - Parameters - ---------- data_name_list : list list of (data, name) pairs local_path : str @@ -68,21 +73,13 @@ class Recorder: """ raise NotImplementedError(f"Please implement the `load_object` method.") - def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False): + def start_run(self): """ - Start running the Recorder. The return value can be used as a context manager within a `with` block; + Start running or resuming the Recorder. The return value can be used as a context manager within a `with` block; otherwise, you must call end_run() to terminate the current run. (See `ActiveRun` class in mlflow) Parameters ---------- - run_id : str - id of the active Recorder. - experiment_id : str - id of the active experiment. - run_name : str - name of the Recorder. - nested : boolean - controls whether run is nested in parent run. Returns ------- @@ -127,18 +124,33 @@ class Recorder: keyword arguments key, value pair to be logged as tags. """ - raise NotImplementedError(f"Please implement the `log_tags` method.") + raise NotImplementedError(f"Please implement the `set_tags` method.") - def delete_tag(self, key): + def delete_tags(self, *keys): """ - Delete a tag from a run. + Delete some tags from a run. Parameters ---------- - key : str - the name of the tag to be deleted. + keys : series of strs of the keys + all the name of the tag to be deleted. """ - raise NotImplementedError(f"Please implement the `delete_tag` method.") + raise NotImplementedError(f"Please implement the `delete_tags` method.") + + def list_artifacts(self, artifact_path=None): + """ + Delete some tags from a run. + + Parameters + ---------- + artifact_path=None : str + the relative path for the artifact to be stored in the URI. + + Returns + ------- + A list of artifacts information (name, path, etc.) that being stored. + """ + raise NotImplementedError(f"Please implement the `list_artifacts` method.") class MLflowRecorder(Recorder): @@ -149,51 +161,43 @@ class MLflowRecorder(Recorder): use file manager to help maintain the objects in the project. """ - def __init__(self, experiment_id): - super(MLflowRecorder, self).__init__(experiment_id) + def __init__(self, name, experiment_id): + super(MLflowRecorder, self).__init__(name, experiment_id) self.fm = None self.temp_dir = None - def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False): - if run_id is None: - run_id = self.recorder_id - if experiment_id is None: - experiment_id = self.experiment_id - if run_name is None: - run_name = self.recorder_name + def start_run(self): # start the run - run = mlflow.start_run(run_id, experiment_id, run_name, nested) + run = mlflow.start_run(self.id, self.experiment_id, self.name) # save the run id and artifact_uri - self.recorder_id = run.info.run_id + self.id = run.info.run_id self.artifact_uri = run.info.artifact_uri self._uri = mlflow.get_tracking_uri() # Fix!!! : this is not proper to have uri in recorder # set up file manager for saving objects self.temp_dir = tempfile.mkdtemp() self.fm = FileManager(Path(self.temp_dir).absolute()) + self.status = "RUNNING" return run - def end_run(self): - mlflow.end_run() + def end_run(self, status): + mlflow.end_run(status) + self.status = status shutil.rmtree(self.temp_dir) - def save_object(self, data=None, name=None, local_path=None, artifact_path=None): + def save_objects(self, data_name_list=None, local_path=None, artifact_path=None, **kwargs): client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) - if local_path is None: - assert data is not None and name is not None, "Please provide data and name input." + if local_path is not None: + client.log_artifacts(self.id, local_path, artifact_path) + elif kwargs.get('data') is not None and kwargs.get('name') is not None: + data, name = kwargs.get('data'), kwargs.get('name') self.fm.save_obj(data, name) - client.log_artifact(self.recorder_id, self.fm.path / name, artifact_path) - else: - assert local_path is not None, "Please provide a valid local path for the " - client.log_artifact(self.recorder_id, local_path, artifact_path) - - def save_objects(self, data_name_list=None, local_path=None, artifact_path=None): - client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) - if local_path is None: - assert data_name_list is not None, "Please provide data_name_list input." + client.log_artifact(self.id, self.fm.path / name, artifact_path) + elif kwargs.get('data_name_list') is not None: + data_name_list = kwargs.get('data_name_list') self.fm.save_objs(data_name_list) - client.log_artifacts(self.recorder_id, self.fm.path, artifact_path) + client.log_artifacts(self.id, self.fm.path, artifact_path) else: - client.log_artifacts(self.recorder_id, local_path, artifact_path) + raise Exception('Please provide valid arguments in order to save object properly.') def load_object(self, name): client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) @@ -227,18 +231,16 @@ class MLflowRecorder(Recorder): else: mlflow.set_tags(dict(kwargs)) - def delete_tag(self, key): - mlflow.delete_tag(key) + def delete_tags(self, *keys): + for count, key in enumerate(keys): + mlflow.delete_tag(key) def get_artifact_uri(self, artifact_path=None): if self.artifact_uri is not None: return self.artifact_uri return mlflow.get_artifact_uri(artifact_path) - def check(self, name, path=None): + def list_artifacts(self, artifact_path=None): client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) - artifacts = client.list_artifacts(self.recorder_id, path) - for artifact in artifacts: - if name in artifact.path: - return True - return False + artifacts = client.list_artifacts(self.id, path) + return artifacts