1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00

update port_ana_record

This commit is contained in:
bxdd
2021-04-29 02:28:22 +08:00
parent 86a6f565e8
commit 49cdaf8f5d
3 changed files with 34 additions and 24 deletions

View File

@@ -214,6 +214,7 @@ class Account:
# finish today's updation
# reset the daily variables
self.rtn = 0
self.ct = 0
self.to = 0

View File

@@ -9,7 +9,7 @@ from .account import Account
def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account):
trade_account = Account(init_cash=account)
trade_account = Account(init_cash=account, benchmark=benchmark, start_time=start_time, end_time=end_time)
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
trade_strategy.reset(start_time=start_time, end_time=end_time)

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import re
import warnings
import pandas as pd
from pathlib import Path
from pprint import pprint
@@ -223,18 +224,23 @@ class PortAnaRecord(SignalRecord):
artifact_path = "portfolio_analysis"
def __init__(self, recorder, config, **kwargs):
def __init__(self, recorder, config, risk_analysis_dep, **kwargs):
"""
config["strategy"] : dict
define the strategy class as well as the kwargs.
config["env"] : dict
define the env class as well as the kwargs.
config["backtest"] : dict
define the backtest kwargs.
risk_analysis_dep : int
risk analyze the dep'th env report
"""
super().__init__(recorder=recorder, **kwargs)
self.strategy_config = config["strategy"]
self.env_config = config["env"]
self.backtest_config = config["backtest"]
self.risk_analysis_dep = risk_analysis_dep
def generate(self, **kwargs):
# check previously stored prediction results
@@ -245,31 +251,34 @@ class PortAnaRecord(SignalRecord):
# custom strategy and get backtest
report_list = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config)
for report_id, (report_normal, positions_normal) in enumerate(report_list):
for report_dep, (report_normal, positions_normal) in enumerate(report_list):
if report_dict is None:
if self.risk_analysis_dep == report_dep:
warnings.warn(f"the report in dep {risk_analysis_dep} is None, please set the corresponding env with `generate_report==True`")
continue
self.recorder.save_objects(**{f"report_normal_{report_id}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(**{f"positions_norma_{report_id}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(**{f"report_normal_{report_dep}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(**{f"positions_norma_{report_dep}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
# analysis
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
# log metrics
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
# save results
self.recorder.save_objects(**{f"port_analysis.pkl_{report_id}": analysis_df}, artifact_path=PortAnaRecord.get_path())
logger.info(
f"Portfolio analysis record 'port_analysis_{report_id}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
# print out results
pprint("The following are analysis results of the excess return without cost.")
pprint(analysis["excess_return_without_cost"])
pprint("The following are analysis results of the excess return with cost.")
pprint(analysis["excess_return_with_cost"])
self.risk_analysis_dep == report_dep:
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
# log metrics
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
# save results
self.recorder.save_objects(**{f"port_analysis.pkl_{report_dep}": analysis_df}, artifact_path=PortAnaRecord.get_path())
logger.info(
f"Portfolio analysis record 'port_analysis_{report_dep}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
# print out results
pprint("The following are analysis results of the excess return without cost.")
pprint(analysis["excess_return_without_cost"])
pprint("The following are analysis results of the excess return with cost.")
pprint(analysis["excess_return_with_cost"])
def list(self):
return [