diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py index 981e3c07a..8bf7dedb7 100644 --- a/qlib/contrib/backtest/account.py +++ b/qlib/contrib/backtest/account.py @@ -214,6 +214,7 @@ class Account: # finish today's updation # reset the daily variables + self.rtn = 0 self.ct = 0 self.to = 0 diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index 2bc349be3..a7e009a9a 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -9,7 +9,7 @@ from .account import Account def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account): - trade_account = Account(init_cash=account) + trade_account = Account(init_cash=account, benchmark=benchmark, start_time=start_time, end_time=end_time) trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account) trade_strategy.reset(start_time=start_time, end_time=end_time) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 51a9a305c..3d7188bcc 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import re +import warnings import pandas as pd from pathlib import Path from pprint import pprint @@ -223,18 +224,23 @@ class PortAnaRecord(SignalRecord): artifact_path = "portfolio_analysis" - def __init__(self, recorder, config, **kwargs): + def __init__(self, recorder, config, risk_analysis_dep, **kwargs): """ config["strategy"] : dict define the strategy class as well as the kwargs. + config["env"] : dict + define the env class as well as the kwargs. config["backtest"] : dict define the backtest kwargs. + risk_analysis_dep : int + risk analyze the dep'th env report """ super().__init__(recorder=recorder, **kwargs) self.strategy_config = config["strategy"] self.env_config = config["env"] self.backtest_config = config["backtest"] + self.risk_analysis_dep = risk_analysis_dep def generate(self, **kwargs): # check previously stored prediction results @@ -245,31 +251,34 @@ class PortAnaRecord(SignalRecord): # custom strategy and get backtest report_list = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config) - for report_id, (report_normal, positions_normal) in enumerate(report_list): + for report_dep, (report_normal, positions_normal) in enumerate(report_list): if report_dict is None: + if self.risk_analysis_dep == report_dep: + warnings.warn(f"the report in dep {risk_analysis_dep} is None, please set the corresponding env with `generate_report==True`") continue - - self.recorder.save_objects(**{f"report_normal_{report_id}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()) - self.recorder.save_objects(**{f"positions_norma_{report_id}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()) + + self.recorder.save_objects(**{f"report_normal_{report_dep}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()) + self.recorder.save_objects(**{f"positions_norma_{report_dep}l.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()) # analysis - analysis = dict() - analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"]) - analysis["excess_return_with_cost"] = risk_analysis( - report_normal["return"] - report_normal["bench"] - report_normal["cost"] - ) - analysis_df = pd.concat(analysis) # type: pd.DataFrame - # log metrics - self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict())) - # save results - self.recorder.save_objects(**{f"port_analysis.pkl_{report_id}": analysis_df}, artifact_path=PortAnaRecord.get_path()) - logger.info( - f"Portfolio analysis record 'port_analysis_{report_id}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" - ) - # print out results - pprint("The following are analysis results of the excess return without cost.") - pprint(analysis["excess_return_without_cost"]) - pprint("The following are analysis results of the excess return with cost.") - pprint(analysis["excess_return_with_cost"]) + self.risk_analysis_dep == report_dep: + analysis = dict() + analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"]) + analysis["excess_return_with_cost"] = risk_analysis( + report_normal["return"] - report_normal["bench"] - report_normal["cost"] + ) + analysis_df = pd.concat(analysis) # type: pd.DataFrame + # log metrics + self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict())) + # save results + self.recorder.save_objects(**{f"port_analysis.pkl_{report_dep}": analysis_df}, artifact_path=PortAnaRecord.get_path()) + logger.info( + f"Portfolio analysis record 'port_analysis_{report_dep}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" + ) + # print out results + pprint("The following are analysis results of the excess return without cost.") + pprint(analysis["excess_return_without_cost"]) + pprint("The following are analysis results of the excess return with cost.") + pprint(analysis["excess_return_with_cost"]) def list(self): return [