mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
simplify record tmp
This commit is contained in:
@@ -49,7 +49,7 @@ class MultiSegRecord(RecordTemp):
|
||||
|
||||
if save:
|
||||
save_name = "results-{:}.pkl".format(key)
|
||||
self.recorder.save_objects(**{save_name: results})
|
||||
self.save(**{save_name: results})
|
||||
logger.info(
|
||||
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
|
||||
save_name, self.recorder.experiment_id
|
||||
@@ -79,9 +79,8 @@ class SignalMseRecord(RecordTemp):
|
||||
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
|
||||
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
self.save(**objects)
|
||||
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
|
||||
return paths
|
||||
return ["mse.pkl", "rmse.pkl"]
|
||||
|
||||
@@ -9,6 +9,9 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from typing import Union, List
|
||||
from collections import defaultdict
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis
|
||||
|
||||
from ..data.dataset import DatasetH
|
||||
@@ -45,6 +48,16 @@ class RecordTemp:
|
||||
|
||||
return "/".join(names)
|
||||
|
||||
def save(self, **kwargs):
|
||||
"""
|
||||
It behaves the same as self.recorder.save_objects.
|
||||
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
|
||||
"""
|
||||
art_path = self.get_path()
|
||||
if art_path == "":
|
||||
art_path = None
|
||||
self.recorder.save_objects(artifact_path=art_path, **kwargs)
|
||||
|
||||
def __init__(self, recorder):
|
||||
self._recorder = recorder
|
||||
|
||||
@@ -67,31 +80,37 @@ class RecordTemp:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `generate` method.")
|
||||
|
||||
def load(self, name):
|
||||
def load(self, name: str, parents: bool = True):
|
||||
"""
|
||||
Load the stored records. Due to the fact that some problems occured when we tried to balancing a clean API
|
||||
with the Python's inheritance. This method has to be used in a rather ugly way, and we will try to fix them
|
||||
in the future::
|
||||
|
||||
sar = SigAnaRecord(recorder)
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
It behaves the same as self.recorder.load_object.
|
||||
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
the name for the file to be load.
|
||||
|
||||
parents : bool
|
||||
Each recorder has different `artifact_path`.
|
||||
So parents recursively find the path in parents
|
||||
Sub classes has higher priority
|
||||
|
||||
Return
|
||||
------
|
||||
The stored records.
|
||||
"""
|
||||
# try to load the saved object
|
||||
obj = self.recorder.load_object(name)
|
||||
return obj
|
||||
try:
|
||||
return self.recorder.load_object(self.get_path(name))
|
||||
except LoadObjectError:
|
||||
if parents:
|
||||
if self.depend_cls is not None:
|
||||
with class_casting(self, self.depend_cls):
|
||||
return self.load(name, parents=True)
|
||||
|
||||
def list(self):
|
||||
"""
|
||||
List the supported artifacts.
|
||||
Users don't have to consider self.get_path
|
||||
|
||||
Return
|
||||
------
|
||||
@@ -99,7 +118,7 @@ class RecordTemp:
|
||||
"""
|
||||
return []
|
||||
|
||||
def check(self, include_self: bool = False):
|
||||
def check(self, include_self: bool = False, parents: bool = True):
|
||||
"""
|
||||
Check if the records is properly generated and saved.
|
||||
It is useful in following examples
|
||||
@@ -110,19 +129,34 @@ class RecordTemp:
|
||||
----------
|
||||
include_self : bool
|
||||
is the file generated by self included
|
||||
parents : bool
|
||||
will we check parents
|
||||
|
||||
Raise
|
||||
------
|
||||
FileExistsError: whether the records are stored properly.
|
||||
FileNotFoundError
|
||||
: whether the records are stored properly.
|
||||
"""
|
||||
artifacts = set(self.recorder.list_artifacts())
|
||||
if include_self:
|
||||
|
||||
# Some mlflow backend will not list the directly recursively.
|
||||
# So we force to the directly
|
||||
artifacts = {}
|
||||
|
||||
def _get_arts(dirn):
|
||||
if dirn not in artifacts:
|
||||
artifacts[dirn] = self.recorder.list_artifacts(dirn)
|
||||
return artifacts[dirn]
|
||||
|
||||
for item in self.list():
|
||||
if item not in artifacts:
|
||||
raise FileExistsError(item)
|
||||
if self.depend_cls is not None:
|
||||
with class_casting(self, self.depend_cls):
|
||||
self.check(include_self=True)
|
||||
ps = self.get_path(item).split("/")
|
||||
dirn, fn = "/".join(ps[:-1]), ps[-1]
|
||||
if self.get_path(item) not in _get_arts(dirn):
|
||||
raise FileNotFoundError
|
||||
if parents:
|
||||
if self.depend_cls is not None:
|
||||
with class_casting(self, self.depend_cls):
|
||||
self.check(include_self=True)
|
||||
|
||||
|
||||
class SignalRecord(RecordTemp):
|
||||
@@ -158,7 +192,7 @@ class SignalRecord(RecordTemp):
|
||||
pred = self.model.predict(self.dataset)
|
||||
if isinstance(pred, pd.Series):
|
||||
pred = pred.to_frame("score")
|
||||
self.recorder.save_objects(**{"pred.pkl": pred})
|
||||
self.save(**{"pred.pkl": pred})
|
||||
|
||||
logger.info(
|
||||
f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
@@ -169,15 +203,11 @@ class SignalRecord(RecordTemp):
|
||||
|
||||
if isinstance(self.dataset, DatasetH):
|
||||
raw_label = self.generate_label(self.dataset)
|
||||
self.recorder.save_objects(**{"label.pkl": raw_label})
|
||||
self.save(**{"label.pkl": raw_label})
|
||||
|
||||
@staticmethod
|
||||
def list():
|
||||
def list(self):
|
||||
return ["pred.pkl", "label.pkl"]
|
||||
|
||||
def load(self, name="pred.pkl"):
|
||||
return super().load(name)
|
||||
|
||||
|
||||
class HFSignalRecord(SignalRecord):
|
||||
"""
|
||||
@@ -218,19 +248,11 @@ class HFSignalRecord(SignalRecord):
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
self.save(**objects)
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [
|
||||
self.get_path("ic.pkl"),
|
||||
self.get_path("ric.pkl"),
|
||||
self.get_path("long_pre.pkl"),
|
||||
self.get_path("short_pre.pkl"),
|
||||
self.get_path("long_short_r.pkl"),
|
||||
self.get_path("long_avg_r.pkl"),
|
||||
]
|
||||
return paths
|
||||
return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"]
|
||||
|
||||
|
||||
class SigAnaRecord(RecordTemp):
|
||||
@@ -241,13 +263,23 @@ class SigAnaRecord(RecordTemp):
|
||||
artifact_path = "sig_analysis"
|
||||
depend_cls = SignalRecord
|
||||
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0):
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False):
|
||||
super().__init__(recorder=recorder)
|
||||
self.ana_long_short = ana_long_short
|
||||
self.ann_scaler = ann_scaler
|
||||
self.label_col = label_col
|
||||
self.skip_existing = skip_existing
|
||||
|
||||
def generate(self, **kwargs):
|
||||
if self.skip_existing:
|
||||
try:
|
||||
self.check(include_self=True, parents=False)
|
||||
except FileNotFoundError:
|
||||
pass # continue to generating metrics
|
||||
else:
|
||||
logger.info("The results has previously generated, generation skipped.")
|
||||
return
|
||||
|
||||
self.check()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
@@ -280,13 +312,13 @@ class SigAnaRecord(RecordTemp):
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
self.save(**objects)
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
|
||||
paths = ["ic.pkl", "ric.pkl"]
|
||||
if self.ana_long_short:
|
||||
paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")])
|
||||
paths.extend(["long_short_r.pkl", "long_avg_r.pkl"])
|
||||
return paths
|
||||
|
||||
|
||||
@@ -373,17 +405,11 @@ class PortAnaRecord(RecordTemp):
|
||||
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
|
||||
)
|
||||
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
|
||||
self.recorder.save_objects(
|
||||
**{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.recorder.save_objects(
|
||||
**{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
|
||||
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
|
||||
|
||||
for _freq, indicators_normal in indicator_dict.items():
|
||||
self.recorder.save_objects(
|
||||
**{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal})
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq not in portfolio_metric_dict:
|
||||
@@ -405,9 +431,7 @@ class PortAnaRecord(RecordTemp):
|
||||
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.recorder.save_objects(
|
||||
**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -432,9 +456,7 @@ class PortAnaRecord(RecordTemp):
|
||||
analysis_dict = analysis_df["value"].to_dict()
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.recorder.save_objects(
|
||||
**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -446,20 +468,19 @@ class PortAnaRecord(RecordTemp):
|
||||
for _freq in self.all_freq:
|
||||
list_path.extend(
|
||||
[
|
||||
PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"),
|
||||
PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"),
|
||||
f"report_normal_{_freq}.pkl",
|
||||
f"positions_normal_{_freq}.pkl",
|
||||
]
|
||||
)
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl"))
|
||||
list_path.append(f"port_analysis_{_analysis_freq}.pkl")
|
||||
else:
|
||||
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
|
||||
|
||||
for _analysis_freq in self.indicator_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl"))
|
||||
list_path.append(f"indicator_analysis_{_analysis_freq}.pkl")
|
||||
else:
|
||||
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
|
||||
|
||||
return list_path
|
||||
|
||||
@@ -47,13 +47,13 @@ def train(uri_path: str = None):
|
||||
rid = recorder.id
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
pred_score = sr.load(sr.get_path("pred.pkl"))
|
||||
pred_score = sr.load("pred.pkl")
|
||||
|
||||
# calculate ic and ric
|
||||
sar = SigAnaRecord(recorder)
|
||||
sar.generate()
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
ic = sar.load("ic.pkl")
|
||||
ric = sar.load("ric.pkl")
|
||||
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
@@ -78,13 +78,13 @@ def train_with_sigana(uri_path: str = None):
|
||||
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
pred_score = sr.load(sr.get_path("pred.pkl"))
|
||||
pred_score = sr.load("pred.pkl")
|
||||
|
||||
# predict and calculate ic and ric
|
||||
sar = SigAnaRecord(recorder)
|
||||
sar.generate()
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
ic = sar.load("ic.pkl")
|
||||
ric = sar.load("ric.pkl")
|
||||
|
||||
uri_path = R.get_uri()
|
||||
return pred_score, {"ic": ic, "ric": ric}, uri_path
|
||||
@@ -169,7 +169,7 @@ def backtest_analysis(pred, rid, uri_path: str = None):
|
||||
# backtest
|
||||
par = PortAnaRecord(recorder, port_analysis_config, risk_analysis_freq="day")
|
||||
par.generate()
|
||||
analysis_df = par.load(par.get_path("port_analysis_1day.pkl"))
|
||||
analysis_df = par.load("port_analysis_1day.pkl")
|
||||
print(analysis_df)
|
||||
return analysis_df
|
||||
|
||||
|
||||
Reference in New Issue
Block a user