diff --git a/CHANGES.rst b/CHANGES.rst index 114d577f3..3daa1e8e6 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -159,6 +159,21 @@ Version 0.5.0 - Add baselines - public data crawler -Version greater than Version 0.5.0 + +Version 0.8.0 +-------------------- +- The backtest is greatly refactored. + - Nested decision execution framework is supported + - There are lots of changes for daily trading, it is hard to list all of them. But a few important changes could be noticed + - The trading limitation is more accurate; + - In `previous version `_, longing and shorting actions share the same action. + - In `current verison `_, the trading limitation is different between loging and shorting action. + - The constant is different when calculating annualized metrics. + - `Current version `_ uses more accurate constant than `previous version `_ + - `A new version `_ of data is released. Due to the unstability of Yahoo data source, the data may be different after downloading data again. + - Users could chec kout the backtesting results between `Current version `_ and `previous version `_ + + +Other Versions ---------------------------------- Please refer to `Github release Notes `_ diff --git a/docs/component/workflow.rst b/docs/component/workflow.rst index 84522af99..1b15212ac 100644 --- a/docs/component/workflow.rst +++ b/docs/component/workflow.rst @@ -53,6 +53,9 @@ Below is a typical config file of ``qrun``. kwargs: topk: 50 n_drop: 5 + signal: + - + - backtest: limit_threshold: 0.095 account: 100000000 @@ -240,6 +243,9 @@ The following script is the configuration of `backtest` and the `strategy` used kwargs: topk: 50 n_drop: 5 + signal: + - + - backtest: limit_threshold: 0.095 account: 100000000 diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml index 039040d8f..a8e89e360 100755 --- a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml +++ b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml @@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml index 88c6fcd07..3aa8147fc 100644 --- a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml +++ b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -86,4 +87,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml index 18e19bd0f..2eb642741 100644 --- a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml +++ b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml @@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml index a6cdd1882..982963eea 100644 --- a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml +++ b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml @@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml index fb8cce74d..12da23171 100644 --- a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml +++ b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml @@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml index d1fbd7807..d9481f12d 100644 --- a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml +++ b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml @@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -100,4 +101,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml b/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml index 5387adc24..e056bc845 100644 --- a/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml +++ b/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml @@ -35,8 +35,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -94,4 +95,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml b/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml index 1ffd6780e..2effecd61 100644 --- a/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml +++ b/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml b/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml index 82c690889..7c525c12a 100755 --- a/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml +++ b/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml @@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml b/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml index 02c81c850..2daaa0136 100644 --- a/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml +++ b/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -85,4 +86,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml index f4412c262..bf3738bc0 100755 --- a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml +++ b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml @@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml index 10a1dc5df..d550cacb2 100644 --- a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml +++ b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -85,4 +86,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml index 2bb21d41d..2d441dea9 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml @@ -14,7 +14,7 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: + model: dataset: topk: 50 n_drop: 5 diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml index 46b5c0f80..3d0a7859c 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml @@ -33,6 +33,9 @@ port_analysis_config: &port_analysis_config kwargs: topk: 50 n_drop: 5 + signal: + - + - backtest: verbose: False limit_threshold: 0.095 @@ -80,4 +83,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml index b8af19ec1..053c5bd29 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml @@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -76,4 +77,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml index a92f342a1..f1ffc45da 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml @@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml index 89fbcb153..20cf7de6e 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml @@ -31,8 +31,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml b/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml index 9f055a62c..c4e4d8e21 100644 --- a/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml +++ b/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml index cd31ecd1e..7f5a78e74 100644 --- a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml +++ b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml @@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml index f9cc091fd..9de80a350 100644 --- a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml +++ b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml b/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml index 8303f3945..b0f95e696 100644 --- a/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml +++ b/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml @@ -41,8 +41,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -98,4 +99,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml b/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml index f52c5930d..053dd455a 100644 --- a/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml +++ b/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml @@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -85,4 +86,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/README.md b/examples/benchmarks/README.md index 0088128df..1a7d2fc26 100644 --- a/examples/benchmarks/README.md +++ b/examples/benchmarks/README.md @@ -13,6 +13,10 @@ The numbers shown below demonstrate the performance of the entire `workflow` of > > In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ --> +> NOTE: +> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference. + + ## Alpha158 dataset | Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown | diff --git a/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml b/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml index 5c66400bb..d750a9980 100644 --- a/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml +++ b/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml b/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml index 484ed45b1..9e0e735d1 100644 --- a/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml +++ b/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml @@ -30,8 +30,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -94,4 +95,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml b/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml index 0508ce676..d83878e3e 100644 --- a/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml +++ b/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml @@ -16,8 +16,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml index f273f62ee..c86f87fc6 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml @@ -57,8 +57,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml index 8dc82cb99..75f18f3ee 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml @@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml index bd5b132ee..9ab5b904b 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml @@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml index 1d1c7da1c..d9b94e86c 100644 --- a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml +++ b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml index 3d11efe60..830943d6b 100644 --- a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml +++ b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml index 6174abf2e..e36d44c43 100644 --- a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml +++ b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml @@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml index 883c18cdc..cab46a4d4 100644 --- a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml +++ b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml index 502a5e73c..5ee38cf70 100644 --- a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml +++ b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml @@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml index a2e40eefb..7c98bd40c 100644 --- a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml +++ b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml @@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py index ef6906018..d7f5fc813 100644 --- a/examples/nested_decision_execution/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -151,10 +151,9 @@ class NestedDecisionExecutionWorkflow: self._train_model(model, dataset) strategy_config = { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { - "model": model, - "dataset": dataset, + "signal": (model, dataset), "topk": 50, "n_drop": 5, }, @@ -189,10 +188,9 @@ class NestedDecisionExecutionWorkflow: backtest_config["benchmark"] = self.benchmark strategy_config = { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { - "model": model, - "dataset": dataset, + "signal": (model, dataset), "topk": 50, "n_drop": 5, }, diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 486e694a7..7fd299338 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -31,10 +31,9 @@ if __name__ == "__main__": }, "strategy": { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { - "model": model, - "dataset": dataset, + "signal": (model, dataset), "topk": 50, "n_drop": 5, }, diff --git a/qlib/__init__.py b/qlib/__init__.py index 107819860..19a7e09af 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -152,8 +152,11 @@ def init_from_yaml_conf(conf_path, **kwargs): :param conf_path: A path to the qlib config in yml format """ - with open(conf_path) as f: - config = yaml.safe_load(f) + if conf_path is None: + config = {} + else: + with open(conf_path) as f: + config = yaml.safe_load(f) config.update(kwargs) default_conf = config.pop("default_conf", "client") init(default_conf, **config) @@ -216,7 +219,7 @@ def auto_init(**kwargs): .. code-block:: yaml conf_type: ref - qlib_cfg: '' + qlib_cfg: '' # this could be null reference no config from other files # following configs in `qlib_cfg_update` is project=specific qlib_cfg_update: exp_manager: @@ -246,6 +249,7 @@ def auto_init(**kwargs): except FileNotFoundError: init(**kwargs) else: + logger = get_module_logger("Initialization") conf_pp = pp / "config.yaml" with conf_pp.open() as f: conf = yaml.safe_load(f) @@ -259,8 +263,14 @@ def auto_init(**kwargs): # - There is a shared configure file and you don't want to edit it inplace. # - The shared configure may be updated later and you don't want to copy it. # - You have some customized config. - qlib_conf_path = conf["qlib_cfg"] - qlib_conf_update = conf.get("qlib_cfg_update") - init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs) - logger = get_module_logger("Initialization") + qlib_conf_path = conf.get("qlib_cfg", None) + + # merge the arguments + qlib_conf_update = conf.get("qlib_cfg_update", {}) + for k, v in kwargs.items(): + if k in qlib_conf_update: + logger.warning(f"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'") + qlib_conf_update.update(kwargs) + + init_from_yaml_conf(qlib_conf_path, **qlib_conf_update) logger.info(f"Auto load project config: {conf_pp}") diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 9e40e1877..cc88528fd 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -34,6 +34,7 @@ class Exchange: open_cost=0.0015, close_cost=0.0025, min_cost=5, + impact_cost=0.0, extra_quote=None, quote_cls=NumpyQuote, **kwargs, @@ -95,6 +96,7 @@ class Exchange: **NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must distinguish `not set` and `disable trade_unit` :param min_cost: min cost, default 5 + :param impact_cost: market impact cost rate (a.k.a. slippage). A recommended value is 0.1. :param extra_quote: pandas, dataframe consists of columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy']. The limit indicates that the etf is tradable on a specific day. @@ -164,9 +166,12 @@ class Exchange: all_fields = list(all_fields | set(subscribe_fields)) self.all_fields = all_fields + self.open_cost = open_cost self.close_cost = close_cost self.min_cost = min_cost + self.impact_cost = impact_cost + self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold self.volume_threshold = volume_threshold self.extra_quote = extra_quote @@ -685,12 +690,14 @@ class Exchange: f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}" ) - def _get_buy_amount_by_cash_limit(self, trade_price, cash): + def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio): """return the real order amount after cash limit for buying. Parameters ---------- trade_price : float position : cash + cost_ratio : float + Return ---------- float @@ -699,10 +706,10 @@ class Exchange: max_trade_amount = 0 if cash >= self.min_cost: # critical_price means the stock transaction price when the service fee is equal to min_cost. - critical_price = self.min_cost / self.open_cost + self.min_cost + critical_price = self.min_cost / cost_ratio + self.min_cost if cash >= critical_price: - # the service fee is equal to open_cost * trade_amount - max_trade_amount = cash / (1 + self.open_cost) / trade_price + # the service fee is equal to cost_ratio * trade_amount + max_trade_amount = cash / (1 + cost_ratio) / trade_price else: # the service fee is equal to min_cost max_trade_amount = (cash - self.min_cost) / trade_price @@ -718,6 +725,7 @@ class Exchange: :return: trade_price, trade_val, trade_cost """ trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction) + total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time) order.deal_amount = order.amount # set to full amount and clip it step by step # Clipping amount first @@ -726,8 +734,12 @@ class Exchange: # - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit. self._clip_amount_by_volume(order, dealt_order_amount) + # TODO: the adjusted cost ratio can be overestimated as deal_amount will be clipped in the next steps + trade_val = order.deal_amount * trade_price + adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2 + if order.direction == Order.SELL: - cost_ratio = self.close_cost + cost_ratio = self.close_cost + adj_cost_ratio # sell # if we don't know current position, we choose to sell all # Otherwise, we clip the amount based on current position @@ -750,14 +762,18 @@ class Exchange: self.logger.debug(f"Order clipped due to cash limitation: {order}") elif order.direction == Order.BUY: - cost_ratio = self.open_cost + cost_ratio = self.open_cost + adj_cost_ratio # buy if position is not None: cash = position.get_cash() trade_val = order.deal_amount * trade_price - if cash < trade_val + max(trade_val * cost_ratio, self.min_cost): + if cash < max(trade_val * cost_ratio, self.min_cost): + # cash cannot cover cost + order.deal_amount = 0 + self.logger.debug(f"Order clipped due to cost higher than cash: {order}") + elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost): # The money is not enough - max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash) + max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio) order.deal_amount = self.round_amount_by_trade_unit( min(max_buy_amount, order.deal_amount), order.factor ) diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index 235bd054b..51847cac3 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -160,6 +160,11 @@ class NumpyQuote(BaseQuote): if is_single_value(start_time, end_time, self.freq, self.region): # this is a very special case. # skip aggregating function to speed-up the query calculation + + # FIXME: + # it will go to the else logic when it comes to the + # 1) the day before holiday when daily trading + # 2) the last minute of the day when intraday trading try: return self.data[stock_id].loc[start_time, field] except KeyError: diff --git a/qlib/backtest/signal.py b/qlib/backtest/signal.py new file mode 100644 index 000000000..a342a58be --- /dev/null +++ b/qlib/backtest/signal.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from qlib.utils import init_instance_by_config +from typing import Dict, List, Text, Tuple, Union +from ..model.base import BaseModel +from ..data.dataset import Dataset +from ..data.dataset.utils import convert_index_format +from ..utils.resam import resam_ts_data +import pandas as pd +import abc + + +class Signal(metaclass=abc.ABCMeta): + """ + Some trading strategy make decisions based on other prediction signals + The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset) + + This interface is tries to provide unified interface for those different sources + """ + + @abc.abstractmethod + def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]: + """ + get the signal at the end of the decision step(from `start_time` to `end_time`) + + Returns + ------- + Union[pd.Series, pd.DataFrame, None]: + returns None if no signal in the specific day + """ + ... + + +class SignalWCache(Signal): + """ + Signal With pandas with based Cache + SignalWCache will store the prepared signal as a attribute and give the according signal based on input query + """ + + def __init__(self, signal: Union[pd.Series, pd.DataFrame]): + """ + + Parameters + ---------- + signal : Union[pd.Series, pd.DataFrame] + The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted) + + instrument datetime + SH600000 2008-01-02 0.079704 + 2008-01-03 0.120125 + 2008-01-04 0.878860 + 2008-01-07 0.505539 + 2008-01-08 0.395004 + """ + self.signal_cache = convert_index_format(signal, level="datetime") + + def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]: + # the frequency of the signal may not algin with the decision frequency of strategy + # so resampling from the data is necessary + # the latest signal leverage more recent data and therefore is used in trading. + signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last") + return signal + + +class ModelSignal(SignalWCache): + def __init__(self, model: BaseModel, dataset: Dataset): + self.model = model + self.dataset = dataset + pred_scores = self.model.predict(dataset) + if isinstance(pred_scores, pd.DataFrame): + pred_scores = pred_scores.iloc[:, 0] + super().__init__(pred_scores) + + def _update_model(self): + """ + When using online data, update model in each bar as the following steps: + - update dataset with online data, the dataset should support online update + - make the latest prediction scores of the new bar + - update the pred score into the latest prediction + """ + # TODO: this method is not included in the framework and could be refactor later + raise NotImplementedError("_update_model is not implemented!") + + +def create_signal_from( + obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] +) -> Signal: + """ + create signal from diverse information + This method will choose the right method to create a signal based on `obj` + Please refer to the code below. + """ + if isinstance(obj, Signal): + return obj + elif isinstance(obj, (tuple, list)): + return ModelSignal(*obj) + elif isinstance(obj, (dict, str)): + return init_instance_by_config(obj) + elif isinstance(obj, (pd.DataFrame, pd.Series)): + return SignalWCache(signal=obj) + else: + raise NotImplementedError(f"This type of signal is not supported") diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 51130712d..5db7658b0 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -70,7 +70,7 @@ class TradeCalendarManager: - If self.trade_step >= self.self.trade_len, it means the trading is finished - If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step """ - return self.trade_step >= self.trade_len + return self.trade_step >= self.trade_len - 1 def step(self): if self.finished(): diff --git a/qlib/contrib/data/utils/__init__.py b/qlib/contrib/data/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qlib/contrib/data/utils/sepdf.py b/qlib/contrib/data/utils/sepdf.py new file mode 100644 index 000000000..58664c46c --- /dev/null +++ b/qlib/contrib/data/utils/sepdf.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import pandas as pd +from typing import Dict, Iterable + + +def align_index(df_dict, join): + res = {} + for k, df in df_dict.items(): + if join is not None and k != join: + df = df.reindex(df_dict[join].index) + res[k] = df + return res + + +# Mocking the pd.DataFrame class +class SepDataFrame: + """ + (Sep)erate DataFrame + We usually concat multiple dataframe to be processed together(Such as feature, label, weight, filter). + However, they are usally be used seperately at last. + This will result in extra cost for concating and spliting data(reshaping and copying data in the memory is very expensive) + + SepDataFrame tries to act like a DataFrame whose column with multiindex + """ + + def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False): + """ + initialize the data based on the dataframe dictionary + + Parameters + ---------- + df_dict : Dict[str, pd.DataFrame] + dataframe dictionary + join : str + how to join the data + It will reindex the dataframe based on the join key. + If join is None, the reindex step will be skipped + + skip_align : + for some cases, we can improve performance by skipping aligning index + """ + self.join = join + + if skip_align: + self._df_dict = df_dict + else: + self._df_dict = align_index(df_dict, join) + + @property + def loc(self): + return SDFLoc(self, join=self.join) + + @property + def index(self): + return self._df_dict[self.join].index + + def apply_each(self, method: str, skip_align=True, *args, **kwargs): + """ + Assumptions: + - inplace methods will return None + """ + inplace = False + df_dict = {} + for k, df in self._df_dict.items(): + df_dict[k] = getattr(df, method)(*args, **kwargs) + if df_dict[k] is None: + inplace = True + if not inplace: + return SepDataFrame(df_dict=df_dict, join=self.join, skip_align=skip_align) + + def sort_index(self, *args, **kwargs): + return self.apply_each("sort_index", True, *args, **kwargs) + + def copy(self, *args, **kwargs): + return self.apply_each("copy", True, *args, **kwargs) + + def _update_join(self): + if self.join not in self: + self.join = next(iter(self._df_dict.keys())) + + def __getitem__(self, item): + return self._df_dict[item] + + def __setitem__(self, item: str, df: pd.DataFrame): + # TODO: consider the join behavior + self._df_dict[item] = df + + def __delitem__(self, item: str): + del self._df_dict[item] + self._update_join() + + def __contains__(self, item): + return item in self._df_dict + + def __len__(self): + return len(self._df_dict[self.join]) + + def droplevel(self, *args, **kwargs): + raise NotImplementedError(f"Please implement the `droplevel` method") + + @property + def columns(self): + dfs = [] + for k, df in self._df_dict.items(): + df = df.head(0) + df.columns = pd.MultiIndex.from_product([[k], df.columns]) + dfs.append(df) + return pd.concat(dfs, axis=1).columns + + # Useless methods + @staticmethod + def merge(df_dict: Dict[str, pd.DataFrame], join: str): + all_df = df_dict[join] + for k, df in df_dict.items(): + if k != join: + all_df = all_df.join(df) + return all_df + + +class SDFLoc: + """Mock Class""" + + def __init__(self, sdf: SepDataFrame, join): + self._sdf = sdf + self.axis = None + self.join = join + + def __call__(self, axis): + self.axis = axis + return self + + def __getitem__(self, args): + if self.axis == 1: + if isinstance(args, str): + return self._sdf[args] + elif isinstance(args, (tuple, list)): + new_df_dict = {k: self._sdf[k] for k in args} + return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0], skip_align=True) + else: + raise NotImplementedError(f"This type of input is not supported") + elif self.axis == 0: + return SepDataFrame( + {k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join, skip_align=True + ) + else: + df = self._sdf + if isinstance(args, tuple): + ax0, *ax1 = args + if len(ax1) == 0: + ax1 = None + if ax1 is not None: + df = df.loc(axis=1)[ax1] + if ax0 is not None: + df = df.loc(axis=0)[ax0] + return df + else: + return df.loc(axis=0)[args] + + +# Patch pandas DataFrame +# Tricking isinstance to accept SepDataFrame as its subclass +import builtins + + +def _isinstance(instance, cls): + if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602 + if isinstance(cls, Iterable): + for c in cls: + if c is pd.DataFrame: + return True + elif cls is pd.DataFrame: + return True + return isinstance_orig(instance, cls) # pylint: disable=E0602 + + +builtins.isinstance_orig = builtins.isinstance +builtins.isinstance = _isinstance + +if __name__ == "__main__": + sdf = SepDataFrame({}, join=None) + print(isinstance(sdf, (pd.DataFrame,))) + print(isinstance(sdf, pd.DataFrame)) diff --git a/qlib/contrib/strategy/__init__.py b/qlib/contrib/strategy/__init__.py index e308c1a05..adc1679c1 100644 --- a/qlib/contrib/strategy/__init__.py +++ b/qlib/contrib/strategy/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. -from .model_strategy import ( +from .signal_strategy import ( TopkDropoutStrategy, WeightStrategyBase, ) diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py index b45c03ae9..aaebe3543 100644 --- a/qlib/contrib/strategy/cost_control.py +++ b/qlib/contrib/strategy/cost_control.py @@ -6,7 +6,7 @@ This strategy is not well maintained from .order_generator import OrderGenWInteract -from .model_strategy import WeightStrategyBase +from .signal_strategy import WeightStrategyBase import copy diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index eff938dd7..5dfef1510 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -80,18 +80,22 @@ class OrderGenWInteract(OrderGenerator): :rtype: list """ + if target_weight_position is None: + return [] + # calculate current_tradable_value current_amount_dict = current.get_stock_amount_dict() + current_total_value = trade_exchange.calculate_amount_position_value( amount_dict=current_amount_dict, - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, + start_time=trade_start_time, + end_time=trade_end_time, only_tradable=False, ) current_tradable_value = trade_exchange.calculate_amount_position_value( amount_dict=current_amount_dict, - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, + start_time=trade_start_time, + end_time=trade_end_time, only_tradable=True, ) # add cash @@ -105,9 +109,7 @@ class OrderGenWInteract(OrderGenerator): # value. Then just sell all the stocks target_amount_dict = copy.deepcopy(current_amount_dict.copy()) for stock_id in list(target_amount_dict.keys()): - if trade_exchange.is_stock_tradable( - stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time - ): + if trade_exchange.is_stock_tradable(stock_id, start_time=trade_start_time, end_time=trade_end_time): del target_amount_dict[stock_id] else: # consider cost rate @@ -118,16 +120,16 @@ class OrderGenWInteract(OrderGenerator): target_amount_dict = trade_exchange.generate_amount_position_from_weight_position( weight_position=target_weight_position, cash=current_tradable_value, - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, + start_time=trade_start_time, + end_time=trade_end_time, ) order_list = trade_exchange.generate_order_for_target_amount_position( target_position=target_amount_dict, current_position=current_amount_dict, - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, + start_time=trade_start_time, + end_time=trade_end_time, ) - return TradeDecisionWO(order_list, self) + return order_list class OrderGenWOInteract(OrderGenerator): @@ -163,8 +165,11 @@ class OrderGenWOInteract(OrderGenerator): :param trade_date: :type trade_date: pd.Timestamp - :rtype: list + :rtype: list of generated orders """ + if target_weight_position is None: + return [] + risk_total_value = risk_degree * current.calculate_value() current_stock = current.get_stock_list() @@ -172,13 +177,17 @@ class OrderGenWOInteract(OrderGenerator): for stock_id in target_weight_position: # Current rule will ignore the stock that not hold and cannot be traded at predict date if trade_exchange.is_stock_tradable( - stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time + stock_id=stock_id, start_time=trade_start_time, end_time=trade_end_time + ) and trade_exchange.is_stock_tradable( + stock_id=stock_id, start_time=pred_start_time, end_time=pred_end_time ): amount_dict[stock_id] = ( risk_total_value * target_weight_position[stock_id] - / trade_exchange.get_close(stock_id, trade_start_time=pred_start_time, trade_end_time=pred_end_time) + / trade_exchange.get_close(stock_id, start_time=pred_start_time, end_time=pred_end_time) ) + # TODO: Qlib use None to represent trading suspension. So last close price can't be the estimated trading price. + # Maybe a close price with forward fill will be a better solution. elif stock_id in current_stock: amount_dict[stock_id] = ( risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id) @@ -188,7 +197,7 @@ class OrderGenWOInteract(OrderGenerator): order_list = trade_exchange.generate_order_for_target_amount_position( target_position=amount_dict, current_position=current.get_stock_amount_dict(), - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, + start_time=trade_start_time, + end_time=trade_end_time, ) - return TradeDecisionWO(order_list, self) + return order_list diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 23fdd2991..dcf4667ff 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from pathlib import Path import warnings import numpy as np diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/signal_strategy.py similarity index 88% rename from qlib/contrib/strategy/model_strategy.py rename to qlib/contrib/strategy/signal_strategy.py index 1d22153a7..ae69b4bb6 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -1,27 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import copy +from qlib.backtest.signal import Signal, create_signal_from +from typing import Dict, List, Text, Tuple, Union +from qlib.data.dataset import Dataset +from qlib.model.base import BaseModel from qlib.backtest.position import Position import warnings import numpy as np import pandas as pd from ...utils.resam import resam_ts_data -from ...strategy.base import ModelStrategy +from ...strategy.base import BaseStrategy from ...backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO from .order_generator import OrderGenWInteract -class TopkDropoutStrategy(ModelStrategy): +class TopkDropoutStrategy(BaseStrategy): # TODO: # 1. Supporting leverage the get_range_limit result from the decision # 2. Supporting alter_outer_trade_decision # 3. Supporting checking the availability of trade decision def __init__( self, - model, - dataset, + *, topk, n_drop, + signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] = None, method_sell="bottom", method_buy="top", risk_degree=0.95, @@ -30,6 +36,8 @@ class TopkDropoutStrategy(ModelStrategy): trade_exchange=None, level_infra=None, common_infra=None, + model=None, + dataset=None, **kwargs, ): """ @@ -39,6 +47,9 @@ class TopkDropoutStrategy(ModelStrategy): the number of stocks in the portfolio. n_drop : int number of stocks to be replaced in each trading date. + signal : + the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from` + the decision of the strategy will base on the given signal method_sell : str dropout method_sell, random/bottom. method_buy : str @@ -64,7 +75,7 @@ class TopkDropoutStrategy(ModelStrategy): """ super(TopkDropoutStrategy, self).__init__( - model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs + level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs ) self.topk = topk self.n_drop = n_drop @@ -74,6 +85,13 @@ class TopkDropoutStrategy(ModelStrategy): self.hold_thresh = hold_thresh self.only_tradable = only_tradable + # This is trying to be compatible with previous version of qlib task config + if model is not None and dataset is not None: + warnings.warn("`model` `dataset` is deprecated; use `signal`.", DeprecationWarning) + signal = model, dataset + + self.signal: Signal = create_signal_from(signal) + def get_risk_degree(self, trade_step=None): """get_risk_degree Return the proportion of your total value you will used in investment. @@ -87,7 +105,7 @@ class TopkDropoutStrategy(ModelStrategy): trade_step = self.trade_calendar.get_trade_step() trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) - pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time) if pred_score is None: return TradeDecisionWO([], self) if self.only_tradable: @@ -235,15 +253,15 @@ class TopkDropoutStrategy(ModelStrategy): return TradeDecisionWO(sell_order_list + buy_order_list, self) -class WeightStrategyBase(ModelStrategy): +class WeightStrategyBase(BaseStrategy): # TODO: # 1. Supporting leverage the get_range_limit result from the decision # 2. Supporting alter_outer_trade_decision # 3. Supporting checking the availability of trade decision def __init__( self, - model, - dataset, + *, + signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame], order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, level_infra=None, @@ -251,6 +269,9 @@ class WeightStrategyBase(ModelStrategy): **kwargs, ): """ + signal : + the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from` + the decision of the strategy will base on the given signal trade_exchange : Exchange exchange that provides market info, used to deal order and generate report - If `trade_exchange` is None, self.trade_exchange will be set with common_infra @@ -260,13 +281,15 @@ class WeightStrategyBase(ModelStrategy): - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended. """ super(WeightStrategyBase, self).__init__( - model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs + level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs ) if isinstance(order_generator_cls_or_obj, type): self.order_generator = order_generator_cls_or_obj() else: self.order_generator = order_generator_cls_or_obj + self.signal: Signal = create_signal_from(signal) + def get_risk_degree(self, trade_step=None): """get_risk_degree Return the proportion of your total value you will used in investment. @@ -298,7 +321,7 @@ class WeightStrategyBase(ModelStrategy): trade_step = self.trade_calendar.get_trade_step() trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) - pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time) if pred_score is None: return TradeDecisionWO([], self) current_temp = copy.deepcopy(self.trade_position) diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index e7c80cf6e..8d10b2ab4 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -49,7 +49,7 @@ class MultiSegRecord(RecordTemp): if save: save_name = "results-{:}.pkl".format(key) - self.recorder.save_objects(**{save_name: results}) + self.save(**{save_name: results}) logger.info( "The record '{:}' has been saved as the artifact of the Experiment {:}".format( save_name, self.recorder.experiment_id @@ -79,9 +79,8 @@ class SignalMseRecord(RecordTemp): metrics = {"MSE": mse, "RMSE": np.sqrt(mse)} objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)} self.recorder.log_metrics(**metrics) - self.recorder.save_objects(**objects, artifact_path=self.get_path()) + self.save(**objects) logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics)) def list(self): - paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")] - return paths + return ["mse.pkl", "rmse.pkl"] diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 1002df8ba..46b90402d 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -320,6 +320,7 @@ class TSDataSampler: self.flt_data = flt_data.values self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map) self.data_index = self.data_index[np.where(self.flt_data == True)[0]] + self.idx_map = self.idx_map2arr(self.idx_map) self.start_idx, self.end_idx = self.data_index.slice_locs( start=time_to_slc_point(start), end=time_to_slc_point(end) @@ -328,6 +329,25 @@ class TSDataSampler: del self.data # save memory + @staticmethod + def idx_map2arr(idx_map): + # pytorch data sampler will have better memory control without large dict or list + # - https://github.com/pytorch/pytorch/issues/13243 + # - https://github.com/airctic/icevision/issues/613 + # So we convert the dict into int array. + # The arr_map is expected to behave the same as idx_map + + dtype = np.int32 + # set a index out of bound to indicate the none existing + no_existing_idx = (np.iinfo(dtype).max, np.iinfo(dtype).max) + + max_idx = max(idx_map.keys()) + arr_map = [] + for i in range(max_idx + 1): + arr_map.append(idx_map.get(i, no_existing_idx)) + arr_map = np.array(arr_map, dtype=dtype) + return arr_map + @staticmethod def flt_idx_map(flt_data, idx_map): idx = 0 @@ -524,20 +544,18 @@ class TSDatasetH(DatasetH): def setup_data(self, **kwargs): super().setup_data(**kwargs) + # make sure the calendar is updated to latest when loading data from new config cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique() - cal = sorted(cal) - self.cal = cal + self.cal = sorted(cal) - def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame: + @staticmethod + def _extend_slice(slc: slice, cal: list, step_len: int) -> slice: # Dataset decide how to slice data(Get more data for timeseries). start, end = slc.start, slc.stop - start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start)) - pad_start_idx = max(0, start_idx - self.step_len) - pad_start = self.cal[pad_start_idx] - - # TSDatasetH will retrieve more data for complete - data = super()._prepare_seg(slice(pad_start, end), **kwargs) - return data + start_idx = bisect.bisect_left(cal, pd.Timestamp(start)) + pad_start_idx = max(0, start_idx - step_len) + pad_start = cal[pad_start_idx] + return slice(pad_start, end) def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: """ @@ -546,13 +564,15 @@ class TSDatasetH(DatasetH): dtype = kwargs.pop("dtype", None) start, end = slc.start, slc.stop flt_col = kwargs.pop("flt_col", None) - # TSDatasetH will retrieve more data for complete - data = self._prepare_raw_seg(slc, **kwargs) + # TSDatasetH will retrieve more data for complete time-series + + ext_slice = self._extend_slice(slc, self.cal, self.step_len) + data = super()._prepare_seg(ext_slice, **kwargs) flt_kwargs = deepcopy(kwargs) if flt_col is not None: flt_kwargs["col_set"] = flt_col - flt_data = self._prepare_raw_seg(slc, **flt_kwargs) + flt_data = self._prepare_seg(ext_slice, **flt_kwargs) assert len(flt_data.columns) == 1 else: flt_data = None diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 507e5ea81..134091c22 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -82,8 +82,6 @@ class DataHandler(Serializable): fetch_orig : bool Return the original data instead of copy if possible. """ - # Set logger - self.logger = get_module_logger("DataHandler") # Setup data loader assert data_loader is not None # to make start_time end_time could have None default value @@ -302,6 +300,7 @@ class DataHandlerLP(DataHandler): DK_R = "raw" DK_I = "infer" DK_L = "learn" + ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"} # process type PTYPE_I = "independent" @@ -543,7 +542,7 @@ class DataHandlerLP(DataHandler): raise AttributeError( "DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data" ) - df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key]) + df = getattr(self, self.ATTR_MAP[data_key]) return df def fetch( @@ -624,3 +623,33 @@ class DataHandlerLP(DataHandler): df = self._get_df_by_key(data_key).head() df = fetch_df_by_col(df, col_set) return df.columns.to_list() + + @classmethod + def cast(cls, handler: "DataHandlerLP") -> "DataHandlerLP": + """ + Motivation + - A user create a datahandler in his customized package. Then he want to share the processed handler to other users without introduce the package dependency and complicated data processing logic. + - This class make it possible by casting the class to DataHandlerLP and only keep the processed data + + Parameters + ---------- + handler : DataHandlerLP + A subclass of DataHandlerLP + + Returns + ------- + DataHandlerLP: + the converted processed data + """ + new_hd: DataHandlerLP = object.__new__(DataHandlerLP) + new_hd.from_cast = True # add a mark for the casted instance + + for key in list(DataHandlerLP.ATTR_MAP.values()) + [ + "instruments", + "start_time", + "end_time", + "fetch_orig", + "drop_raw", + ]: + setattr(new_hd, key, getattr(handler, key, None)) + return new_hd diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index bd5d3dbd3..860477544 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -17,7 +17,7 @@ from ..utils import init_instance_by_config from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager from ..backtest.decision import BaseTradeDecision -__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"] +__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"] class BaseStrategy: @@ -194,45 +194,6 @@ class BaseStrategy: return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1]) -class ModelStrategy(BaseStrategy): - """Model-based trading strategy, use model to make predictions for trading""" - - def __init__( - self, - model: BaseModel, - dataset: DatasetH, - outer_trade_decision: BaseTradeDecision = None, - level_infra: LevelInfrastructure = None, - common_infra: CommonInfrastructure = None, - **kwargs, - ): - """ - Parameters - ---------- - model : BaseModel - the model used in when making predictions - dataset : DatasetH - provide test data for model - kwargs : dict - arguments that will be passed into `reset` method - """ - super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs) - self.model = model - self.dataset = dataset - self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime") - if isinstance(self.pred_scores, pd.DataFrame): - self.pred_scores = self.pred_scores.iloc[:, 0] - - def _update_model(self): - """ - When using online data, pdate model in each bar as the following steps: - - update dataset with online data, the dataset should support online update - - make the latest prediction scores of the new bar - - update the pred score into the latest prediction - """ - raise NotImplementedError("_update_model is not implemented!") - - class RLStrategy(BaseStrategy): """RL-based strategy""" diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index f6a6632ea..12553411c 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -199,6 +199,7 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod ---------- config : [dict, str] similar to config + please refer to the doc of init_instance_by_config default_module : Python module or str It should be a python module to load the class type @@ -219,9 +220,12 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod _callable = config["class"] # the class type itself is passed in kwargs = config.get("kwargs", {}) elif isinstance(config, str): - module = get_module_by_module_path(default_module) + # a.b.c.ClassName + *m_path, cls = config.split(".") + m_path = ".".join(m_path) + module = get_module_by_module_path(default_module if m_path == "" else m_path) - _callable = getattr(module, config) + _callable = getattr(module, cls) kwargs = {} else: raise NotImplementedError(f"This type of input is not supported") @@ -260,7 +264,9 @@ def init_instance_by_config( 1) specify a pickle object - path like 'file:////obj.pkl' 2) specify a class name - - "ClassName": getattr(module, config)() will be used. + - "ClassName": getattr(module, "ClassName")() will be used. + 3) specify module path with class name + - "a.b.c.ClassName" getattr(, "ClassName")() will be used. object example: instance of accept_types default_module : Python module diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 5e3942db5..06fb42a5e 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -401,6 +401,10 @@ class IndexData(metaclass=index_data_ops_creator): def columns(self): return self.indices[1] + def __getitem__(self, args): + # NOTE: this tries to behave like a numpy array to be compatible with numpy aggregating function like nansum and nanmean + return self.iloc[args] + def _align_indices(self, other: "IndexData") -> "IndexData": """ Align all indices of `other` to `self` before performing the arithmetic operations. @@ -409,7 +413,7 @@ class IndexData(metaclass=index_data_ops_creator): Parameters ---------- other : "IndexData" - the index in `other` is to be chagned + the index in `other` is to be changed Returns ------- @@ -455,7 +459,8 @@ class IndexData(metaclass=index_data_ops_creator): """ return len(self.data) - def sum(self, axis=None): + def sum(self, axis=None, dtype=None, out=None): + assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function" # FIXME: weird logic and not general if axis is None: return np.nansum(self.data) @@ -468,7 +473,8 @@ class IndexData(metaclass=index_data_ops_creator): else: raise ValueError(f"axis must be None, 0 or 1") - def mean(self, axis=None): + def mean(self, axis=None, dtype=None, out=None): + assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function" # FIXME: weird logic and not general if axis is None: return np.nanmean(self.data) diff --git a/qlib/utils/paral.py b/qlib/utils/paral.py index 075a1adb8..48b427a28 100644 --- a/qlib/utils/paral.py +++ b/qlib/utils/paral.py @@ -1,9 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import pandas as pd +from functools import partial +from threading import Thread +from typing import Callable + from joblib import Parallel, delayed from joblib._parallel_backends import MultiprocessingBackend +import pandas as pd + +from queue import Queue class ParallelExt(Parallel): @@ -46,3 +52,54 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru return pd.concat(dfs, axis=axis).sort_index() else: return _naive_group_apply(df) + + +class AsyncCaller: + """ + This AsyncCaller tries to make it easier to async call + + Currently, it is used in MLflowRecorder to make functions like `log_params` async + + NOTE: + - This caller didn't consider the return value + """ + + STOP_MARK = "__STOP" + + def __init__(self) -> None: + self._q = Queue() + self._stop = False + self._t = Thread(target=self.run) + self._t.start() + + def close(self): + self._q.put(self.STOP_MARK) + + def run(self): + while True: + data = self._q.get() + if data == self.STOP_MARK: + break + else: + data() + + def __call__(self, func, *args, **kwargs): + self._q.put(partial(func, *args, **kwargs)) + + def wait(self, close=True): + if close: + self.close() + self._t.join() + + @staticmethod + def async_dec(ac_attr): + def decorator_func(func): + def wrapper(self, *args, **kwargs): + if isinstance(getattr(self, ac_attr, None), Callable): + return getattr(self, ac_attr)(func, self, *args, **kwargs) + else: + return func(self, *args, **kwargs) + + return wrapper + + return decorator_func diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index b4b509483..e9f0fe9d2 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -21,19 +21,65 @@ Situations Description Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It will train models task by task and strategy by strategy. -Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train - nothing until all tasks have been prepared. It makes user can train all tasks in - the end of `routine` or `first_train`. +Online + DelayTrainer DelayTrainer will skip concrete training until all tasks have been prepared by + different strategies. It makes users can parallelly train all tasks at the end of + `routine` or `first_train`. Otherwise, these functions will get stuck when each + strategy prepare tasks. -Simulation + Trainer When your models have some temporal dependence on the previous models, then you - need to consider using Trainer. This means it will REAL train your models in - every routine and prepare signals for every routine. +Simulation + Trainer It will behave in the same way as `Online + Trainer`. The only difference is that it + is for simulation/backtesting instead of online trading Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer for the ability to multitasking. It means all tasks in all routines can be REAL trained at the end of simulating. The signals will be prepared well at different time segments (based on whether or not any new model is online). ========================= =================================================================================== + +Here is some pseudo code the demonstrate the workflow of each situation + +For simplicity + - Only one strategy is used in the strategy + - `update_online_pred` is only called in the online mode and is ignored + +1) `Online + Trainer` + +.. code-block:: python + + tasks = first_train() + models = trainer.train(tasks) + trainer.end_train(models) + for day in online_trading_days: + # OnlineManager.routine + models = trainer.train(strategy.prepare_tasks()) # for each strategy + strategy.prepare_online_models(models) # for each strategy + + trainer.end_train(models) + prepare_signals() # prepare trading signals daily + + +`Online + DelayTrainer`: the workflow is the same as `Online + Trainer`. + + +2) `Simulation + DelayTrainer` + +.. code-block:: python + + # simulate + tasks = first_train() + models = trainer.train(tasks) + for day in historical_calendars: + # OnlineManager.routine + models = trainer.train(strategy.prepare_tasks()) # for each strategy + strategy.prepare_online_models(models) # for each strategy + # delay_prepare() + # FIXME: Currently the delay_prepare is not implemented in a proper way. + trainer.end_train() + prepare_signals() + + +# Can we simplify current workflow? +- Can reduce the number of state of tasks? + - For each task, we have three phases (i.e. task, partly trained task, final trained task) """ import logging @@ -58,7 +104,7 @@ class OnlineManager(Serializable): """ STATUS_SIMULATING = "simulating" # when calling `simulate` - STATUS_NORMAL = "normal" # the normal status + STATUS_ONLINE = "online" # the normal status. It is used when online trading def __init__( self, @@ -87,12 +133,24 @@ class OnlineManager(Serializable): self.begin_time = pd.Timestamp(begin_time) self.cur_time = self.begin_time # OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}. + # It records the online servnig models of each strategy for each day. self.history = {} if trainer is None: trainer = TrainerR() self.trainer = trainer self.signals = None - self.status = self.STATUS_NORMAL + self.status = self.STATUS_ONLINE + + def _postpone_action(self): + """ + Should the workflow to postpone the following actions to the end (in delay_prepare) + - trainer.end_train + - prepare_signals + + Postpone these actions is to support simulating/backtest online strategies without time dependencies. + All the actions can be done parallelly at the end. + """ + return self.status == self.STATUS_SIMULATING and self.trainer.is_delay() def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}): """ @@ -113,12 +171,12 @@ class OnlineManager(Serializable): models = self.trainer.train(tasks, experiment_name=strategy.name_id) models_list.append(models) self.logger.info(f"Finished training {len(models)} models.") - # FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the + # FIXME: Train multiple online models at `first_train` will result in getting too much online models at the # start. online_models = strategy.prepare_online_models(models, **model_kwargs) self.history.setdefault(self.cur_time, {})[strategy] = online_models - if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay(): + if not self._postpone_action(): for strategy, models in zip(strategies, models_list): models = self.trainer.end_train(models, experiment_name=strategy.name_id) @@ -160,10 +218,10 @@ class OnlineManager(Serializable): # The online model may changes in the above processes # So updating the predictions of online models should be the last step - if self.status == self.STATUS_NORMAL: + if self.status == self.STATUS_ONLINE: strategy.tool.update_online_pred() - if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay(): + if not self._postpone_action(): for strategy, models in zip(self.strategies, models_list): models = self.trainer.end_train(models, experiment_name=strategy.name_id) self.prepare_signals(**signal_kwargs) @@ -278,13 +336,13 @@ class OnlineManager(Serializable): signal_kwargs=signal_kwargs, ) # delay prepare the models and signals - if self.trainer.is_delay(): + if self._postpone_action(): self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs) # FIXME: get logging level firstly and restore it here set_global_logger_level(logging.DEBUG) self.logger.info(f"Finished preparing signals") - self.status = self.STATUS_NORMAL + self.status = self.STATUS_ONLINE return self.get_signals() def delay_prepare(self, model_kwargs={}, signal_kwargs={}): @@ -295,6 +353,8 @@ class OnlineManager(Serializable): model_kwargs: the params for `end_train` signal_kwargs: the params for `prepare_signals` """ + # FIXME: + # This method is not implemented in the proper way!!! last_models = {} signals_time = D.calendar()[0] need_prepare = False diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 0d85311ee..07422243d 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -9,6 +9,9 @@ import pandas as pd from pathlib import Path from pprint import pprint from typing import Union, List +from collections import defaultdict + +from qlib.utils.exceptions import LoadObjectError from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis from ..data.dataset import DatasetH @@ -45,6 +48,16 @@ class RecordTemp: return "/".join(names) + def save(self, **kwargs): + """ + It behaves the same as self.recorder.save_objects. + But it is an easier interface because users don't have to care about `get_path` and `artifact_path` + """ + art_path = self.get_path() + if art_path == "": + art_path = None + self.recorder.save_objects(artifact_path=art_path, **kwargs) + def __init__(self, recorder): self._recorder = recorder @@ -67,31 +80,37 @@ class RecordTemp: """ raise NotImplementedError(f"Please implement the `generate` method.") - def load(self, name): + def load(self, name: str, parents: bool = True): """ - Load the stored records. Due to the fact that some problems occured when we tried to balancing a clean API - with the Python's inheritance. This method has to be used in a rather ugly way, and we will try to fix them - in the future:: - - sar = SigAnaRecord(recorder) - ic = sar.load(sar.get_path("ic.pkl")) + It behaves the same as self.recorder.load_object. + But it is an easier interface because users don't have to care about `get_path` and `artifact_path` Parameters ---------- name : str the name for the file to be load. + parents : bool + Each recorder has different `artifact_path`. + So parents recursively find the path in parents + Sub classes has higher priority + Return ------ The stored records. """ - # try to load the saved object - obj = self.recorder.load_object(name) - return obj + try: + return self.recorder.load_object(self.get_path(name)) + except LoadObjectError: + if parents: + if self.depend_cls is not None: + with class_casting(self, self.depend_cls): + return self.load(name, parents=True) def list(self): """ List the supported artifacts. + Users don't have to consider self.get_path Return ------ @@ -99,7 +118,7 @@ class RecordTemp: """ return [] - def check(self, include_self: bool = False): + def check(self, include_self: bool = False, parents: bool = True): """ Check if the records is properly generated and saved. It is useful in following examples @@ -110,19 +129,34 @@ class RecordTemp: ---------- include_self : bool is the file generated by self included + parents : bool + will we check parents Raise ------ - FileExistsError: whether the records are stored properly. + FileNotFoundError + : whether the records are stored properly. """ - artifacts = set(self.recorder.list_artifacts()) if include_self: + + # Some mlflow backend will not list the directly recursively. + # So we force to the directly + artifacts = {} + + def _get_arts(dirn): + if dirn not in artifacts: + artifacts[dirn] = self.recorder.list_artifacts(dirn) + return artifacts[dirn] + for item in self.list(): - if item not in artifacts: - raise FileExistsError(item) - if self.depend_cls is not None: - with class_casting(self, self.depend_cls): - self.check(include_self=True) + ps = self.get_path(item).split("/") + dirn, fn = "/".join(ps[:-1]), ps[-1] + if self.get_path(item) not in _get_arts(dirn): + raise FileNotFoundError + if parents: + if self.depend_cls is not None: + with class_casting(self, self.depend_cls): + self.check(include_self=True) class SignalRecord(RecordTemp): @@ -158,7 +192,7 @@ class SignalRecord(RecordTemp): pred = self.model.predict(self.dataset) if isinstance(pred, pd.Series): pred = pred.to_frame("score") - self.recorder.save_objects(**{"pred.pkl": pred}) + self.save(**{"pred.pkl": pred}) logger.info( f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" @@ -169,15 +203,11 @@ class SignalRecord(RecordTemp): if isinstance(self.dataset, DatasetH): raw_label = self.generate_label(self.dataset) - self.recorder.save_objects(**{"label.pkl": raw_label}) + self.save(**{"label.pkl": raw_label}) - @staticmethod - def list(): + def list(self): return ["pred.pkl", "label.pkl"] - def load(self, name="pred.pkl"): - return super().load(name) - class HFSignalRecord(SignalRecord): """ @@ -218,19 +248,11 @@ class HFSignalRecord(SignalRecord): } ) self.recorder.log_metrics(**metrics) - self.recorder.save_objects(**objects, artifact_path=self.get_path()) + self.save(**objects) pprint(metrics) def list(self): - paths = [ - self.get_path("ic.pkl"), - self.get_path("ric.pkl"), - self.get_path("long_pre.pkl"), - self.get_path("short_pre.pkl"), - self.get_path("long_short_r.pkl"), - self.get_path("long_avg_r.pkl"), - ] - return paths + return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"] class SigAnaRecord(RecordTemp): @@ -241,13 +263,23 @@ class SigAnaRecord(RecordTemp): artifact_path = "sig_analysis" depend_cls = SignalRecord - def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0): + def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False): super().__init__(recorder=recorder) self.ana_long_short = ana_long_short self.ann_scaler = ann_scaler self.label_col = label_col + self.skip_existing = skip_existing def generate(self, **kwargs): + if self.skip_existing: + try: + self.check(include_self=True, parents=False) + except FileNotFoundError: + pass # continue to generating metrics + else: + logger.info("The results has previously generated, generation skipped.") + return + self.check() pred = self.load("pred.pkl") @@ -280,13 +312,13 @@ class SigAnaRecord(RecordTemp): } ) self.recorder.log_metrics(**metrics) - self.recorder.save_objects(**objects, artifact_path=self.get_path()) + self.save(**objects) pprint(metrics) def list(self): - paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")] + paths = ["ic.pkl", "ric.pkl"] if self.ana_long_short: - paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) + paths.extend(["long_short_r.pkl", "long_avg_r.pkl"]) return paths @@ -373,17 +405,11 @@ class PortAnaRecord(RecordTemp): executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config ) for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items(): - self.recorder.save_objects( - **{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path() - ) - self.recorder.save_objects( - **{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"report_normal_{_freq}.pkl": report_normal}) + self.save(**{f"positions_normal_{_freq}.pkl": positions_normal}) for _freq, indicators_normal in indicator_dict.items(): - self.recorder.save_objects( - **{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal}) for _analysis_freq in self.risk_analysis_freq: if _analysis_freq not in portfolio_metric_dict: @@ -405,9 +431,7 @@ class PortAnaRecord(RecordTemp): analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict()) self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()}) # save results - self.recorder.save_objects( - **{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}) logger.info( f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) @@ -432,9 +456,7 @@ class PortAnaRecord(RecordTemp): analysis_dict = analysis_df["value"].to_dict() self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()}) # save results - self.recorder.save_objects( - **{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}) logger.info( f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) @@ -446,20 +468,19 @@ class PortAnaRecord(RecordTemp): for _freq in self.all_freq: list_path.extend( [ - PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"), - PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"), + f"report_normal_{_freq}.pkl", + f"positions_normal_{_freq}.pkl", ] ) for _analysis_freq in self.risk_analysis_freq: if _analysis_freq in self.all_freq: - list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl")) + list_path.append(f"port_analysis_{_analysis_freq}.pkl") else: warnings.warn(f"risk_analysis freq {_analysis_freq} is not found") for _analysis_freq in self.indicator_analysis_freq: if _analysis_freq in self.all_freq: - list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl")) + list_path.append(f"indicator_analysis_{_analysis_freq}.pkl") else: warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found") - return list_path diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 0bf6f4841..056d75be1 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import os from qlib.utils.serial import Serializable import mlflow, logging import shutil, os, pickle, tempfile, codecs, pickle @@ -8,8 +9,10 @@ from pathlib import Path from datetime import datetime from qlib.utils.exceptions import LoadObjectError +from qlib.utils.paral import AsyncCaller from ..utils.objm import FileManager -from ..log import get_module_logger +from ..log import TimeInspector, get_module_logger +from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository logger = get_module_logger("workflow", logging.INFO) @@ -227,6 +230,7 @@ class MLflowRecorder(Recorder): if mlflow_run.info.end_time is not None else None ) + self.async_log = None def __repr__(self): name = self.__class__.__name__ @@ -285,6 +289,10 @@ class MLflowRecorder(Recorder): self.status = Recorder.STATUS_R logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...") + # NOTE: making logging async. + # - This may cause delay when uploading results + # - The logging time may not be accurate + self.async_log = AsyncCaller() return run def end_run(self, status: str = Recorder.STATUS_S): @@ -298,6 +306,9 @@ class MLflowRecorder(Recorder): self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if self.status != Recorder.STATUS_S: self.status = status + with TimeInspector.logt("waiting `async_log`"): + self.async_log.wait() + self.async_log = None def save_objects(self, local_path=None, artifact_path=None, **kwargs): assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly." @@ -333,18 +344,27 @@ class MLflowRecorder(Recorder): try: path = self.client.download_artifacts(self.id, name) with Path(path).open("rb") as f: - return pickle.load(f) + data = pickle.load(f) + ar = self.client._tracking_client._get_artifact_repo(self.id) + if isinstance(ar, AzureBlobArtifactRepository): + # for saving disk space + # For safety, only remove redundant file for specific ArtifactRepository + shutil.rmtree(Path(path).absolute().parent) + return data except Exception as e: raise LoadObjectError(message=str(e)) + @AsyncCaller.async_dec(ac_attr="async_log") def log_params(self, **kwargs): for name, data in kwargs.items(): self.client.log_param(self.id, name, data) + @AsyncCaller.async_dec(ac_attr="async_log") def log_metrics(self, step=None, **kwargs): for name, data in kwargs.items(): self.client.log_metric(self.id, name, data, step=step) + @AsyncCaller.async_dec(ac_attr="async_log") def set_tags(self, **kwargs): for name, data in kwargs.items(): self.client.set_tag(self.id, name, data) diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 467281666..13fcd0202 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -5,6 +5,7 @@ Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on. """ +from libs.qlib.qlib.log import TimeInspector from typing import Callable, Dict, List from qlib.log import get_module_logger from qlib.utils.serial import Serializable @@ -190,7 +191,9 @@ class RecorderCollector(Collector): collect_dict = {} # filter records - recs = self.experiment.list_recorders(**self.list_kwargs) + + with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"): + recs = self.experiment.list_recorders(**self.list_kwargs) recs_flt = {} for rid, rec in recs.items(): if rec_filter_func is None or rec_filter_func(rec): diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 2fc87b1a4..45fba12da 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -94,6 +94,11 @@ class TaskGen(metaclass=abc.ABCMeta): def handler_mod(task: dict, rolling_gen): """ Help to modify the handler end time when using RollingGen + It try to handle the following case + - Hander's data end_time is earlier than dataset's test_data's segments. + - To handle this, handler's data's end_time is extended. + + If the handler's end_time is None, then it is not necessary to change it's end time. Args: task (dict): a task template @@ -112,6 +117,9 @@ def handler_mod(task: dict, rolling_gen): except KeyError: # Maybe dataset do not have handler, then do nothing. pass + except TypeError: + # May be the handler is a string. `"handler.pkl"["kwargs"]` will raise TypeError + pass class RollingGen(TaskGen): @@ -259,3 +267,56 @@ class RollingGen(TaskGen): # Update the following rolling res.extend(self.gen_following_tasks(t, test_end)) return res + + +class MultiHorizonGenBase(TaskGen): + def __init__(self, horizon: List[int] = [5], label_leak_n=2): + """ + This task generator tries to genrate tasks for different horizons based on an existing task + + Parameters + ---------- + horizon : List[int] + the possible horizons of the tasks + label_leak_n : int + How many future days it will take to get complete label after the day making prediction + For example: + - User make prediction on day `T`(after getting the close price on `T`) + - The label is the return of buying stock on `T + 1` and selling it on `T + 2` + - the `label_leak_n` will be 2 (e.g. two days of information is leaked to leverage this sample) + """ + self.horizon = list(horizon) + self.label_leak_n = label_leak_n + self.ta = TimeAdjuster() + self.test_key = "test" + + @abc.abstractmethod + def set_horizon(self, task: dict, hr: int): + """ + This method is designed to change the task **in place** + + Parameters + ---------- + task : dict + Qlib's task + hr : int + the horizon of task + """ + + def generate(self, task: dict): + res = [] + for hr in self.horizon: + + # Add horizon + t = copy.deepcopy(task) + self.set_horizon(t, hr) + + # adjust segment + segments = self.ta.align_seg(t["dataset"]["kwargs"]["segments"]) + test_start = min(t for t in segments[self.test_key] if t is not None) + for k in list(segments.keys()): + if k != self.test_key: + segments[k] = self.ta.truncate(segments[k], test_start, hr + self.label_leak_n) + t["dataset"]["kwargs"]["segments"] = segments + res.append(t) + return res diff --git a/tests/misc/test_index_data.py b/tests/misc/test_index_data.py index 3cd819a0f..20cda69ff 100644 --- a/tests/misc/test_index_data.py +++ b/tests/misc/test_index_data.py @@ -1,6 +1,5 @@ import numpy as np import pandas as pd - import qlib.utils.index_data as idd import unittest @@ -115,6 +114,19 @@ class IndexDataTest(unittest.TestCase): # sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) # 2 * sd2 + def test_squeeze(self): + sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) + # automatically squeezing + self.assertTrue(not isinstance(np.nansum(sd1), idd.IndexData)) + self.assertTrue(not isinstance(np.sum(sd1), idd.IndexData)) + self.assertTrue(not isinstance(sd1.sum(), idd.IndexData)) + self.assertEqual(np.nansum(sd1), 10) + self.assertEqual(np.sum(sd1), 10) + self.assertEqual(sd1.sum(), 10) + self.assertEqual(np.nanmean(sd1), 2.5) + self.assertEqual(np.mean(sd1), 2.5) + self.assertEqual(sd1.mean(), 2.5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index da68139a8..de15d8722 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -47,13 +47,13 @@ def train(uri_path: str = None): rid = recorder.id sr = SignalRecord(model, dataset, recorder) sr.generate() - pred_score = sr.load(sr.get_path("pred.pkl")) + pred_score = sr.load("pred.pkl") # calculate ic and ric sar = SigAnaRecord(recorder) sar.generate() - ic = sar.load(sar.get_path("ic.pkl")) - ric = sar.load(sar.get_path("ric.pkl")) + ic = sar.load("ic.pkl") + ric = sar.load("ric.pkl") return pred_score, {"ic": ic, "ric": ric}, rid @@ -78,13 +78,13 @@ def train_with_sigana(uri_path: str = None): sr = SignalRecord(model, dataset, recorder) sr.generate() - pred_score = sr.load(sr.get_path("pred.pkl")) + pred_score = sr.load("pred.pkl") # predict and calculate ic and ric sar = SigAnaRecord(recorder) sar.generate() - ic = sar.load(sar.get_path("ic.pkl")) - ric = sar.load(sar.get_path("ric.pkl")) + ic = sar.load("ic.pkl") + ric = sar.load("ric.pkl") uri_path = R.get_uri() return pred_score, {"ic": ic, "ric": ric}, uri_path @@ -144,10 +144,9 @@ def backtest_analysis(pred, rid, uri_path: str = None): }, "strategy": { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { - "model": model, - "dataset": dataset, + "signal": (model, dataset), "topk": 50, "n_drop": 5, }, @@ -170,7 +169,7 @@ def backtest_analysis(pred, rid, uri_path: str = None): # backtest par = PortAnaRecord(recorder, port_analysis_config, risk_analysis_freq="day") par.generate() - analysis_df = par.load(par.get_path("port_analysis_1day.pkl")) + analysis_df = par.load("port_analysis_1day.pkl") print(analysis_df) return analysis_df