diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index e7c80cf6e..8d10b2ab4 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -49,7 +49,7 @@ class MultiSegRecord(RecordTemp): if save: save_name = "results-{:}.pkl".format(key) - self.recorder.save_objects(**{save_name: results}) + self.save(**{save_name: results}) logger.info( "The record '{:}' has been saved as the artifact of the Experiment {:}".format( save_name, self.recorder.experiment_id @@ -79,9 +79,8 @@ class SignalMseRecord(RecordTemp): metrics = {"MSE": mse, "RMSE": np.sqrt(mse)} objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)} self.recorder.log_metrics(**metrics) - self.recorder.save_objects(**objects, artifact_path=self.get_path()) + self.save(**objects) logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics)) def list(self): - paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")] - return paths + return ["mse.pkl", "rmse.pkl"] diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 0d85311ee..07422243d 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -9,6 +9,9 @@ import pandas as pd from pathlib import Path from pprint import pprint from typing import Union, List +from collections import defaultdict + +from qlib.utils.exceptions import LoadObjectError from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis from ..data.dataset import DatasetH @@ -45,6 +48,16 @@ class RecordTemp: return "/".join(names) + def save(self, **kwargs): + """ + It behaves the same as self.recorder.save_objects. + But it is an easier interface because users don't have to care about `get_path` and `artifact_path` + """ + art_path = self.get_path() + if art_path == "": + art_path = None + self.recorder.save_objects(artifact_path=art_path, **kwargs) + def __init__(self, recorder): self._recorder = recorder @@ -67,31 +80,37 @@ class RecordTemp: """ raise NotImplementedError(f"Please implement the `generate` method.") - def load(self, name): + def load(self, name: str, parents: bool = True): """ - Load the stored records. Due to the fact that some problems occured when we tried to balancing a clean API - with the Python's inheritance. This method has to be used in a rather ugly way, and we will try to fix them - in the future:: - - sar = SigAnaRecord(recorder) - ic = sar.load(sar.get_path("ic.pkl")) + It behaves the same as self.recorder.load_object. + But it is an easier interface because users don't have to care about `get_path` and `artifact_path` Parameters ---------- name : str the name for the file to be load. + parents : bool + Each recorder has different `artifact_path`. + So parents recursively find the path in parents + Sub classes has higher priority + Return ------ The stored records. """ - # try to load the saved object - obj = self.recorder.load_object(name) - return obj + try: + return self.recorder.load_object(self.get_path(name)) + except LoadObjectError: + if parents: + if self.depend_cls is not None: + with class_casting(self, self.depend_cls): + return self.load(name, parents=True) def list(self): """ List the supported artifacts. + Users don't have to consider self.get_path Return ------ @@ -99,7 +118,7 @@ class RecordTemp: """ return [] - def check(self, include_self: bool = False): + def check(self, include_self: bool = False, parents: bool = True): """ Check if the records is properly generated and saved. It is useful in following examples @@ -110,19 +129,34 @@ class RecordTemp: ---------- include_self : bool is the file generated by self included + parents : bool + will we check parents Raise ------ - FileExistsError: whether the records are stored properly. + FileNotFoundError + : whether the records are stored properly. """ - artifacts = set(self.recorder.list_artifacts()) if include_self: + + # Some mlflow backend will not list the directly recursively. + # So we force to the directly + artifacts = {} + + def _get_arts(dirn): + if dirn not in artifacts: + artifacts[dirn] = self.recorder.list_artifacts(dirn) + return artifacts[dirn] + for item in self.list(): - if item not in artifacts: - raise FileExistsError(item) - if self.depend_cls is not None: - with class_casting(self, self.depend_cls): - self.check(include_self=True) + ps = self.get_path(item).split("/") + dirn, fn = "/".join(ps[:-1]), ps[-1] + if self.get_path(item) not in _get_arts(dirn): + raise FileNotFoundError + if parents: + if self.depend_cls is not None: + with class_casting(self, self.depend_cls): + self.check(include_self=True) class SignalRecord(RecordTemp): @@ -158,7 +192,7 @@ class SignalRecord(RecordTemp): pred = self.model.predict(self.dataset) if isinstance(pred, pd.Series): pred = pred.to_frame("score") - self.recorder.save_objects(**{"pred.pkl": pred}) + self.save(**{"pred.pkl": pred}) logger.info( f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" @@ -169,15 +203,11 @@ class SignalRecord(RecordTemp): if isinstance(self.dataset, DatasetH): raw_label = self.generate_label(self.dataset) - self.recorder.save_objects(**{"label.pkl": raw_label}) + self.save(**{"label.pkl": raw_label}) - @staticmethod - def list(): + def list(self): return ["pred.pkl", "label.pkl"] - def load(self, name="pred.pkl"): - return super().load(name) - class HFSignalRecord(SignalRecord): """ @@ -218,19 +248,11 @@ class HFSignalRecord(SignalRecord): } ) self.recorder.log_metrics(**metrics) - self.recorder.save_objects(**objects, artifact_path=self.get_path()) + self.save(**objects) pprint(metrics) def list(self): - paths = [ - self.get_path("ic.pkl"), - self.get_path("ric.pkl"), - self.get_path("long_pre.pkl"), - self.get_path("short_pre.pkl"), - self.get_path("long_short_r.pkl"), - self.get_path("long_avg_r.pkl"), - ] - return paths + return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"] class SigAnaRecord(RecordTemp): @@ -241,13 +263,23 @@ class SigAnaRecord(RecordTemp): artifact_path = "sig_analysis" depend_cls = SignalRecord - def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0): + def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False): super().__init__(recorder=recorder) self.ana_long_short = ana_long_short self.ann_scaler = ann_scaler self.label_col = label_col + self.skip_existing = skip_existing def generate(self, **kwargs): + if self.skip_existing: + try: + self.check(include_self=True, parents=False) + except FileNotFoundError: + pass # continue to generating metrics + else: + logger.info("The results has previously generated, generation skipped.") + return + self.check() pred = self.load("pred.pkl") @@ -280,13 +312,13 @@ class SigAnaRecord(RecordTemp): } ) self.recorder.log_metrics(**metrics) - self.recorder.save_objects(**objects, artifact_path=self.get_path()) + self.save(**objects) pprint(metrics) def list(self): - paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")] + paths = ["ic.pkl", "ric.pkl"] if self.ana_long_short: - paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) + paths.extend(["long_short_r.pkl", "long_avg_r.pkl"]) return paths @@ -373,17 +405,11 @@ class PortAnaRecord(RecordTemp): executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config ) for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items(): - self.recorder.save_objects( - **{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path() - ) - self.recorder.save_objects( - **{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"report_normal_{_freq}.pkl": report_normal}) + self.save(**{f"positions_normal_{_freq}.pkl": positions_normal}) for _freq, indicators_normal in indicator_dict.items(): - self.recorder.save_objects( - **{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal}) for _analysis_freq in self.risk_analysis_freq: if _analysis_freq not in portfolio_metric_dict: @@ -405,9 +431,7 @@ class PortAnaRecord(RecordTemp): analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict()) self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()}) # save results - self.recorder.save_objects( - **{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}) logger.info( f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) @@ -432,9 +456,7 @@ class PortAnaRecord(RecordTemp): analysis_dict = analysis_df["value"].to_dict() self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()}) # save results - self.recorder.save_objects( - **{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}) logger.info( f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) @@ -446,20 +468,19 @@ class PortAnaRecord(RecordTemp): for _freq in self.all_freq: list_path.extend( [ - PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"), - PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"), + f"report_normal_{_freq}.pkl", + f"positions_normal_{_freq}.pkl", ] ) for _analysis_freq in self.risk_analysis_freq: if _analysis_freq in self.all_freq: - list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl")) + list_path.append(f"port_analysis_{_analysis_freq}.pkl") else: warnings.warn(f"risk_analysis freq {_analysis_freq} is not found") for _analysis_freq in self.indicator_analysis_freq: if _analysis_freq in self.all_freq: - list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl")) + list_path.append(f"indicator_analysis_{_analysis_freq}.pkl") else: warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found") - return list_path diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 24c6765aa..de15d8722 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -47,13 +47,13 @@ def train(uri_path: str = None): rid = recorder.id sr = SignalRecord(model, dataset, recorder) sr.generate() - pred_score = sr.load(sr.get_path("pred.pkl")) + pred_score = sr.load("pred.pkl") # calculate ic and ric sar = SigAnaRecord(recorder) sar.generate() - ic = sar.load(sar.get_path("ic.pkl")) - ric = sar.load(sar.get_path("ric.pkl")) + ic = sar.load("ic.pkl") + ric = sar.load("ric.pkl") return pred_score, {"ic": ic, "ric": ric}, rid @@ -78,13 +78,13 @@ def train_with_sigana(uri_path: str = None): sr = SignalRecord(model, dataset, recorder) sr.generate() - pred_score = sr.load(sr.get_path("pred.pkl")) + pred_score = sr.load("pred.pkl") # predict and calculate ic and ric sar = SigAnaRecord(recorder) sar.generate() - ic = sar.load(sar.get_path("ic.pkl")) - ric = sar.load(sar.get_path("ric.pkl")) + ic = sar.load("ic.pkl") + ric = sar.load("ric.pkl") uri_path = R.get_uri() return pred_score, {"ic": ic, "ric": ric}, uri_path @@ -169,7 +169,7 @@ def backtest_analysis(pred, rid, uri_path: str = None): # backtest par = PortAnaRecord(recorder, port_analysis_config, risk_analysis_freq="day") par.generate() - analysis_df = par.load(par.get_path("port_analysis_1day.pkl")) + analysis_df = par.load("port_analysis_1day.pkl") print(analysis_df) return analysis_df