mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 03:50:57 +08:00
Update R and workflow
This commit is contained in:
@@ -58,3 +58,8 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
.. note::
|
||||
|
||||
If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <../component/data.html#cache>`_ for details.
|
||||
- `exp_manager`
|
||||
Type: str, optional parameter(default: "MLflowExpManager"), the experiment manager to be used in qlib.
|
||||
- `exp_uri`
|
||||
Type: str, optional parameter(default: "mlruns" in local execution path), the tracking uri of the experiment manager.
|
||||
It can either be a local path or a remote uri.
|
||||
@@ -14,10 +14,9 @@ from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
# from qlib.model.learner import train_model
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -93,55 +92,41 @@ if __name__ == "__main__":
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
"record": ["SignalRecord", "PortAnaRecord"],
|
||||
}
|
||||
|
||||
# model = train_model(task)
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
"backtest": {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
}
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
model.fit(dataset)
|
||||
# start exp
|
||||
with R.start("workflow"):
|
||||
model.fit(dataset)
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
# backtest
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par.generate()
|
||||
|
||||
@@ -5,17 +5,19 @@
|
||||
__version__ = "0.5.1.dev0"
|
||||
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import platform
|
||||
import sys
|
||||
import copy
|
||||
import yaml
|
||||
import atexit
|
||||
import signal
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path
|
||||
|
||||
from .workflow.utils import experiment_exception_hook, experiment_kill_signal_handler
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
@@ -44,9 +46,14 @@ def init(default_conf="client", **kwargs):
|
||||
C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
|
||||
|
||||
for k, v in kwargs.items():
|
||||
C[k] = v
|
||||
if k not in C:
|
||||
LOG.warning("Unrecognized config %s" % k)
|
||||
if k == "exp_manager":
|
||||
C["exp_manager"].update({"class": v})
|
||||
elif k == "exp_uri":
|
||||
C["exp_manager"]["kwargs"].update({"uri": v})
|
||||
else:
|
||||
C[k] = v
|
||||
if k not in C:
|
||||
LOG.warning("Unrecognized config %s" % k)
|
||||
|
||||
C.resolve_path()
|
||||
|
||||
@@ -86,7 +93,9 @@ def init(default_conf="client", **kwargs):
|
||||
qr = QlibRecorder(exp_manager)
|
||||
R.register(qr)
|
||||
# clean up experiment when python program ends
|
||||
atexit.register(R.end_exp, status="FAILED") # will not take effect if experiment ends
|
||||
atexit.register(R.end_exp, recorder_status="FINISHED") # will not take effect if experiment ends
|
||||
signal.signal(signal.SIGINT, experiment_kill_signal_handler)
|
||||
sys.excepthook = experiment_exception_hook
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
|
||||
@@ -222,7 +222,9 @@ class QlibConfig(Config):
|
||||
|
||||
def get_uri_type(self):
|
||||
is_win = re.match("^[a-zA-Z]:.*", self["provider_uri"]) is not None # such as 'C:\\data', 'D:'
|
||||
is_nfs_or_win = re.match("^[^/]+:.+", self["provider_uri"]) is not None # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
is_nfs_or_win = (
|
||||
re.match("^[^/]+:.+", self["provider_uri"]) is not None
|
||||
) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
|
||||
if is_nfs_or_win and not is_win:
|
||||
return QlibConfig.NFS_URI
|
||||
|
||||
@@ -161,7 +161,7 @@ class DNNModelPytorch(Model):
|
||||
try:
|
||||
wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L)
|
||||
w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
|
||||
except:
|
||||
except KeyError as e:
|
||||
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
|
||||
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
|
||||
|
||||
@@ -287,20 +287,6 @@ class DNNModelPytorch(Model):
|
||||
preds = self.dnn_model(x_test).detach().numpy()
|
||||
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
|
||||
|
||||
def score(self, x_test, y_test, w_test=None):
|
||||
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
|
||||
df_test = dataset.prepare("test", col_set=["feature", "label"])
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
|
||||
preds = self.predict(x_test)
|
||||
try:
|
||||
df_test = dataset.prepare("test", col_set=["weight"])
|
||||
w_test = df_test["weight"]
|
||||
w_test_weight = w_test.values
|
||||
except:
|
||||
w_test_weight = None
|
||||
return self._scorer(y_test.values, preds, sample_weight=w_test_weight)
|
||||
|
||||
def save(self, filename, **kwargs):
|
||||
with save_multiple_parts_file(filename) as model_dir:
|
||||
model_path = os.path.join(model_dir, os.path.split(model_dir)[-1])
|
||||
@@ -318,14 +304,6 @@ class DNNModelPytorch(Model):
|
||||
self.dnn_model.load_state_dict(torch.load(_model_path))
|
||||
self._fitted = True
|
||||
|
||||
def finetune(self, dataset, w_train=None, w_valid=None, **kwargs):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
self.fit(x_train, y_train, x_valid, y_valid, w_train=w_train, w_valid=w_valid, **kwargs)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
@@ -52,4 +52,6 @@ def fetch_df_by_index(
|
||||
idx_slc = (selector, slice(None, None))
|
||||
if get_level_index(df, level) == 1:
|
||||
idx_slc = idx_slc[1], idx_slc[0]
|
||||
return df.loc[pd.IndexSlice[idx_slc], ] # This could be faster than df.loc(axis=0)[idx_slc]
|
||||
return df.loc[
|
||||
pd.IndexSlice[idx_slc],
|
||||
] # This could be faster than df.loc(axis=0)[idx_slc]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
from .expm import MLflowExpManager
|
||||
from .recorder import Recorder
|
||||
from ..utils import Wrapper
|
||||
|
||||
|
||||
@@ -31,7 +32,7 @@ class QlibRecorder:
|
||||
self.exp_manager = exp_manager
|
||||
|
||||
@contextmanager
|
||||
def start(self, experiment_name):
|
||||
def start(self, experiment_name=None):
|
||||
"""
|
||||
Method to start an experiment. This method can only be called within a Python's `with` statement.
|
||||
|
||||
@@ -53,13 +54,13 @@ class QlibRecorder:
|
||||
try:
|
||||
yield run
|
||||
except Exception as e:
|
||||
self.end_exp("FAILED") # end the experiment if something went wrong
|
||||
self.end_exp(Recorder.STATUS_FA) # end the experiment if something went wrong
|
||||
raise e
|
||||
self.end_exp("FINISHED")
|
||||
self.end_exp(Recorder.STATUS_FI)
|
||||
|
||||
def start_exp(self, experiment_name=None, uri=None):
|
||||
"""
|
||||
Lower leverl method for starting an experiment. When use this method, one should end the experiment manually
|
||||
Lower level method for starting an experiment. When use this method, one should end the experiment manually
|
||||
and the status of the recorder may not be handled properly.
|
||||
|
||||
Use case:
|
||||
@@ -67,7 +68,7 @@ class QlibRecorder:
|
||||
```
|
||||
R.start_exp(experiment_name='test')
|
||||
... # further operations
|
||||
R.end_exp('FINISHED')
|
||||
R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
|
||||
```
|
||||
|
||||
Parameters
|
||||
@@ -83,7 +84,7 @@ class QlibRecorder:
|
||||
"""
|
||||
return self.exp_manager.start_exp(experiment_name, uri)
|
||||
|
||||
def end_exp(self, status):
|
||||
def end_exp(self, recorder_status=Recorder.STATUS_FI):
|
||||
"""
|
||||
Method for ending an experiment manually. It will end the current active experiment, as well as its
|
||||
active recorder with the specified `status` type.
|
||||
@@ -93,7 +94,7 @@ class QlibRecorder:
|
||||
```
|
||||
R.start_exp(experiment_name='test')
|
||||
... # further operations
|
||||
R.end_exp('FINISHED')
|
||||
R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
|
||||
```
|
||||
|
||||
Parameters
|
||||
@@ -101,7 +102,7 @@ class QlibRecorder:
|
||||
status : str
|
||||
The status of a recorder, which can be SCHEDULED, RUNNING, FINISHED, FAILED.
|
||||
"""
|
||||
self.exp_manager.end_exp(status)
|
||||
self.exp_manager.end_exp(recorder_status)
|
||||
|
||||
def search_records(self, experiment_ids, **kwargs):
|
||||
"""
|
||||
@@ -175,7 +176,7 @@ class QlibRecorder:
|
||||
"""
|
||||
return self.get_exp(experiment_id, experiment_name).list_recorders()
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create=True):
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
|
||||
"""
|
||||
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
|
||||
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
|
||||
@@ -185,18 +186,18 @@ class QlibRecorder:
|
||||
If R's running:
|
||||
1) no id or name specified, return the active experiment.
|
||||
2) if id or name is specified, return the specified experiment. If no such exp found,
|
||||
create a new experiment with given id or name.
|
||||
create a new experiment with given id or name, and the experiment is set to be running.
|
||||
If R's not running:
|
||||
1) no id or name specified, create a default experiment.
|
||||
2) if id or name is specified, return the specified experiment. If no such exp found,
|
||||
create a new experiment with given id or name.
|
||||
create a new experiment with given id or name, and the experiment is set to be running.
|
||||
Else If `create` is False:
|
||||
If R's running:
|
||||
1) no id or name specified, return the active experiment.
|
||||
2) if id or name is specified, return the specified experiment. If no such exp found,
|
||||
raise Error.
|
||||
If R's not running:
|
||||
1) no id or name specified, raise Error.
|
||||
1) no id or name specified. If the default experiment exists, return it, otherwise, raise Error.
|
||||
2) if id or name is specified, return the specified experiment. If no such exp found,
|
||||
raise Error.
|
||||
|
||||
@@ -219,7 +220,7 @@ class QlibRecorder:
|
||||
exp = R.get_exp(experiment_name='test')
|
||||
|
||||
# Case 5
|
||||
exp = R.get_exp(create=False) -> Error
|
||||
exp = R.get_exp(create=False) -> the default experiment if exists.
|
||||
```
|
||||
|
||||
Parameters
|
||||
@@ -229,7 +230,8 @@ class QlibRecorder:
|
||||
experiment_name : str
|
||||
name of the experiment.
|
||||
create : boolean
|
||||
decide whether to create an default experiment.
|
||||
an argument determines whether the method will automatically create a new experiment
|
||||
according to user's specification if the experiment hasn't been created before.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -348,7 +350,8 @@ class QlibRecorder:
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
"""
|
||||
Method for saving objects as artifacts in the experiment to the uri. It supports either saving
|
||||
from a local file/directory, or directly saving objects.
|
||||
from a local file/directory, or directly saving objects. User can use valid python's keywords arguments
|
||||
to specify the object to be saved as well as its name (name: value).
|
||||
|
||||
If R's running: it will save the objects through the running recorder.
|
||||
If R's not running: the system will create a default experiment, and a new recorder and
|
||||
@@ -364,28 +367,16 @@ class QlibRecorder:
|
||||
# Case 1
|
||||
with R.start('test'):
|
||||
pred = model.predict(dataset)
|
||||
R.save_objects(data=pred, name='pred.pkl', artifact_path='prediction')
|
||||
kwargs = {"pred.pkl": pred}
|
||||
R.save_objects(**kwargs, artifact_path='prediction')
|
||||
|
||||
# Case 2
|
||||
with R.start('test'):
|
||||
pred1 = model1.predict(dataset)
|
||||
pred2 = model2.predict(dataset)
|
||||
dn_list = [(pred1, 'pred1.pkl'), (pred2, 'pred2.pkl')]
|
||||
R.save_objects(data_name_list=dn_list)
|
||||
|
||||
# Case 3
|
||||
with R.start('test'):
|
||||
R.save_objects(local_path='results/pred.pkl')
|
||||
```
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : any type
|
||||
the data to be saved.
|
||||
name : str
|
||||
name of the file to be saved.
|
||||
data_name_list : list
|
||||
list of (data, name) pairs
|
||||
local_path : str
|
||||
if provided, them save the file or directory to the artifact URI.
|
||||
artifact_path=None : str
|
||||
@@ -464,10 +455,10 @@ class QlibRecorder:
|
||||
```
|
||||
# Case 1
|
||||
with R.start('test'):
|
||||
R.set_tags(release_version=2.2.0)
|
||||
R.set_tags(release_version="2.2.0")
|
||||
|
||||
# Case 2
|
||||
R.set_tags(release_version=2.2.0)
|
||||
R.set_tags(release_version="2.2.0")
|
||||
```
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import mlflow
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from .recorder import MLflowRecorder
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
@@ -20,7 +20,6 @@ class Experiment:
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.active_recorder = None # only one recorder can running each time
|
||||
self.recorders = dict() # recorder id -> object
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
@@ -30,31 +29,32 @@ class Experiment:
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
recorders = self.list_recorders()
|
||||
output = dict()
|
||||
output["class"] = "Experiment"
|
||||
output["id"] = self.id
|
||||
output["name"] = self.name
|
||||
output["active_recorder"] = self.active_recorder.id if self.active_recorder is not None else None
|
||||
output["recorders"] = list(self.recorders.keys())
|
||||
output["recorders"] = list(recorders.keys())
|
||||
return output
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start the experiment.
|
||||
Start the experiment and set it to be active. This method will also start a new recorder.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A running recorder instance.
|
||||
An active recorder.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start` method.")
|
||||
|
||||
def end(self, status):
|
||||
def end(self, recorder_status=Recorder.STATUS_S):
|
||||
"""
|
||||
End the experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
status : str
|
||||
recorder_status : str
|
||||
the status the recorder to be set with when ending (SCHEDULED, RUNNING, FINISHED, FAILED).
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end` method.")
|
||||
@@ -72,17 +72,7 @@ class Experiment:
|
||||
def search_records(self, **kwargs):
|
||||
"""
|
||||
Get a pandas DataFrame of records that fit the search criteria of the experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter_string : str
|
||||
filter query string, defaults to searching all runs.
|
||||
run_view_type : int
|
||||
one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL (e.g. in mlflow.entities.ViewType).
|
||||
max_results : int
|
||||
the maximum number of runs to put in the dataframe.
|
||||
order_by : list
|
||||
list of columns to order by (e.g., “metrics.rmse”).
|
||||
Inputs are the search critera user want to apply.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -104,9 +94,31 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_recorder` method.")
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None):
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True):
|
||||
"""
|
||||
Get the current active Recorder.
|
||||
Retrieve a Recorder for user. When user specify recorder id and name, the method will try to return the
|
||||
specific recorder. When user does not provide recorder id or name, the method will try to return the current
|
||||
active recorder. The `create` argument determines whether the method will automatically create a new recorder
|
||||
according to user's specification if the recorder hasn't been created before
|
||||
|
||||
If `create` is True:
|
||||
If R's running:
|
||||
1) no id or name specified, return the active recorder.
|
||||
2) if id or name is specified, return the specified recorder. If no such exp found,
|
||||
create a new recorder with given id or name, and the recorder shoud be running.
|
||||
If R's not running:
|
||||
1) no id or name specified, create a new recorder.
|
||||
2) if id or name is specified, return the specified experiment. If no such exp found,
|
||||
create a new recorder with given id or name, and the recorder shoud be running.
|
||||
Else If `create` is False:
|
||||
If R's running:
|
||||
1) no id or name specified, return the active recorder.
|
||||
2) if id or name is specified, return the specified recorder. If no such exp found,
|
||||
raise Error.
|
||||
If R's not running:
|
||||
1) no id or name specified, raise Error.
|
||||
2) if id or name is specified, return the specified recorder. If no such exp found,
|
||||
raise Error.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -140,32 +152,29 @@ class MLflowExperiment(Experiment):
|
||||
def __init__(self, id, name, uri):
|
||||
super(MLflowExperiment, self).__init__(id, name)
|
||||
self._uri = uri
|
||||
self._total_recorders = 0
|
||||
self._default_name = None
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
|
||||
def start(self):
|
||||
# get all the recorders of the experiment
|
||||
self.recorders = self.list_recorders()
|
||||
# set the active experiment
|
||||
mlflow.set_experiment(self.name)
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
# set up recorder
|
||||
recorder = self.create_recorder()
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
run = self.active_recorder.start_run()
|
||||
# store the recorder
|
||||
self.recorders[self.active_recorder.id] = recorder
|
||||
self._total_recorders += 1 # update recorder num
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
|
||||
return self.active_recorder
|
||||
|
||||
def end(self, status):
|
||||
def end(self, recorder_status):
|
||||
if self.active_recorder is not None:
|
||||
self.active_recorder.end_run(status)
|
||||
self.active_recorder.end_run(recorder_status)
|
||||
self.active_recorder = None
|
||||
self._total_recorders -= 1
|
||||
|
||||
def create_recorder(self):
|
||||
num = len(self.recorders)
|
||||
recorders = self.list_recorders()
|
||||
num = len(recorders)
|
||||
name = "Recorder_{}".format(num + 1)
|
||||
recorder = MLflowRecorder(name, self.id, self._uri)
|
||||
|
||||
@@ -177,7 +186,7 @@ class MLflowExperiment(Experiment):
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
order_by = kwargs.get("order_by")
|
||||
|
||||
return mlflow.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def delete_recorder(self, recorder_id=None, recorder_name=None):
|
||||
assert (
|
||||
@@ -185,20 +194,26 @@ class MLflowExperiment(Experiment):
|
||||
), "Please input a valid recorder id or name before deleting."
|
||||
try:
|
||||
if recorder_id is not None:
|
||||
mlflow.delete_run(recorder_id)
|
||||
self.recorders = [r for r in self.recorders if r == recorder_id]
|
||||
self.client.delete_run(recorder_id)
|
||||
else:
|
||||
for r in self.recorders:
|
||||
if self.recorders[r].name == recorder_name:
|
||||
recorders = self.list_recorders()
|
||||
for r in recorders:
|
||||
if recorders[r].name == recorder_name:
|
||||
recorder_id = r
|
||||
break
|
||||
mlflow.delete_run(recorder_id)
|
||||
self.client.delete_run(recorder_id)
|
||||
except:
|
||||
raise Exception(
|
||||
"Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
|
||||
)
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, create=True):
|
||||
"""
|
||||
MLflow doesn't support create recorder with a specific id. Thus, when user only provides recorder id and `create`
|
||||
is set to True, this method will not automatically create an active recorder.
|
||||
"""
|
||||
# retrive all the recorders under this experiment
|
||||
recorders = self.list_recorders()
|
||||
if recorder_id is None and recorder_name is None:
|
||||
if self.active_recorder:
|
||||
return self.active_recorder
|
||||
@@ -215,19 +230,19 @@ class MLflowExperiment(Experiment):
|
||||
)
|
||||
else:
|
||||
if recorder_id is not None:
|
||||
if recorder_id in self.recorders:
|
||||
return self.recorders[recorder_id]
|
||||
if recorder_id in recorders:
|
||||
return recorders[recorder_id]
|
||||
else:
|
||||
# mlflow does not support create a run with given id
|
||||
raise Exception(
|
||||
"Something went wrong when retrieving recorders. Please check if QlibRecorder is running or the name/id of the recorder is correct."
|
||||
)
|
||||
else:
|
||||
for rid in self.recorders:
|
||||
if self.recorders[rid].name == recorder_name:
|
||||
return self.recorders[rid]
|
||||
for rid in recorders:
|
||||
if recorders[rid].name == recorder_name:
|
||||
return recorders[rid]
|
||||
if create:
|
||||
self.recorders = self.list_recorders()
|
||||
recorders = self.list_recorders()
|
||||
logger.warning(f"No valid recorder found. Create a new recorder with name {recorder_name}.")
|
||||
recorder = self.create_recorder()
|
||||
recorder.name = recorder_name
|
||||
@@ -239,10 +254,8 @@ class MLflowExperiment(Experiment):
|
||||
)
|
||||
|
||||
def list_recorders(self):
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
runs = client.list_run_infos(self.id)[::-1]
|
||||
runs = self.client.list_run_infos(self.id, run_view_type=1)[::-1]
|
||||
recorders = dict()
|
||||
self._total_recorders = len(runs)
|
||||
for i in range(len(runs)):
|
||||
rid = runs[i].run_id
|
||||
status = runs[i].status
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from .exp import MLflowExperiment
|
||||
from .recorder import MLflowRecorder
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
@@ -22,11 +22,10 @@ class ExpManager:
|
||||
self.uri = uri
|
||||
self.default_exp_name = default_exp_name
|
||||
self.active_experiment = None # only one experiment can running each time
|
||||
self.experiments = dict() # store the experiment name --> Experiment object
|
||||
|
||||
def start_exp(self, experiment_name=None, uri=None, **kwargs):
|
||||
"""
|
||||
Start running an experiment.
|
||||
Start an experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -37,11 +36,18 @@ class ExpManager:
|
||||
|
||||
Returns
|
||||
-------
|
||||
An active recorder.
|
||||
An active experiment.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start_exp` method.")
|
||||
# create experiment
|
||||
experiment = self.create_exp(experiment_name, uri)
|
||||
# set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# start the experiment
|
||||
self.active_experiment.start()
|
||||
|
||||
def end_exp(self, **kwargs):
|
||||
return self.active_experiment
|
||||
|
||||
def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs):
|
||||
"""
|
||||
End an running experiment.
|
||||
|
||||
@@ -49,25 +55,17 @@ class ExpManager:
|
||||
----------
|
||||
experiment_name : str
|
||||
name of the active experiment.
|
||||
recorder_status : str
|
||||
the status of the active recorder of the experiment.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end_exp` method.")
|
||||
if self.active_experiment is not None:
|
||||
self.active_experiment.end(recorder_status)
|
||||
self.active_experiment = None
|
||||
|
||||
def search_records(self, experiment_ids=None, **kwargs):
|
||||
"""
|
||||
Get a pandas DataFrame of records that fit the search criteria.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_ids : list
|
||||
list of experiment IDs.
|
||||
filter_string : str
|
||||
filter query string, defaults to searching all runs.
|
||||
run_view_type : int
|
||||
one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL (e.g. in mlflow.entities.ViewType).
|
||||
max_results : int
|
||||
the maximum number of runs to put in the dataframe.
|
||||
order_by : list
|
||||
list of columns to order by (e.g., “metrics.rmse”).
|
||||
Get a pandas DataFrame of records that fit the search criteria of the experiment.
|
||||
Inputs are the search critera user want to apply.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -78,7 +76,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
def create_exp(self, experiment_name, artifact_location=None):
|
||||
def create_exp(self, experiment_name=None, uri=None):
|
||||
"""
|
||||
Create an experiment.
|
||||
|
||||
@@ -86,8 +84,8 @@ class ExpManager:
|
||||
----------
|
||||
experiment_name : str
|
||||
the experiment name, which must be unique.
|
||||
artifact_location : str
|
||||
the location to store run artifacts.
|
||||
uri : str
|
||||
the tracking uri of the experiment.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -97,14 +95,36 @@ class ExpManager:
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
|
||||
"""
|
||||
Retrieve an experiment by experiment_id from the backend store.
|
||||
Retrieve an experiment. When user specify experiment id and name, the method will try to return the
|
||||
specific experiment. When user does not provide recorder id or name, the method will try to return the current
|
||||
active experiment. The `create` argument determines whether the method will automatically create a new experiment
|
||||
according to user's specification if the experiment hasn't been created before
|
||||
|
||||
If `create` is True:
|
||||
If R's running:
|
||||
1) no id or name specified, return the active experiment.
|
||||
2) if id or name is specified, return the specified experiment. If no such exp found,
|
||||
create a new experiment with given id or name, and the experiment is set to be running.
|
||||
If R's not running:
|
||||
1) no id or name specified, create a default experiment.
|
||||
2) if id or name is specified, return the specified experiment. If no such exp found,
|
||||
create a new experiment with given id or name, and the experiment is set to be running.
|
||||
Else If `create` is False:
|
||||
If R's running:
|
||||
1) no id or name specified, return the active experiment.
|
||||
2) if id or name is specified, return the specified experiment. If no such exp found,
|
||||
raise Error.
|
||||
If R's not running:
|
||||
1) no id or name specified. If the default experiment exists, return it, otherwise, raise Error.
|
||||
2) if id or name is specified, return the specified experiment. If no such exp found,
|
||||
raise Error.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
the experiment id to return.
|
||||
create : boolean
|
||||
create the experiment if it does not exists
|
||||
create the experiment if hasn't been created before.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -153,28 +173,11 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
super(MLflowExpManager, self).__init__(uri, default_exp_name)
|
||||
self._total_exps = 0
|
||||
# get all the exps
|
||||
self.experiments = self.list_experiments()
|
||||
|
||||
def start_exp(self, experiment_name=None, uri=None):
|
||||
# create experiment
|
||||
experiment = self.create_exp(experiment_name, uri)
|
||||
# set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# start the experiment
|
||||
self.active_experiment.start()
|
||||
self._total_exps += 1 # update exp num
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
def end_exp(self, status):
|
||||
if self.active_experiment is not None:
|
||||
self.active_experiment.end(status)
|
||||
self.active_experiment = None
|
||||
self._total_exps -= 1
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
|
||||
def create_exp(self, experiment_name=None, uri=None):
|
||||
# retrieve all created experiments
|
||||
experiments = self.list_experiments()
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
logger.info(
|
||||
@@ -188,29 +191,28 @@ class MLflowExpManager(ExpManager):
|
||||
logger.info(
|
||||
f"No experiment name provided. The default experiment name is set as `{self.default_exp_name}`."
|
||||
)
|
||||
experiment_id = mlflow.create_experiment(self.default_exp_name)
|
||||
if self.default_exp_name not in experiments:
|
||||
experiment_id = self.client.create_experiment(self.default_exp_name)
|
||||
else:
|
||||
experiment_id = self.client.get_experiment_by_name(self.default_exp_name).experiment_id
|
||||
# set the active experiment
|
||||
mlflow.set_experiment(self.default_exp_name)
|
||||
experiment_name = self.default_exp_name
|
||||
else:
|
||||
if experiment_name not in self.experiments:
|
||||
if mlflow.get_experiment_by_name(experiment_name) is not None:
|
||||
if experiment_name not in experiments:
|
||||
if self.client.get_experiment_by_name(experiment_name) is not None:
|
||||
logger.info(
|
||||
"The experiment has already been created before. Try to resume the experiment with a new recorder..."
|
||||
)
|
||||
experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id
|
||||
experiment_id = self.client.get_experiment_by_name(experiment_name).experiment_id
|
||||
else:
|
||||
experiment_id = mlflow.create_experiment(experiment_name)
|
||||
experiment_id = self.client.create_experiment(experiment_name)
|
||||
else:
|
||||
experiment_id = self.experiments[experiment_name].id
|
||||
experiment = self.experiments[experiment_name]
|
||||
# set the active experiment
|
||||
mlflow.set_experiment(experiment_name)
|
||||
experiment_id = experiments[experiment_name].id
|
||||
experiment = experiments[experiment_name]
|
||||
# init experiment
|
||||
experiment = MLflowExperiment(experiment_id, experiment_name, self.uri)
|
||||
experiment._default_name = self.default_exp_name
|
||||
# store the experiment
|
||||
self.experiments[experiment_name] = experiment
|
||||
|
||||
return experiment
|
||||
|
||||
@@ -219,9 +221,11 @@ class MLflowExpManager(ExpManager):
|
||||
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
order_by = kwargs.get("order_by")
|
||||
return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
|
||||
return self.client.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create=True):
|
||||
# retrive all created experiments
|
||||
experiments = self.list_experiments()
|
||||
if experiment_id is None and experiment_name is None:
|
||||
if self.active_experiment:
|
||||
return self.active_experiment
|
||||
@@ -230,13 +234,15 @@ class MLflowExpManager(ExpManager):
|
||||
logger.warning("QlibRecorder is not running. Use the Default experiment for further process.")
|
||||
return self.start_exp()
|
||||
else:
|
||||
if self.default_exp_name in experiments:
|
||||
return experiments[self.default_exp_name]
|
||||
raise Exception(
|
||||
"Something went wrong when retrieving experiments. Please check if QlibRecorder is running or the name/id of the experiment is correct."
|
||||
)
|
||||
else:
|
||||
if experiment_name is not None:
|
||||
if experiment_name in self.experiments:
|
||||
return self.experiments[experiment_name]
|
||||
if experiment_name in experiments:
|
||||
return experiments[experiment_name]
|
||||
else:
|
||||
if create:
|
||||
logger.warning(
|
||||
@@ -248,9 +254,9 @@ class MLflowExpManager(ExpManager):
|
||||
"Something went wrong when retrieving experiments. Please check if QlibRecorder is running or the name/id of the experiment is correct."
|
||||
)
|
||||
else:
|
||||
for name in self.experiments:
|
||||
if self.experiments[name].id == experiment_id:
|
||||
return self.experiments[name]
|
||||
for name in experiments:
|
||||
if experiments[name].id == experiment_id:
|
||||
return experiments[name]
|
||||
if create:
|
||||
logger.warning(f"No valid experiment found. Use the Default experiment for further process.")
|
||||
return self.start_exp()
|
||||
@@ -265,11 +271,10 @@ class MLflowExpManager(ExpManager):
|
||||
), "Please input a valid experiment id or name before deleting."
|
||||
try:
|
||||
if experiment_id is not None:
|
||||
mlflow.delete_experiment(experiment_id)
|
||||
self.experiments = {key: val for key, val in self.experiments.items() if val.id != experiment_id}
|
||||
self.client.delete_experiment(experiment_id)
|
||||
else:
|
||||
experiment_id = self.experiments[experiment_name].id
|
||||
mlflow.delete_experiment(experiment_id)
|
||||
experiment = self.client.get_experiment_by_name(experiment_name)
|
||||
self.client.delete_experiment(experiment.experiment_id)
|
||||
except:
|
||||
raise Exception(
|
||||
"Something went wrong when deleting experiment. Please check if the name/id of the experiment is correct."
|
||||
@@ -277,10 +282,8 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
def list_experiments(self):
|
||||
# retrieve all the existing experiments
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
exps = client.list_experiments()
|
||||
exps = self.client.list_experiments(view_type=1)
|
||||
experiments = dict()
|
||||
self._total_exps = len(exps)
|
||||
for i in range(len(exps)):
|
||||
eid = exps[i].experiment_id
|
||||
ename = exps[i].name
|
||||
|
||||
@@ -8,6 +8,9 @@ from ..contrib.evaluate import (
|
||||
risk_analysis,
|
||||
)
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
class RecordTemp:
|
||||
@@ -76,7 +79,10 @@ class SignalRecord(RecordTemp):
|
||||
def generate(self, **kwargs):
|
||||
# generate prediciton
|
||||
pred = self.model.predict(self.dataset)
|
||||
self.recorder.save_objects(data=pred, name="pred.pkl")
|
||||
self.recorder.save_objects(**{"pred.pkl": pred})
|
||||
logger.info(
|
||||
f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
|
||||
def load(self):
|
||||
# try to load the saved object
|
||||
@@ -133,8 +139,8 @@ class PortAnaRecord(SignalRecord):
|
||||
# custom strategy and get backtest
|
||||
pred_score = super().load()
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
self.recorder.save_objects(data=report_normal, name="report_normal.pkl", artifact_path=self.artifact_path)
|
||||
self.recorder.save_objects(data=positions_normal, name="positions_normal.pkl", artifact_path=self.artifact_path)
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=self.artifact_path)
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=self.artifact_path)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
@@ -143,7 +149,10 @@ class PortAnaRecord(SignalRecord):
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
self.recorder.save_objects(data=analysis_df, name="port_analysis.pkl", artifact_path=self.artifact_path)
|
||||
self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=self.artifact_path)
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
|
||||
def load(self):
|
||||
# try to load the saved object
|
||||
|
||||
@@ -5,6 +5,9 @@ import mlflow
|
||||
import shutil, os, pickle, tempfile, codecs, datetime
|
||||
from pathlib import Path
|
||||
from ..utils.objm import FileManager
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
class Recorder:
|
||||
@@ -15,13 +18,19 @@ class Recorder:
|
||||
The status of the recorder can be SCHEDULED, RUNNING, FINISHED, FAILED.
|
||||
"""
|
||||
|
||||
# status type
|
||||
STATUS_S = "SCHEDULED"
|
||||
STATUS_R = "RUNNING"
|
||||
STATUS_FI = "FINISHED"
|
||||
STATUS_FA = "FAILED"
|
||||
|
||||
def __init__(self, name, experiment_id):
|
||||
self.id = None
|
||||
self.name = name
|
||||
self.experiment_id = experiment_id
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.status = "SCHEDULED"
|
||||
self.status = Recorder.STATUS_S
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
@@ -46,16 +55,11 @@ class Recorder:
|
||||
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
"""
|
||||
Save objects such as prediction file or model checkpoints to the artifact URI.
|
||||
Save objects such as prediction file or model checkpoints to the artifact URI. User
|
||||
can save object through keywords arguments (name:value).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : any type
|
||||
the data to be saved.
|
||||
name : str
|
||||
name of the file to be saved.
|
||||
data_name_list : list
|
||||
list of (data, name) pairs
|
||||
local_path : str
|
||||
if provided, them save the file or directory to the artifact URI.
|
||||
artifact_path=None : str
|
||||
@@ -170,6 +174,7 @@ class MLflowRecorder(Recorder):
|
||||
# set up file manager for saving objects
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.fm = FileManager(Path(self.temp_dir).absolute())
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
|
||||
def start_run(self):
|
||||
# start the run
|
||||
@@ -178,38 +183,36 @@ class MLflowRecorder(Recorder):
|
||||
self.id = run.info.run_id
|
||||
self.artifact_uri = run.info.artifact_uri
|
||||
self.start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.status = "RUNNING"
|
||||
self.status = Recorder.STATUS_R
|
||||
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
|
||||
|
||||
return run
|
||||
|
||||
def end_run(self, status):
|
||||
assert status in ["SCHEDULED", "RUNNING", "FINISHED", "FAILED"], f"The status type {status} is not supported."
|
||||
def end_run(self, status: str = Recorder.STATUS_S):
|
||||
assert status in [
|
||||
Recorder.STATUS_S,
|
||||
Recorder.STATUS_R,
|
||||
Recorder.STATUS_FI,
|
||||
Recorder.STATUS_FA,
|
||||
], f"The status type {status} is not supported."
|
||||
mlflow.end_run(status)
|
||||
self.end_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self.status != "FINISHED":
|
||||
if self.status != Recorder.STATUS_S:
|
||||
self.status = status
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def save_objects(self, data_name_list=None, local_path=None, artifact_path=None, **kwargs):
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
if local_path is not None:
|
||||
client.log_artifacts(self.id, local_path, artifact_path)
|
||||
elif kwargs.get("data") is not None and kwargs.get("name") is not None:
|
||||
data, name = kwargs.get("data"), kwargs.get("name")
|
||||
self.fm.save_obj(data, name)
|
||||
client.log_artifact(self.id, self.fm.path / name, artifact_path)
|
||||
elif kwargs.get("data_name_list") is not None:
|
||||
data_name_list = kwargs.get("data_name_list")
|
||||
self.fm.save_objs(data_name_list)
|
||||
client.log_artifacts(self.id, self.fm.path, artifact_path)
|
||||
self.client.log_artifacts(self.id, local_path, artifact_path)
|
||||
else:
|
||||
raise Exception("Please provide valid arguments in order to save object properly.")
|
||||
for name, data in kwargs.items():
|
||||
self.fm.save_obj(data, name)
|
||||
self.client.log_artifact(self.id, self.fm.path / name, artifact_path)
|
||||
|
||||
def load_object(self, name):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
path = client.download_artifacts(self.id, name)
|
||||
path = self.client.download_artifacts(self.id, name)
|
||||
try:
|
||||
with Path(path).open("rb") as f:
|
||||
f.seek(0)
|
||||
@@ -220,28 +223,22 @@ class MLflowRecorder(Recorder):
|
||||
|
||||
def log_params(self, **kwargs):
|
||||
keys = list(kwargs.keys())
|
||||
if len(keys) == 0:
|
||||
mlflow.log_param(keys[0], kwargs.get(keys[0]))
|
||||
else:
|
||||
mlflow.log_params(dict(kwargs))
|
||||
for name, data in kwargs.items():
|
||||
self.client.log_param(self.id, name, data)
|
||||
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
keys = list(kwargs.keys())
|
||||
if len(keys) == 0:
|
||||
mlflow.log_metric(keys[0], kwargs.get(keys[0]))
|
||||
else:
|
||||
mlflow.log_metrics(dict(kwargs))
|
||||
for name, data in kwargs.items():
|
||||
self.client.log_metric(self.id, name, data)
|
||||
|
||||
def set_tags(self, **kwargs):
|
||||
keys = list(kwargs.keys())
|
||||
if len(keys) == 0:
|
||||
mlflow.set_tag(keys[0], kwargs.get(keys[0]))
|
||||
else:
|
||||
mlflow.set_tags(dict(kwargs))
|
||||
for name, data in kwargs.items():
|
||||
self.client.set_tag(self.id, name, data)
|
||||
|
||||
def delete_tags(self, *keys):
|
||||
for count, key in enumerate(keys):
|
||||
mlflow.delete_tag(key)
|
||||
self.client.delete_tag(self.id, key)
|
||||
|
||||
def get_artifact_uri(self):
|
||||
if self.artifact_uri is not None:
|
||||
@@ -253,6 +250,5 @@ class MLflowRecorder(Recorder):
|
||||
|
||||
def list_artifacts(self, artifact_path=None):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
artifacts = client.list_artifacts(self.id, artifact_path)
|
||||
artifacts = self.client.list_artifacts(self.id, artifact_path)
|
||||
return artifacts
|
||||
|
||||
33
qlib/workflow/utils.py
Normal file
33
qlib/workflow/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys, traceback, signal
|
||||
from . import R
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
def experiment_exception_hook(type, value, tb):
|
||||
"""
|
||||
End an experiment with status to be "FAILED". This exception tries to catch those uncaught exception
|
||||
and end the experiment automatically.
|
||||
|
||||
Parameters
|
||||
type: Exception type
|
||||
value: Exception's value
|
||||
tb: Exception's traceback
|
||||
"""
|
||||
error_msg = "An exception has been raised.\n" f"Type: {type}\n" f"Value: {value}\n"
|
||||
logger.error(error_msg)
|
||||
traceback.print_tb(tb)
|
||||
|
||||
R.end_exp(recorder_status=Recorder.STATUS_FA)
|
||||
|
||||
|
||||
def experiment_kill_signal_handler(signum, frame):
|
||||
"""
|
||||
End an experiment when user kill the program (CTRL+C, etc.).
|
||||
"""
|
||||
R.end_exp(recorder_status=Recorder.STATUS_FA)
|
||||
@@ -137,5 +137,5 @@ def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
|
||||
return res.upper() if capital else res.lower()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
|
||||
|
||||
Reference in New Issue
Block a user