mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix record_tmp bug
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -56,4 +56,4 @@ jobs:
|
||||
|
||||
- name: Test workflow by config
|
||||
run: |
|
||||
workflow_by_config examples/benchmarks/GBDT/workflow_config_gbdt.yaml
|
||||
qrun examples/benchmarks/GBDT/workflow_config_gbdt.yaml
|
||||
|
||||
@@ -25,6 +25,19 @@ class RecordTemp:
|
||||
backtest in a certain format.
|
||||
"""
|
||||
|
||||
artifact_path = None
|
||||
|
||||
@classmethod
|
||||
def get_path(cls, path=None):
|
||||
names = []
|
||||
if cls.artifact_path is not None:
|
||||
names.append(cls.artifact_path)
|
||||
|
||||
if path is not None:
|
||||
names.append(path)
|
||||
|
||||
return "/".join(names)
|
||||
|
||||
def __init__(self, recorder):
|
||||
self.recorder = recorder
|
||||
|
||||
@@ -79,7 +92,7 @@ class RecordTemp:
|
||||
artifacts = set(self.recorder.list_artifacts())
|
||||
if parent:
|
||||
# Downcasting have to be done here instead of using `super`
|
||||
flist = self.__class__.__base__.list(self)
|
||||
flist = self.__class__.__base__.list(self) # pylint: disable=E1101
|
||||
else:
|
||||
flist = self.list()
|
||||
for item in flist:
|
||||
@@ -132,10 +145,12 @@ class SignalRecord(RecordTemp):
|
||||
|
||||
|
||||
class SigAnaRecord(SignalRecord):
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
# The name must be unique. Otherwise it will be overridden
|
||||
self.artifact_path_sig = "sig_analysis"
|
||||
|
||||
def generate(self):
|
||||
self.check(parent=True)
|
||||
@@ -150,11 +165,11 @@ class SigAnaRecord(SignalRecord):
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**{"ic.pkl": ic, "ric.pkl": ric}, artifact_path=self.artifact_path_sig)
|
||||
self.recorder.save_objects(**{"ic.pkl": ic, "ric.pkl": ric}, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
return ["{self.artifact_path_sig}/ic.pkl", "{self.artifact_path_sig}/ric.pkl"]
|
||||
return [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
|
||||
|
||||
|
||||
class PortAnaRecord(SignalRecord):
|
||||
@@ -162,6 +177,8 @@ class PortAnaRecord(SignalRecord):
|
||||
This is the Portfolio Analysis Record class that generates the results such as those of backtest.
|
||||
"""
|
||||
|
||||
artifact_path = "portfolio_analysis"
|
||||
|
||||
def __init__(self, recorder, config, **kwargs):
|
||||
"""
|
||||
config["strategy"] : dict
|
||||
@@ -174,7 +191,6 @@ class PortAnaRecord(SignalRecord):
|
||||
self.strategy_config = config["strategy"]
|
||||
self.backtest_config = config["backtest"]
|
||||
self.strategy = init_instance_by_config(self.strategy_config)
|
||||
self.artifact_path_port = "portfolio_analysis"
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# check previously stored prediction results
|
||||
@@ -183,8 +199,8 @@ class PortAnaRecord(SignalRecord):
|
||||
# custom strategy and get backtest
|
||||
pred_score = super().load()
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=self.artifact_path_port)
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=self.artifact_path_port)
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
@@ -197,7 +213,7 @@ class PortAnaRecord(SignalRecord):
|
||||
# log metrics
|
||||
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
|
||||
# save results
|
||||
self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=self.artifact_path_port)
|
||||
self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path())
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -209,7 +225,7 @@ class PortAnaRecord(SignalRecord):
|
||||
|
||||
def list(self):
|
||||
return [
|
||||
f"{self.artifact_path_port}/report_normal.pkl",
|
||||
f"{self.artifact_path_port}/positions_normal.pkl",
|
||||
f"{self.artifact_path_port}/port_analysis.pkl",
|
||||
PortAnaRecord.get_path("report_normal.pkl"),
|
||||
PortAnaRecord.get_path("positions_normal.pkl"),
|
||||
PortAnaRecord.get_path("port_analysis.pkl"),
|
||||
]
|
||||
|
||||
2
setup.py
2
setup.py
@@ -101,7 +101,7 @@ setup(
|
||||
entry_points={
|
||||
# 'console_scripts': ['mycli=mymodule:cli'],
|
||||
"console_scripts": [
|
||||
"workflow_by_config=qlib.workflow.cli:run",
|
||||
"qrun=qlib.workflow.cli:run",
|
||||
],
|
||||
},
|
||||
ext_modules=extensions,
|
||||
|
||||
@@ -147,7 +147,7 @@ def backtest_analysis(pred, rid):
|
||||
# backtest
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par.generate()
|
||||
analysis_df = par.load("port_analysis.pkl")
|
||||
analysis_df = par.load(par.get_path("port_analysis.pkl"))
|
||||
print(analysis_df)
|
||||
return analysis_df
|
||||
|
||||
|
||||
Reference in New Issue
Block a user