mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 19:41:00 +08:00
Update exp related and pytorch_nn
This commit is contained in:
@@ -6,18 +6,20 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
|
||||
class DNNModelPytorch(Model):
|
||||
@@ -144,20 +146,25 @@ class DNNModelPytorch(Model):
|
||||
|
||||
def fit(
|
||||
self,
|
||||
x_train,
|
||||
y_train,
|
||||
x_valid,
|
||||
y_valid,
|
||||
w_train=None,
|
||||
w_valid=None,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
if w_train is None:
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
try:
|
||||
wdf_train, wdf_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
|
||||
except:
|
||||
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
|
||||
if w_valid is None:
|
||||
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
|
||||
|
||||
save_path = create_save_path(save_path)
|
||||
@@ -188,6 +195,7 @@ class DNNModelPytorch(Model):
|
||||
w_val_auto = w_val_auto.cuda()
|
||||
|
||||
for step in range(self.max_steps):
|
||||
self.logger.info(step)
|
||||
if stop_steps >= self.early_stop_rounds:
|
||||
if verbose:
|
||||
self.logger.info("\tearly stop")
|
||||
@@ -195,6 +203,7 @@ class DNNModelPytorch(Model):
|
||||
loss = AverageMeter()
|
||||
self.dnn_model.train()
|
||||
self.train_optimizer.zero_grad()
|
||||
self.logger.info("INIT")
|
||||
|
||||
choice = np.random.choice(train_num, self.batch_size)
|
||||
x_batch_auto = x_train_values[choice]
|
||||
@@ -264,10 +273,11 @@ class DNNModelPytorch(Model):
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
|
||||
def predict(self, x_test):
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = torch.from_numpy(x_test.values).float()
|
||||
x_test_pd = dataset.prepare("test", col_set="feature")
|
||||
x_test = torch.from_numpy(x_test_pd.values).float()
|
||||
if self.use_gpu:
|
||||
x_test = x_test.cuda()
|
||||
self.dnn_model.eval()
|
||||
@@ -277,13 +287,20 @@ class DNNModelPytorch(Model):
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
else:
|
||||
preds = self.dnn_model(x_test).detach().numpy()
|
||||
return preds
|
||||
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
|
||||
|
||||
def score(self, x_test, y_test, w_test=None):
|
||||
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
|
||||
df_test = dataset.prepare("test", col_set=["feature", "label"])
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
|
||||
preds = self.predict(x_test)
|
||||
w_test_weight = None if w_test is None else w_test.values
|
||||
try:
|
||||
df_test = dataset.prepare("test", col_set=["weight"])
|
||||
w_test = df_test["weight"]
|
||||
w_test_weight = w_test.values
|
||||
except:
|
||||
w_test_weight = None
|
||||
return self._scorer(y_test.values, preds, sample_weight=w_test_weight)
|
||||
|
||||
def save(self, filename, **kwargs):
|
||||
@@ -303,7 +320,12 @@ class DNNModelPytorch(Model):
|
||||
self.dnn_model.load_state_dict(torch.load(_model_path))
|
||||
self._fitted = True
|
||||
|
||||
def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
|
||||
def finetune(self, dataset, w_train=None, w_valid=None, **kwargs):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
self.fit(x_train, y_train, x_valid, y_valid, w_train=w_train, w_valid=w_valid, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -4,31 +4,32 @@
|
||||
from contextlib import contextmanager
|
||||
from .expm import MLflowExpManager
|
||||
from ..utils import Wrapper
|
||||
|
||||
from ..config import C
|
||||
|
||||
class QlibRecorder:
|
||||
"""
|
||||
A global system that helps to manage the experiments.
|
||||
"""
|
||||
|
||||
def __init__(self, exp_manager, uri):
|
||||
def __init__(self, exp_manager):
|
||||
self.exp_manager = exp_manager
|
||||
self.uri = uri
|
||||
self.uri = C["exp_uri"]
|
||||
|
||||
@contextmanager
|
||||
def start(self, experiment_name):
|
||||
run = self.start_exp(experiment_name)
|
||||
try:
|
||||
yield run
|
||||
except:
|
||||
self.end_exp() # end the experiment if something went wrong
|
||||
self.end_exp()
|
||||
except Exception as e:
|
||||
self.end_exp("FAILED") # end the experiment if something went wrong
|
||||
raise e
|
||||
self.end_exp("FINISHED")
|
||||
|
||||
def start_exp(self, experiment_name=None):
|
||||
return self.exp_manager.start_exp(experiment_name, self.uri)
|
||||
|
||||
def end_exp(self):
|
||||
self.exp_manager.end_exp()
|
||||
def end_exp(self, status):
|
||||
self.exp_manager.end_exp(status)
|
||||
|
||||
def search_records(self, experiment_ids, **kwargs):
|
||||
return self.exp_manager.search_records(experiment_ids, **kwargs)
|
||||
@@ -45,11 +46,8 @@ class QlibRecorder:
|
||||
def get_recorder(self):
|
||||
return self.exp_manager.active_recorder
|
||||
|
||||
def save_object(self, data=None, name=None, local_path=None):
|
||||
self.exp_manager.active_recorder.save_object(data, name, local_path)
|
||||
|
||||
def save_objects(self, data_name_list=None, local_path=None):
|
||||
self.exp_manager.active_recorder.save_objects(data_name_list, local_path)
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
self.exp_manager.active_recorder.save_objects(local_path, artifact_path, **kwargs)
|
||||
|
||||
def load_object(self, name):
|
||||
return self.exp_manager.active_recorder.load_object(name)
|
||||
@@ -63,8 +61,8 @@ class QlibRecorder:
|
||||
def set_tags(self, **kwargs):
|
||||
self.exp_manager.active_recorder.set_tags(**kwargs)
|
||||
|
||||
def delete_tag(self, key):
|
||||
self.exp_manager.active_recorder.delete_tag(key)
|
||||
def delete_tag(self, *key):
|
||||
self.exp_manager.active_recorder.delete_tag(*key)
|
||||
|
||||
|
||||
# global record
|
||||
|
||||
@@ -14,7 +14,47 @@ class Experiment:
|
||||
def __init__(self):
|
||||
self.name = None
|
||||
self.id = None
|
||||
self.recorders = list()
|
||||
self.active_recorder = None # only one recorder can running each time
|
||||
self.recorders = dict() # recorder id -> object
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
output = dict()
|
||||
output['class'] = "Experiment"
|
||||
output['id'] = self.id
|
||||
output['name'] = self.name
|
||||
output['active_recorder'] = self.active_recorder.id
|
||||
output['recorders'] = list(self.recorders.keys())
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start the experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Returns
|
||||
-------
|
||||
A running recorder instance.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start` method.")
|
||||
|
||||
def end(self, status):
|
||||
"""
|
||||
End the experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
status : str
|
||||
the status the recorder to be set with when ending (SCHEDULED, RUNNING, FINISHED, FAILED).
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end` method.")
|
||||
|
||||
def create_recorder(self):
|
||||
"""
|
||||
@@ -25,7 +65,7 @@ class Experiment:
|
||||
|
||||
Returns
|
||||
-------
|
||||
A recorder instance.
|
||||
A recorder object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_recorder` method.")
|
||||
|
||||
@@ -46,24 +86,40 @@ class Experiment:
|
||||
|
||||
Returns
|
||||
-------
|
||||
A pandas.DataFrame of records.
|
||||
A pandas.DataFrame of records, where each metric, parameter, and tag
|
||||
are expanded into their own columns named metrics.*, params.*, and tags.*
|
||||
respectively. For records that don't have a particular metric, parameter, or tag, their
|
||||
value will be (NumPy) Nan, None, or None respectively.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
def delete_recorder(self, rid):
|
||||
def delete_recorder(self, recorder_id):
|
||||
"""
|
||||
Create a recorder for each experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rid : str
|
||||
recorder_id : str
|
||||
the id of the recorder to be deleted.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_recorder` method.")
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None):
|
||||
"""
|
||||
Get the current active Recorder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorder_id : str
|
||||
the id of the recorder to be deleted.
|
||||
recorder_name : str
|
||||
the name of the recorder to be deleted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A recorder instance.
|
||||
A recorder object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_recorder` method.")
|
||||
raise NotImplementedError(f"Please implement the `get_recorder` method.")
|
||||
|
||||
|
||||
class MLflowExperiment(Experiment):
|
||||
@@ -71,9 +127,26 @@ class MLflowExperiment(Experiment):
|
||||
Use mlflow to implement Experiment.
|
||||
"""
|
||||
|
||||
def start(self):
|
||||
# set up recorder
|
||||
recorder = self.create_recorder()
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
run = self.active_recorder.start_run()
|
||||
# store the recorder
|
||||
self.recorders[self.active_recorder.id] = recorder
|
||||
|
||||
return self.active_recorder
|
||||
|
||||
def end(self, status):
|
||||
if self.active_recorder is not None:
|
||||
self.active_recorder.end_run(status)
|
||||
self.active_recorder = None
|
||||
|
||||
def create_recorder(self):
|
||||
recorder = MLflowRecorder(self.id)
|
||||
self.recorders.append(recorder)
|
||||
num = len(self.recorders)
|
||||
name = "Recorder_{}".format(num+1)
|
||||
recorder = MLflowRecorder(name, self.id)
|
||||
return recorder
|
||||
|
||||
def search_records(self, **kwargs):
|
||||
@@ -81,8 +154,23 @@ class MLflowExperiment(Experiment):
|
||||
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
order_by = kwargs.get("order_by")
|
||||
return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by)
|
||||
return mlflow.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def delete_recorder(self, rid):
|
||||
mlflow.delete_run(rid)
|
||||
self.recorders = [r for r in self.recorders if r.recorder_id == rid]
|
||||
def delete_recorder(self, recorder_id):
|
||||
mlflow.delete_run(recorder_id)
|
||||
self.recorders = [r for r in self.recorders if r.id == recorder_id]
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None):
|
||||
if recorder_id is not None:
|
||||
return self.recorders[recorder_id]
|
||||
elif recorder_name is not None:
|
||||
for rid in self.recorders:
|
||||
if self.recorders[rid].name == recorder_name:
|
||||
return self.recorders[rid]
|
||||
elif self.active_recorder is None:
|
||||
raise Exception('No valid active recorder exists. Please make sure the experiment is running.')
|
||||
else:
|
||||
logger.info(
|
||||
"No experiment id or name is given. Return the current active experiment."
|
||||
)
|
||||
return self.active_recorder
|
||||
@@ -9,7 +9,7 @@ from .exp import MLflowExperiment
|
||||
from .recorder import MLflowRecorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "WARNING")
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
class ExpManager:
|
||||
@@ -20,7 +20,7 @@ class ExpManager:
|
||||
|
||||
def __init__(self):
|
||||
self.uri = None
|
||||
self.active_recorder = None # only one recorder can running each time
|
||||
self.active_experiment = None # only one experiment can running each time
|
||||
self.experiments = dict() # store the experiment name --> Experiment object
|
||||
|
||||
def start_exp(self, experiment_name=None, uri=None, **kwargs):
|
||||
@@ -39,7 +39,7 @@ class ExpManager:
|
||||
controls whether run is nested in parent run.
|
||||
|
||||
Returns
|
||||
An object wrapped by context manager.
|
||||
An active recorder.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start_exp` method.")
|
||||
|
||||
@@ -73,11 +73,14 @@ class ExpManager:
|
||||
|
||||
Returns
|
||||
-------
|
||||
A pandas.DataFrame of runs.
|
||||
A pandas.DataFrame of records, where each metric, parameter, and tag
|
||||
are expanded into their own columns named metrics.*, params.*, and tags.*
|
||||
respectively. For records that don't have a particular metric, parameter, or tag, their
|
||||
value will be (NumPy) Nan, None, or None respectively.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
def __create_exp(self, experiment_name, artifact_location=None):
|
||||
def create_exp(self, experiment_name, artifact_location=None):
|
||||
"""
|
||||
Create an experiment.
|
||||
|
||||
@@ -133,19 +136,6 @@ class ExpManager:
|
||||
"""
|
||||
return self.uri
|
||||
|
||||
def get_recorder(self):
|
||||
"""
|
||||
Get the current active Recorder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Returns
|
||||
-------
|
||||
An Recorder object.
|
||||
"""
|
||||
return self.active_recorder
|
||||
|
||||
|
||||
class MLflowExpManager(ExpManager):
|
||||
"""
|
||||
@@ -158,26 +148,27 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
def start_exp(self, experiment_name=None, uri=None):
|
||||
# create experiment
|
||||
experiment = self.__create_exp(experiment_name, uri)
|
||||
# set up recorder
|
||||
recorder = experiment.create_recorder()
|
||||
self.active_recorder = recorder
|
||||
experiment = self.create_exp(experiment_name, uri)
|
||||
# set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# store the experiment
|
||||
self.experiments[experiment_name] = experiment
|
||||
# start the experiment
|
||||
self.active_experiment.start()
|
||||
|
||||
return self.active_recorder.start_run(experiment_id=experiment.id)
|
||||
return self.active_experiment
|
||||
|
||||
def end_exp(self):
|
||||
if self.active_recorder is not None:
|
||||
self.active_recorder.end_run()
|
||||
self.active_recorder = None
|
||||
def end_exp(self, status):
|
||||
if self.active_experiment is not None:
|
||||
self.active_experiment.end(status)
|
||||
self.active_experiment = None
|
||||
|
||||
def __create_exp(self, experiment_name=None, uri=None):
|
||||
def create_exp(self, experiment_name=None, uri=None):
|
||||
# init experiment
|
||||
experiment = MLflowExperiment()
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
logger.warning(
|
||||
logger.info(
|
||||
"No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory."
|
||||
)
|
||||
else:
|
||||
@@ -185,7 +176,7 @@ class MLflowExpManager(ExpManager):
|
||||
mlflow.set_tracking_uri(self.uri)
|
||||
# start the experiment
|
||||
if experiment_name is None:
|
||||
logger.warning("No experiment name provided. The default experiment name is set as `experiment`.")
|
||||
logger.info("No experiment name provided. The default experiment name is set as `experiment`.")
|
||||
experiment_id = mlflow.create_experiment("experiment")
|
||||
# set the active experiment
|
||||
mlflow.set_experiment("experiment")
|
||||
@@ -216,17 +207,19 @@ class MLflowExpManager(ExpManager):
|
||||
return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None):
|
||||
assert (
|
||||
experiment_id is not None or experiment_name is not None
|
||||
), "Please provide at least one of the experiment id or name to retrieve an experiment."
|
||||
if experiment_name is not None:
|
||||
return self.experiments[experiment_name]
|
||||
elif experiment_id is not None:
|
||||
for name in self.experiments:
|
||||
if self.experiments[name].id == experiment_id:
|
||||
return self.experiments[name]
|
||||
elif self.active_experiment is None:
|
||||
raise Exception('No valid active experiment exists. Please make sure experiment manager is running.')
|
||||
else:
|
||||
raise Exception("No valid experiment is found. Please make sure the id and name are correctly given.")
|
||||
logger.info(
|
||||
"No experiment id or name is given. Return the current active experiment."
|
||||
)
|
||||
return self.active_experiment
|
||||
|
||||
def delete_exp(self, experiment_id):
|
||||
mlflow.delete_experiment(experiment_id)
|
||||
|
||||
@@ -11,6 +11,11 @@ from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
|
||||
|
||||
class RecordTemp:
|
||||
"""
|
||||
This is the Records Template class that enables user to generate experiment results such as IC and
|
||||
backtest in a certain format.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -24,10 +29,23 @@ class RecordTemp:
|
||||
|
||||
Return
|
||||
------
|
||||
The generated records.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `generate` method.")
|
||||
|
||||
def load(self, **kwargs):
|
||||
"""
|
||||
Load the stored records.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kwargs
|
||||
|
||||
Return
|
||||
------
|
||||
The stored records.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load` method.")
|
||||
|
||||
def check(self, **kwargs):
|
||||
"""
|
||||
Check if the records is properly generated and saved.
|
||||
@@ -35,12 +53,20 @@ class RecordTemp:
|
||||
Parameters
|
||||
----------
|
||||
kwargs
|
||||
|
||||
Return
|
||||
------
|
||||
Boolean: whether the records are stored properly.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `check` method.")
|
||||
|
||||
|
||||
# TODO: this can only be run under R's running experiment.
|
||||
class SignalRecord(RecordTemp):
|
||||
"""
|
||||
This is the Signal Record class that generates the signal prediction.
|
||||
"""
|
||||
|
||||
def __init__(self, model, dataset, recorder, **kwargs):
|
||||
super(SignalRecord, self).__init__()
|
||||
self.model = model
|
||||
@@ -61,12 +87,16 @@ class SignalRecord(RecordTemp):
|
||||
raise Exception("Something went wrong when loading the saved object.")
|
||||
|
||||
def check(self, **kwargs):
|
||||
return self.recorder.check("pred.pkl")
|
||||
artifacts = self.recorder.list_artifacts()
|
||||
for artifact in artifacts:
|
||||
if "pred.pkl" in artifact.path:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# TODO
|
||||
class SigAnaRecord(SignalRecord):
|
||||
def __init__(self, recorder, **kwargs):
|
||||
def __init__(self, recorder, config, **kwargs):
|
||||
pass
|
||||
|
||||
def generate(self):
|
||||
@@ -80,13 +110,16 @@ class SigAnaRecord(SignalRecord):
|
||||
|
||||
|
||||
class PortAnaRecord(SignalRecord):
|
||||
def __init__(self, recorder, STRATEGY_CONFIG, BACKTEST_CONFIG, **kwargs):
|
||||
"""
|
||||
This is the Portfolio Analysis Record class that generates the results such as those of backtest.
|
||||
"""
|
||||
|
||||
def __init__(self, recorder, config, **kwargs):
|
||||
self.recorder = recorder
|
||||
self.STRATEGY_CONFIG = STRATEGY_CONFIG
|
||||
self.BACKTEST_CONFIG = BACKTEST_CONFIG
|
||||
module = get_module_by_module_path("qlib.contrib.strategy")
|
||||
self.strategy = init_instance_by_config(STRATEGY_CONFIG, module)
|
||||
self.artifact_path = Path("portfolio_analysis").resolve()
|
||||
self.strategy_config = config['strategy']
|
||||
self.backtest_config = config['backtest']
|
||||
self.strategy = init_instance_by_config(self.strategy_config)
|
||||
self.artifact_path = "portfolio_analysis"
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""
|
||||
@@ -121,4 +154,8 @@ class PortAnaRecord(SignalRecord):
|
||||
raise Exception("Something went wrong when loading the saved object.")
|
||||
|
||||
def check(self):
|
||||
return self.recorder.check("port_analysis.pkl", self.artifact_path)
|
||||
artifacts = self.recorder.list_artifacts(self.artifact_path)
|
||||
for artifact in artifacts:
|
||||
if "port_analysis.pkl" in artifact.path:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -11,19 +11,37 @@ class Recorder:
|
||||
"""
|
||||
This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow.
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
|
||||
The status of the recorder can be SCHEDULED, RUNNING, FINISHED, FAILED.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_id):
|
||||
def __init__(self, name, experiment_id):
|
||||
self.id = None
|
||||
self.name = name
|
||||
self.experiment_id = experiment_id
|
||||
self.recorder_id = None
|
||||
self.recorder_name = None
|
||||
self.status = "SCHEDULED"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
output = dict()
|
||||
output['class'] = "Recorder"
|
||||
output['id'] = self.id
|
||||
output['name'] = self.name
|
||||
output['experiment_id'] = self.experiment_id
|
||||
output['status'] = self.status
|
||||
|
||||
def set_recorder_name(self, rname):
|
||||
self.recorder_name = rname
|
||||
|
||||
def save_object(self, data=None, name=None, local_path=None, artifact_path=None):
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
"""
|
||||
Save object such as prediction file or model checkpoints to the artifact URI.
|
||||
Save objects such as prediction file or model checkpoints to the artifact URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -31,19 +49,6 @@ class Recorder:
|
||||
the data to be saved.
|
||||
name : str
|
||||
name of the file to be saved.
|
||||
local_path : str
|
||||
if provided, them save the file or directory to the artifact URI.
|
||||
artifact_path=None : str
|
||||
the relative path for the artifact to be stored in the URI.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `save_object` method.")
|
||||
|
||||
def save_objects(self, data_name_list=None, local_path=None, artifact_path=None):
|
||||
"""
|
||||
Save objects such as prediction file or model checkpoints to the artifact URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_name_list : list
|
||||
list of (data, name) pairs
|
||||
local_path : str
|
||||
@@ -68,21 +73,13 @@ class Recorder:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load_object` method.")
|
||||
|
||||
def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False):
|
||||
def start_run(self):
|
||||
"""
|
||||
Start running the Recorder. The return value can be used as a context manager within a `with` block;
|
||||
Start running or resuming the Recorder. The return value can be used as a context manager within a `with` block;
|
||||
otherwise, you must call end_run() to terminate the current run. (See `ActiveRun` class in mlflow)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
run_id : str
|
||||
id of the active Recorder.
|
||||
experiment_id : str
|
||||
id of the active experiment.
|
||||
run_name : str
|
||||
name of the Recorder.
|
||||
nested : boolean
|
||||
controls whether run is nested in parent run.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -127,18 +124,33 @@ class Recorder:
|
||||
keyword arguments
|
||||
key, value pair to be logged as tags.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `log_tags` method.")
|
||||
raise NotImplementedError(f"Please implement the `set_tags` method.")
|
||||
|
||||
def delete_tag(self, key):
|
||||
def delete_tags(self, *keys):
|
||||
"""
|
||||
Delete a tag from a run.
|
||||
Delete some tags from a run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
the name of the tag to be deleted.
|
||||
keys : series of strs of the keys
|
||||
all the name of the tag to be deleted.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_tag` method.")
|
||||
raise NotImplementedError(f"Please implement the `delete_tags` method.")
|
||||
|
||||
def list_artifacts(self, artifact_path=None):
|
||||
"""
|
||||
Delete some tags from a run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
artifact_path=None : str
|
||||
the relative path for the artifact to be stored in the URI.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of artifacts information (name, path, etc.) that being stored.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list_artifacts` method.")
|
||||
|
||||
|
||||
class MLflowRecorder(Recorder):
|
||||
@@ -149,51 +161,43 @@ class MLflowRecorder(Recorder):
|
||||
use file manager to help maintain the objects in the project.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_id):
|
||||
super(MLflowRecorder, self).__init__(experiment_id)
|
||||
def __init__(self, name, experiment_id):
|
||||
super(MLflowRecorder, self).__init__(name, experiment_id)
|
||||
self.fm = None
|
||||
self.temp_dir = None
|
||||
|
||||
def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False):
|
||||
if run_id is None:
|
||||
run_id = self.recorder_id
|
||||
if experiment_id is None:
|
||||
experiment_id = self.experiment_id
|
||||
if run_name is None:
|
||||
run_name = self.recorder_name
|
||||
def start_run(self):
|
||||
# start the run
|
||||
run = mlflow.start_run(run_id, experiment_id, run_name, nested)
|
||||
run = mlflow.start_run(self.id, self.experiment_id, self.name)
|
||||
# save the run id and artifact_uri
|
||||
self.recorder_id = run.info.run_id
|
||||
self.id = run.info.run_id
|
||||
self.artifact_uri = run.info.artifact_uri
|
||||
self._uri = mlflow.get_tracking_uri() # Fix!!! : this is not proper to have uri in recorder
|
||||
# set up file manager for saving objects
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.fm = FileManager(Path(self.temp_dir).absolute())
|
||||
self.status = "RUNNING"
|
||||
return run
|
||||
|
||||
def end_run(self):
|
||||
mlflow.end_run()
|
||||
def end_run(self, status):
|
||||
mlflow.end_run(status)
|
||||
self.status = status
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def save_object(self, data=None, name=None, local_path=None, artifact_path=None):
|
||||
def save_objects(self, data_name_list=None, local_path=None, artifact_path=None, **kwargs):
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
if local_path is None:
|
||||
assert data is not None and name is not None, "Please provide data and name input."
|
||||
if local_path is not None:
|
||||
client.log_artifacts(self.id, local_path, artifact_path)
|
||||
elif kwargs.get('data') is not None and kwargs.get('name') is not None:
|
||||
data, name = kwargs.get('data'), kwargs.get('name')
|
||||
self.fm.save_obj(data, name)
|
||||
client.log_artifact(self.recorder_id, self.fm.path / name, artifact_path)
|
||||
else:
|
||||
assert local_path is not None, "Please provide a valid local path for the "
|
||||
client.log_artifact(self.recorder_id, local_path, artifact_path)
|
||||
|
||||
def save_objects(self, data_name_list=None, local_path=None, artifact_path=None):
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
if local_path is None:
|
||||
assert data_name_list is not None, "Please provide data_name_list input."
|
||||
client.log_artifact(self.id, self.fm.path / name, artifact_path)
|
||||
elif kwargs.get('data_name_list') is not None:
|
||||
data_name_list = kwargs.get('data_name_list')
|
||||
self.fm.save_objs(data_name_list)
|
||||
client.log_artifacts(self.recorder_id, self.fm.path, artifact_path)
|
||||
client.log_artifacts(self.id, self.fm.path, artifact_path)
|
||||
else:
|
||||
client.log_artifacts(self.recorder_id, local_path, artifact_path)
|
||||
raise Exception('Please provide valid arguments in order to save object properly.')
|
||||
|
||||
def load_object(self, name):
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
@@ -227,18 +231,16 @@ class MLflowRecorder(Recorder):
|
||||
else:
|
||||
mlflow.set_tags(dict(kwargs))
|
||||
|
||||
def delete_tag(self, key):
|
||||
mlflow.delete_tag(key)
|
||||
def delete_tags(self, *keys):
|
||||
for count, key in enumerate(keys):
|
||||
mlflow.delete_tag(key)
|
||||
|
||||
def get_artifact_uri(self, artifact_path=None):
|
||||
if self.artifact_uri is not None:
|
||||
return self.artifact_uri
|
||||
return mlflow.get_artifact_uri(artifact_path)
|
||||
|
||||
def check(self, name, path=None):
|
||||
def list_artifacts(self, artifact_path=None):
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
artifacts = client.list_artifacts(self.recorder_id, path)
|
||||
for artifact in artifacts:
|
||||
if name in artifact.path:
|
||||
return True
|
||||
return False
|
||||
artifacts = client.list_artifacts(self.id, path)
|
||||
return artifacts
|
||||
|
||||
Reference in New Issue
Block a user