1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Add multi pass portfolio analysis record (#1546)

* Add multi pass port ana record

* Add list function

* Add documentation and support <MODEL> tag

* Add drop in replacement example

* reformat

* Change according to comments

* update format

* Update record_temp.py

Fix type hint

* Update record_temp.py
This commit is contained in:
Di
2023-08-04 17:41:12 +08:00
committed by GitHub
parent 38edac5069
commit 05d67b3828
2 changed files with 221 additions and 10 deletions

View File

@@ -0,0 +1,78 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LinearModel
module_path: qlib.contrib.model.linear
kwargs:
estimator: ols
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: True
ann_scaler: 252
- class: MultiPassPortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -4,8 +4,10 @@
import logging
import warnings
import pandas as pd
import numpy as np
from tqdm import trange
from pprint import pprint
from typing import Union, List, Optional
from typing import Union, List, Optional, Dict
from qlib.utils.exceptions import LoadObjectError
from ..contrib.evaluate import risk_analysis, indicator_analysis
@@ -17,6 +19,7 @@ from ..log import get_module_logger
from ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_shift
from ..utils.time import Freq
from ..utils.data import deepcopy_basic_type
from ..utils.exceptions import QlibException
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
@@ -230,9 +233,16 @@ class ACRecordTemp(RecordTemp):
except FileNotFoundError:
logger.warning("The dependent data does not exists. Generation skipped.")
return
return self._generate(*args, **kwargs)
artifact_dict = self._generate(*args, **kwargs)
if isinstance(artifact_dict, dict):
self.save(**artifact_dict)
return artifact_dict
def _generate(self, *args, **kwargs):
def _generate(self, *args, **kwargs) -> Dict[str, object]:
"""
Run the concrete generating task, return the dictionary of the generated results.
The caller method will save the results to the recorder.
"""
raise NotImplementedError(f"Please implement the `_generate` method")
@@ -336,8 +346,8 @@ class SigAnaRecord(ACRecordTemp):
}
)
self.recorder.log_metrics(**metrics)
self.save(**objects)
pprint(metrics)
return objects
def list(self):
paths = ["ic.pkl", "ric.pkl"]
@@ -468,17 +478,18 @@ class PortAnaRecord(ACRecordTemp):
if self.backtest_config["end_time"] is None:
self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1)
artifact_objects = {}
# custom strategy and get backtest
portfolio_metric_dict, indicator_dict = normal_backtest(
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
)
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
artifact_objects.update({f"report_normal_{_freq}.pkl": report_normal})
artifact_objects.update({f"positions_normal_{_freq}.pkl": positions_normal})
for _freq, indicators_normal in indicator_dict.items():
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
self.save(**{f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
artifact_objects.update({f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
artifact_objects.update({f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq not in portfolio_metric_dict:
@@ -500,7 +511,7 @@ class PortAnaRecord(ACRecordTemp):
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
artifact_objects.update({f"port_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
@@ -525,12 +536,13 @@ class PortAnaRecord(ACRecordTemp):
analysis_dict = analysis_df["value"].to_dict()
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
artifact_objects.update({f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
pprint(f"The following are analysis results of indicators({_analysis_freq}).")
pprint(analysis_df)
return artifact_objects
def list(self):
list_path = []
@@ -553,3 +565,124 @@ class PortAnaRecord(ACRecordTemp):
else:
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
return list_path
class MultiPassPortAnaRecord(PortAnaRecord):
"""
This is the Multiple Pass Portfolio Analysis Record class that run backtest multiple times and generates the analysis results such as those of backtest. This class inherits the ``PortAnaRecord`` class.
If shuffle_init_score enabled, the prediction score of the first backtest date will be shuffled, so that initial position will be random.
The shuffle_init_score will only works when the signal is used as <PRED> placeholder. The placeholder will be replaced by pred.pkl saved in recorder.
Parameters
----------
recorder : Recorder
The recorder used to save the backtest results.
pass_num : int
The number of backtest passes.
shuffle_init_score : bool
Whether to shuffle the prediction score of the first backtest date.
"""
depend_cls = SignalRecord
def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs):
"""
Parameters
----------
recorder : Recorder
The recorder used to save the backtest results.
pass_num : int
The number of backtest passes.
shuffle_init_score : bool
Whether to shuffle the prediction score of the first backtest date.
"""
self.pass_num = pass_num
self.shuffle_init_score = shuffle_init_score
super().__init__(recorder, **kwargs)
# Save original strategy so that pred df can be replaced in next generate
self.original_strategy = deepcopy_basic_type(self.strategy_config)
if not isinstance(self.original_strategy, dict):
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to be a dict")
if "signal" not in self.original_strategy.get("kwargs", {}):
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter")
def random_init(self):
pred_df = self.load("pred.pkl")
all_pred_dates = pred_df.index.get_level_values("datetime")
bt_start_date = pd.to_datetime(self.backtest_config.get("start_time"))
if bt_start_date is None:
first_bt_pred_date = all_pred_dates.min()
else:
first_bt_pred_date = all_pred_dates[all_pred_dates >= bt_start_date].min()
# Shuffle the first backtest date's pred score
first_date_score = pred_df.loc[first_bt_pred_date]["score"]
np.random.shuffle(first_date_score.values)
# Use shuffled signal as the strategy signal
self.strategy_config = deepcopy_basic_type(self.original_strategy)
self.strategy_config["kwargs"]["signal"] = pred_df
def _generate(self, **kwargs):
risk_analysis_df_map = {}
# Collect each frequency's analysis df as df list
for i in trange(self.pass_num):
if self.shuffle_init_score:
self.random_init()
# Not check for cache file list
single_run_artifacts = super()._generate(**kwargs)
for _analysis_freq in self.risk_analysis_freq:
risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, [])
risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list
analysis_df = single_run_artifacts[f"port_analysis_{_analysis_freq}.pkl"]
analysis_df["run_id"] = i
risk_analysis_df_list.append(analysis_df)
result_artifacts = {}
# Concat df list
for _analysis_freq in self.risk_analysis_freq:
combined_df = pd.concat(risk_analysis_df_map[_analysis_freq])
# Calculate return and information ratio's mean, std and mean/std
multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1]).apply(
lambda x: pd.Series(
{"mean": x["risk"].mean(), "std": x["risk"].std(), "mean_std": x["risk"].mean() / x["risk"].std()}
)
)
# Only look at "annualized_return" and "information_ratio"
multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[
(slice(None), ["annualized_return", "information_ratio"]), :
]
pprint(multi_pass_port_analysis_df)
# Save new df
result_artifacts.update({f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df})
# Log metrics
metrics = flatten_dict(
{
"mean": multi_pass_port_analysis_df["mean"].unstack().T.to_dict(),
"std": multi_pass_port_analysis_df["std"].unstack().T.to_dict(),
"mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(),
}
)
self.recorder.log_metrics(**metrics)
return result_artifacts
def list(self):
list_path = []
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq in self.all_freq:
list_path.append(f"multi_pass_port_analysis_{_analysis_freq}.pkl")
else:
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
return list_path