mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Add multi pass portfolio analysis record (#1546)
* Add multi pass port ana record * Add list function * Add documentation and support <MODEL> tag * Add drop in replacement example * reformat * Change according to comments * update format * Update record_temp.py Fix type hint * Update record_temp.py
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LinearModel
|
||||
module_path: qlib.contrib.model.linear
|
||||
kwargs:
|
||||
estimator: ols
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: True
|
||||
ann_scaler: 252
|
||||
- class: MultiPassPortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -4,8 +4,10 @@
|
||||
import logging
|
||||
import warnings
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from pprint import pprint
|
||||
from typing import Union, List, Optional
|
||||
from typing import Union, List, Optional, Dict
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from ..contrib.evaluate import risk_analysis, indicator_analysis
|
||||
@@ -17,6 +19,7 @@ from ..log import get_module_logger
|
||||
from ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_shift
|
||||
from ..utils.time import Freq
|
||||
from ..utils.data import deepcopy_basic_type
|
||||
from ..utils.exceptions import QlibException
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
|
||||
@@ -230,9 +233,16 @@ class ACRecordTemp(RecordTemp):
|
||||
except FileNotFoundError:
|
||||
logger.warning("The dependent data does not exists. Generation skipped.")
|
||||
return
|
||||
return self._generate(*args, **kwargs)
|
||||
artifact_dict = self._generate(*args, **kwargs)
|
||||
if isinstance(artifact_dict, dict):
|
||||
self.save(**artifact_dict)
|
||||
return artifact_dict
|
||||
|
||||
def _generate(self, *args, **kwargs):
|
||||
def _generate(self, *args, **kwargs) -> Dict[str, object]:
|
||||
"""
|
||||
Run the concrete generating task, return the dictionary of the generated results.
|
||||
The caller method will save the results to the recorder.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_generate` method")
|
||||
|
||||
|
||||
@@ -336,8 +346,8 @@ class SigAnaRecord(ACRecordTemp):
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.save(**objects)
|
||||
pprint(metrics)
|
||||
return objects
|
||||
|
||||
def list(self):
|
||||
paths = ["ic.pkl", "ric.pkl"]
|
||||
@@ -468,17 +478,18 @@ class PortAnaRecord(ACRecordTemp):
|
||||
if self.backtest_config["end_time"] is None:
|
||||
self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1)
|
||||
|
||||
artifact_objects = {}
|
||||
# custom strategy and get backtest
|
||||
portfolio_metric_dict, indicator_dict = normal_backtest(
|
||||
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
|
||||
)
|
||||
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
|
||||
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
|
||||
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
|
||||
artifact_objects.update({f"report_normal_{_freq}.pkl": report_normal})
|
||||
artifact_objects.update({f"positions_normal_{_freq}.pkl": positions_normal})
|
||||
|
||||
for _freq, indicators_normal in indicator_dict.items():
|
||||
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
|
||||
self.save(**{f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
|
||||
artifact_objects.update({f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
|
||||
artifact_objects.update({f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq not in portfolio_metric_dict:
|
||||
@@ -500,7 +511,7 @@ class PortAnaRecord(ACRecordTemp):
|
||||
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
artifact_objects.update({f"port_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -525,12 +536,13 @@ class PortAnaRecord(ACRecordTemp):
|
||||
analysis_dict = analysis_df["value"].to_dict()
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
artifact_objects.update({f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
pprint(f"The following are analysis results of indicators({_analysis_freq}).")
|
||||
pprint(analysis_df)
|
||||
return artifact_objects
|
||||
|
||||
def list(self):
|
||||
list_path = []
|
||||
@@ -553,3 +565,124 @@ class PortAnaRecord(ACRecordTemp):
|
||||
else:
|
||||
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
|
||||
return list_path
|
||||
|
||||
|
||||
class MultiPassPortAnaRecord(PortAnaRecord):
|
||||
"""
|
||||
This is the Multiple Pass Portfolio Analysis Record class that run backtest multiple times and generates the analysis results such as those of backtest. This class inherits the ``PortAnaRecord`` class.
|
||||
|
||||
If shuffle_init_score enabled, the prediction score of the first backtest date will be shuffled, so that initial position will be random.
|
||||
The shuffle_init_score will only works when the signal is used as <PRED> placeholder. The placeholder will be replaced by pred.pkl saved in recorder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorder : Recorder
|
||||
The recorder used to save the backtest results.
|
||||
pass_num : int
|
||||
The number of backtest passes.
|
||||
shuffle_init_score : bool
|
||||
Whether to shuffle the prediction score of the first backtest date.
|
||||
"""
|
||||
|
||||
depend_cls = SignalRecord
|
||||
|
||||
def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
recorder : Recorder
|
||||
The recorder used to save the backtest results.
|
||||
pass_num : int
|
||||
The number of backtest passes.
|
||||
shuffle_init_score : bool
|
||||
Whether to shuffle the prediction score of the first backtest date.
|
||||
"""
|
||||
self.pass_num = pass_num
|
||||
self.shuffle_init_score = shuffle_init_score
|
||||
|
||||
super().__init__(recorder, **kwargs)
|
||||
|
||||
# Save original strategy so that pred df can be replaced in next generate
|
||||
self.original_strategy = deepcopy_basic_type(self.strategy_config)
|
||||
if not isinstance(self.original_strategy, dict):
|
||||
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to be a dict")
|
||||
if "signal" not in self.original_strategy.get("kwargs", {}):
|
||||
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter")
|
||||
|
||||
def random_init(self):
|
||||
pred_df = self.load("pred.pkl")
|
||||
|
||||
all_pred_dates = pred_df.index.get_level_values("datetime")
|
||||
bt_start_date = pd.to_datetime(self.backtest_config.get("start_time"))
|
||||
if bt_start_date is None:
|
||||
first_bt_pred_date = all_pred_dates.min()
|
||||
else:
|
||||
first_bt_pred_date = all_pred_dates[all_pred_dates >= bt_start_date].min()
|
||||
|
||||
# Shuffle the first backtest date's pred score
|
||||
first_date_score = pred_df.loc[first_bt_pred_date]["score"]
|
||||
np.random.shuffle(first_date_score.values)
|
||||
|
||||
# Use shuffled signal as the strategy signal
|
||||
self.strategy_config = deepcopy_basic_type(self.original_strategy)
|
||||
self.strategy_config["kwargs"]["signal"] = pred_df
|
||||
|
||||
def _generate(self, **kwargs):
|
||||
risk_analysis_df_map = {}
|
||||
|
||||
# Collect each frequency's analysis df as df list
|
||||
for i in trange(self.pass_num):
|
||||
if self.shuffle_init_score:
|
||||
self.random_init()
|
||||
|
||||
# Not check for cache file list
|
||||
single_run_artifacts = super()._generate(**kwargs)
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, [])
|
||||
risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list
|
||||
|
||||
analysis_df = single_run_artifacts[f"port_analysis_{_analysis_freq}.pkl"]
|
||||
analysis_df["run_id"] = i
|
||||
risk_analysis_df_list.append(analysis_df)
|
||||
|
||||
result_artifacts = {}
|
||||
# Concat df list
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
combined_df = pd.concat(risk_analysis_df_map[_analysis_freq])
|
||||
|
||||
# Calculate return and information ratio's mean, std and mean/std
|
||||
multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1]).apply(
|
||||
lambda x: pd.Series(
|
||||
{"mean": x["risk"].mean(), "std": x["risk"].std(), "mean_std": x["risk"].mean() / x["risk"].std()}
|
||||
)
|
||||
)
|
||||
|
||||
# Only look at "annualized_return" and "information_ratio"
|
||||
multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[
|
||||
(slice(None), ["annualized_return", "information_ratio"]), :
|
||||
]
|
||||
pprint(multi_pass_port_analysis_df)
|
||||
|
||||
# Save new df
|
||||
result_artifacts.update({f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df})
|
||||
|
||||
# Log metrics
|
||||
metrics = flatten_dict(
|
||||
{
|
||||
"mean": multi_pass_port_analysis_df["mean"].unstack().T.to_dict(),
|
||||
"std": multi_pass_port_analysis_df["std"].unstack().T.to_dict(),
|
||||
"mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(),
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
return result_artifacts
|
||||
|
||||
def list(self):
|
||||
list_path = []
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(f"multi_pass_port_analysis_{_analysis_freq}.pkl")
|
||||
else:
|
||||
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
|
||||
return list_path
|
||||
|
||||
Reference in New Issue
Block a user