1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

simplify record tmp

This commit is contained in:
Young
2021-11-05 11:34:21 +00:00
parent 4f2d6b0d84
commit 3fa48d7017
3 changed files with 90 additions and 70 deletions

View File

@@ -49,7 +49,7 @@ class MultiSegRecord(RecordTemp):
if save:
save_name = "results-{:}.pkl".format(key)
self.recorder.save_objects(**{save_name: results})
self.save(**{save_name: results})
logger.info(
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
save_name, self.recorder.experiment_id
@@ -79,9 +79,8 @@ class SignalMseRecord(RecordTemp):
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
def list(self):
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
return paths
return ["mse.pkl", "rmse.pkl"]

View File

@@ -9,6 +9,9 @@ import pandas as pd
from pathlib import Path
from pprint import pprint
from typing import Union, List
from collections import defaultdict
from qlib.utils.exceptions import LoadObjectError
from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis
from ..data.dataset import DatasetH
@@ -45,6 +48,16 @@ class RecordTemp:
return "/".join(names)
def save(self, **kwargs):
"""
It behaves the same as self.recorder.save_objects.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
"""
art_path = self.get_path()
if art_path == "":
art_path = None
self.recorder.save_objects(artifact_path=art_path, **kwargs)
def __init__(self, recorder):
self._recorder = recorder
@@ -67,31 +80,37 @@ class RecordTemp:
"""
raise NotImplementedError(f"Please implement the `generate` method.")
def load(self, name):
def load(self, name: str, parents: bool = True):
"""
Load the stored records. Due to the fact that some problems occured when we tried to balancing a clean API
with the Python's inheritance. This method has to be used in a rather ugly way, and we will try to fix them
in the future::
sar = SigAnaRecord(recorder)
ic = sar.load(sar.get_path("ic.pkl"))
It behaves the same as self.recorder.load_object.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
Parameters
----------
name : str
the name for the file to be load.
parents : bool
Each recorder has different `artifact_path`.
So parents recursively find the path in parents
Sub classes has higher priority
Return
------
The stored records.
"""
# try to load the saved object
obj = self.recorder.load_object(name)
return obj
try:
return self.recorder.load_object(self.get_path(name))
except LoadObjectError:
if parents:
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
return self.load(name, parents=True)
def list(self):
"""
List the supported artifacts.
Users don't have to consider self.get_path
Return
------
@@ -99,7 +118,7 @@ class RecordTemp:
"""
return []
def check(self, include_self: bool = False):
def check(self, include_self: bool = False, parents: bool = True):
"""
Check if the records is properly generated and saved.
It is useful in following examples
@@ -110,19 +129,34 @@ class RecordTemp:
----------
include_self : bool
is the file generated by self included
parents : bool
will we check parents
Raise
------
FileExistsError: whether the records are stored properly.
FileNotFoundError
: whether the records are stored properly.
"""
artifacts = set(self.recorder.list_artifacts())
if include_self:
# Some mlflow backend will not list the directly recursively.
# So we force to the directly
artifacts = {}
def _get_arts(dirn):
if dirn not in artifacts:
artifacts[dirn] = self.recorder.list_artifacts(dirn)
return artifacts[dirn]
for item in self.list():
if item not in artifacts:
raise FileExistsError(item)
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
self.check(include_self=True)
ps = self.get_path(item).split("/")
dirn, fn = "/".join(ps[:-1]), ps[-1]
if self.get_path(item) not in _get_arts(dirn):
raise FileNotFoundError
if parents:
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
self.check(include_self=True)
class SignalRecord(RecordTemp):
@@ -158,7 +192,7 @@ class SignalRecord(RecordTemp):
pred = self.model.predict(self.dataset)
if isinstance(pred, pd.Series):
pred = pred.to_frame("score")
self.recorder.save_objects(**{"pred.pkl": pred})
self.save(**{"pred.pkl": pred})
logger.info(
f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
@@ -169,15 +203,11 @@ class SignalRecord(RecordTemp):
if isinstance(self.dataset, DatasetH):
raw_label = self.generate_label(self.dataset)
self.recorder.save_objects(**{"label.pkl": raw_label})
self.save(**{"label.pkl": raw_label})
@staticmethod
def list():
def list(self):
return ["pred.pkl", "label.pkl"]
def load(self, name="pred.pkl"):
return super().load(name)
class HFSignalRecord(SignalRecord):
"""
@@ -218,19 +248,11 @@ class HFSignalRecord(SignalRecord):
}
)
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
pprint(metrics)
def list(self):
paths = [
self.get_path("ic.pkl"),
self.get_path("ric.pkl"),
self.get_path("long_pre.pkl"),
self.get_path("short_pre.pkl"),
self.get_path("long_short_r.pkl"),
self.get_path("long_avg_r.pkl"),
]
return paths
return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"]
class SigAnaRecord(RecordTemp):
@@ -241,13 +263,23 @@ class SigAnaRecord(RecordTemp):
artifact_path = "sig_analysis"
depend_cls = SignalRecord
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0):
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False):
super().__init__(recorder=recorder)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler
self.label_col = label_col
self.skip_existing = skip_existing
def generate(self, **kwargs):
if self.skip_existing:
try:
self.check(include_self=True, parents=False)
except FileNotFoundError:
pass # continue to generating metrics
else:
logger.info("The results has previously generated, generation skipped.")
return
self.check()
pred = self.load("pred.pkl")
@@ -280,13 +312,13 @@ class SigAnaRecord(RecordTemp):
}
)
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
pprint(metrics)
def list(self):
paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
paths = ["ic.pkl", "ric.pkl"]
if self.ana_long_short:
paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")])
paths.extend(["long_short_r.pkl", "long_avg_r.pkl"])
return paths
@@ -373,17 +405,11 @@ class PortAnaRecord(RecordTemp):
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
)
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
self.recorder.save_objects(
**{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
)
self.recorder.save_objects(
**{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
for _freq, indicators_normal in indicator_dict.items():
self.recorder.save_objects(
**{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal})
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq not in portfolio_metric_dict:
@@ -405,9 +431,7 @@ class PortAnaRecord(RecordTemp):
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.recorder.save_objects(
**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
@@ -432,9 +456,7 @@ class PortAnaRecord(RecordTemp):
analysis_dict = analysis_df["value"].to_dict()
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.recorder.save_objects(
**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
@@ -446,20 +468,19 @@ class PortAnaRecord(RecordTemp):
for _freq in self.all_freq:
list_path.extend(
[
PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"),
PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"),
f"report_normal_{_freq}.pkl",
f"positions_normal_{_freq}.pkl",
]
)
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq in self.all_freq:
list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl"))
list_path.append(f"port_analysis_{_analysis_freq}.pkl")
else:
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
for _analysis_freq in self.indicator_analysis_freq:
if _analysis_freq in self.all_freq:
list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl"))
list_path.append(f"indicator_analysis_{_analysis_freq}.pkl")
else:
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
return list_path

View File

@@ -47,13 +47,13 @@ def train(uri_path: str = None):
rid = recorder.id
sr = SignalRecord(model, dataset, recorder)
sr.generate()
pred_score = sr.load(sr.get_path("pred.pkl"))
pred_score = sr.load("pred.pkl")
# calculate ic and ric
sar = SigAnaRecord(recorder)
sar.generate()
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
ic = sar.load("ic.pkl")
ric = sar.load("ric.pkl")
return pred_score, {"ic": ic, "ric": ric}, rid
@@ -78,13 +78,13 @@ def train_with_sigana(uri_path: str = None):
sr = SignalRecord(model, dataset, recorder)
sr.generate()
pred_score = sr.load(sr.get_path("pred.pkl"))
pred_score = sr.load("pred.pkl")
# predict and calculate ic and ric
sar = SigAnaRecord(recorder)
sar.generate()
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
ic = sar.load("ic.pkl")
ric = sar.load("ric.pkl")
uri_path = R.get_uri()
return pred_score, {"ic": ic, "ric": ric}, uri_path
@@ -169,7 +169,7 @@ def backtest_analysis(pred, rid, uri_path: str = None):
# backtest
par = PortAnaRecord(recorder, port_analysis_config, risk_analysis_freq="day")
par.generate()
analysis_df = par.load(par.get_path("port_analysis_1day.pkl"))
analysis_df = par.load("port_analysis_1day.pkl")
print(analysis_df)
return analysis_df