1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

fix record_tmp bug

This commit is contained in:
Young
2020-11-22 03:17:50 +00:00
parent 89977320e3
commit c8d7d3ea2a
4 changed files with 30 additions and 14 deletions

View File

@@ -56,4 +56,4 @@ jobs:
- name: Test workflow by config
run: |
workflow_by_config examples/benchmarks/GBDT/workflow_config_gbdt.yaml
qrun examples/benchmarks/GBDT/workflow_config_gbdt.yaml

View File

@@ -25,6 +25,19 @@ class RecordTemp:
backtest in a certain format.
"""
artifact_path = None
@classmethod
def get_path(cls, path=None):
names = []
if cls.artifact_path is not None:
names.append(cls.artifact_path)
if path is not None:
names.append(path)
return "/".join(names)
def __init__(self, recorder):
self.recorder = recorder
@@ -79,7 +92,7 @@ class RecordTemp:
artifacts = set(self.recorder.list_artifacts())
if parent:
# Downcasting have to be done here instead of using `super`
flist = self.__class__.__base__.list(self)
flist = self.__class__.__base__.list(self) # pylint: disable=E1101
else:
flist = self.list()
for item in flist:
@@ -132,10 +145,12 @@ class SignalRecord(RecordTemp):
class SigAnaRecord(SignalRecord):
artifact_path = "sig_analysis"
def __init__(self, recorder, **kwargs):
super().__init__(recorder=recorder, **kwargs)
# The name must be unique. Otherwise it will be overridden
self.artifact_path_sig = "sig_analysis"
def generate(self):
self.check(parent=True)
@@ -150,11 +165,11 @@ class SigAnaRecord(SignalRecord):
"Rank ICIR": ric.mean() / ric.std(),
}
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**{"ic.pkl": ic, "ric.pkl": ric}, artifact_path=self.artifact_path_sig)
self.recorder.save_objects(**{"ic.pkl": ic, "ric.pkl": ric}, artifact_path=self.get_path())
pprint(metrics)
def list(self):
return ["{self.artifact_path_sig}/ic.pkl", "{self.artifact_path_sig}/ric.pkl"]
return [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
class PortAnaRecord(SignalRecord):
@@ -162,6 +177,8 @@ class PortAnaRecord(SignalRecord):
This is the Portfolio Analysis Record class that generates the results such as those of backtest.
"""
artifact_path = "portfolio_analysis"
def __init__(self, recorder, config, **kwargs):
"""
config["strategy"] : dict
@@ -174,7 +191,6 @@ class PortAnaRecord(SignalRecord):
self.strategy_config = config["strategy"]
self.backtest_config = config["backtest"]
self.strategy = init_instance_by_config(self.strategy_config)
self.artifact_path_port = "portfolio_analysis"
def generate(self, **kwargs):
# check previously stored prediction results
@@ -183,8 +199,8 @@ class PortAnaRecord(SignalRecord):
# custom strategy and get backtest
pred_score = super().load()
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=self.artifact_path_port)
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=self.artifact_path_port)
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
# analysis
analysis = dict()
@@ -197,7 +213,7 @@ class PortAnaRecord(SignalRecord):
# log metrics
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
# save results
self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=self.artifact_path_port)
self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path())
logger.info(
f"Portfolio analysis record 'port_analysis.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
@@ -209,7 +225,7 @@ class PortAnaRecord(SignalRecord):
def list(self):
return [
f"{self.artifact_path_port}/report_normal.pkl",
f"{self.artifact_path_port}/positions_normal.pkl",
f"{self.artifact_path_port}/port_analysis.pkl",
PortAnaRecord.get_path("report_normal.pkl"),
PortAnaRecord.get_path("positions_normal.pkl"),
PortAnaRecord.get_path("port_analysis.pkl"),
]

View File

@@ -101,7 +101,7 @@ setup(
entry_points={
# 'console_scripts': ['mycli=mymodule:cli'],
"console_scripts": [
"workflow_by_config=qlib.workflow.cli:run",
"qrun=qlib.workflow.cli:run",
],
},
ext_modules=extensions,

View File

@@ -147,7 +147,7 @@ def backtest_analysis(pred, rid):
# backtest
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()
analysis_df = par.load("port_analysis.pkl")
analysis_df = par.load(par.get_path("port_analysis.pkl"))
print(analysis_df)
return analysis_df