diff --git a/README.md b/README.md
index 4f2509188..89d14e9eb 100644
--- a/README.md
+++ b/README.md
@@ -41,7 +41,7 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
# Framework of Qlib
-

+
@@ -192,24 +192,6 @@ The automatic workflow may not suite the research workflow of all Quant research
# [Quant Model Zoo](examples/benchmarks)
-## Run a single model
-`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
-- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
-- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
-- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
-
-## Run multiple models
-`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only supprots *Linux* now. Other OS will be supported in the future.)
-
-The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.
-
-Here is an example of running all the models for 10 iterations:
-```python
-python run_all_model.py 10
-```
-
-It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
-
Here is a list of models built on `Qlib`.
- [GBDT based on LightGBM](qlib/contrib/model/gbdt.py)
- [GBDT based on Catboost](qlib/contrib/model/catboost_model.py)
@@ -219,13 +201,30 @@ Here is a list of models built on `Qlib`.
- [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py)
- [ALSTM based on pytorcn](qlib/contrib/model/pytorch_alstm.py)
- [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py)
-- [TabNet based on pytorch](qlib/contrib/model/tabnet.py)
- [SFM based on pytorch](qlib/contrib/model/pytorch_sfm.py)
-- [HATs based on pytorch](qlib/contrib/model/pytorch_hats.py)
- [TFT based on tensorflow](examples/benchmarks/TFT/tft.py)
Your PR of new Quant models is highly welcomed.
+## Run a single model
+`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
+- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
+- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
+- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
+
+## Run multiple models
+`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only supprots *Linux* now. Other OS will be supported in the future.)
+
+The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored. (**Note**: the script will erase your previous experiment records created by running itself.)
+
+Here is an example of running all the models for 10 iterations:
+```python
+python run_all_model.py 10
+```
+
+It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
+
+
# Quant Dataset Zoo
Dataset plays a very important role in Quant. Here is a list of the datasets built on `Qlib`.
diff --git a/examples/benchmarks/ALSTM/README.md b/examples/benchmarks/ALSTM/README.md
index cd9dd3493..1b749bd80 100644
--- a/examples/benchmarks/ALSTM/README.md
+++ b/examples/benchmarks/ALSTM/README.md
@@ -2,9 +2,7 @@
- ALSTM contains a temporal attentive aggregation layer based on normal LSTM.
-- The code used in Qlib is a pyTorch implementation of Code: https://github.com/fulifeng/Adv-ALSTM
-
- Paper: A dual-stage attention-based recurrent neural network for time series prediction.
- https://www.ijcai.org/Proceedings/2017/0366.pdf
+ [https://www.ijcai.org/Proceedings/2017/0366.pdf](https://www.ijcai.org/Proceedings/2017/0366.pdf)
diff --git a/examples/benchmarks/GATs/workflow_config_gats.yaml b/examples/benchmarks/GATs/workflow_config_gats.yaml
index 33aa0fe8d..c38b4b312 100644
--- a/examples/benchmarks/GATs/workflow_config_gats.yaml
+++ b/examples/benchmarks/GATs/workflow_config_gats.yaml
@@ -8,6 +8,20 @@ data_handler_config: &data_handler_config
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
@@ -26,20 +40,19 @@ port_analysis_config: &port_analysis_config
min_cost: 5
task:
model:
- class: GAT
+ class: GATs
module_path: qlib.contrib.model.pytorch_gats
kwargs:
d_feat: 6
hidden_size: 64
num_layers: 2
- dropout: 0.0
+ dropout: 0.7
n_epochs: 200
- lr: 1e-3
+ lr: 1e-4
early_stop: 20
metric: loss
loss: mse
base_model: LSTM
- with_pretrain: True
seed: 0
GPU: 0
dataset:
@@ -47,7 +60,7 @@ task:
module_path: qlib.data.dataset
kwargs:
handler:
- class: ALPHA360_Denoise
+ class: ALPHA360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
@@ -58,11 +71,6 @@ task:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- - class: SigAnaRecord
- module_path: qlib.workflow.record_temp
- kwargs:
- ana_long_short: False
- ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/HATS/README.md b/examples/benchmarks/HATS/README.md
deleted file mode 100644
index b70dbff25..000000000
--- a/examples/benchmarks/HATS/README.md
+++ /dev/null
@@ -1,15 +0,0 @@
-## Requirement
-
-* pandas==1.1.2
-* numpy==1.17.4
-* scikit_learn==0.23.2
-* torch==1.7.0
-
-## HATS
-
-* HATS is a a hierarchical attention network for stock prediction which uses relational data for stock market prediction. HATS selectively aggregates information
-on different relation types and adds the information to the representations of each company. HATS is used as a relational modeling module with initialized node representations.Furthermore, HATS
-can predict not only individual stock prices but also market index movements, which is similar to the graph classification task.
-
-* HATS uses pretrained model of GRU and LSTM. The code of GRU and LSTM used in Qlib is a pyTorch implemention of GRU and LSTM.
-* Paper address:HATS: A Hierarchical Graph Attention Network for Stock Movement Prediction https://arxiv.org/pdf/1908.07999.pdf
\ No newline at end of file
diff --git a/examples/benchmarks/HATS/requirements.txt b/examples/benchmarks/HATS/requirements.txt
deleted file mode 100644
index 16de0a438..000000000
--- a/examples/benchmarks/HATS/requirements.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-pandas==1.1.2
-numpy==1.17.4
-scikit_learn==0.23.2
-torch==1.7.0
diff --git a/examples/benchmarks/HATS/worflow_config_hats.yaml b/examples/benchmarks/HATS/worflow_config_hats.yaml
deleted file mode 100644
index b08df14e0..000000000
--- a/examples/benchmarks/HATS/worflow_config_hats.yaml
+++ /dev/null
@@ -1,77 +0,0 @@
-provider_uri: "~/.qlib/qlib_data/cn_data"
-region: cn
-market: &market csi300
-benchmark: &benchmark SH000300
-data_handler_config: &data_handler_config
- start_time: 2008-01-01
- end_time: 2020-08-01
- fit_start_time: 2008-01-01
- fit_end_time: 2014-12-31
- instruments: *market
- infer_processors:
- - class: RobustZScoreNorm
- kwargs:
- fields_group: feature
- clip_outlier: true
- - class: Fillna
- kwargs:
- fields_group: feature
- learn_processors:
- - class: DropnaLabel
- - class: CSRankNorm
- kwargs:
- fields_group: label
- label: ["Ref($close, -2) / Ref($close, -1) - 1"]
-port_analysis_config: &port_analysis_config
- strategy:
- class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
- kwargs:
- topk: 50
- n_drop: 5
- backtest:
- verbose: False
- limit_threshold: 0.095
- account: 100000000
- benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
-task:
- model:
- class: HATS
- module_path: qlib.contrib.model.pytorch_hats
- kwargs:
- d_feat: 6
- hidden_size: 64
- num_layers: 2
- dropout: 0.6
- n_epochs: 200
- lr: 1e-3
- early_stop: 20
- metric: loss
- loss: mse
- base_model: GRU
- seed: 0
- GPU: 0
- dataset:
- class: DatasetH
- module_path: qlib.data.dataset
- kwargs:
- handler:
- class: ALPHA360
- module_path: qlib.contrib.data.handler
- kwargs: *data_handler_config
- segments:
- train: [2008-01-01, 2014-12-31]
- valid: [2015-01-01, 2016-12-31]
- test: [2017-01-01, 2020-08-01]
- record:
- - class: SignalRecord
- module_path: qlib.workflow.record_temp
- kwargs: {}
- - class: PortAnaRecord
- module_path: qlib.workflow.record_temp
- kwargs:
- config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/benchmarks/LSTM/model_lstm_csi300.pkl b/examples/benchmarks/LSTM/model_lstm_csi300.pkl
index ff7fee450..84d6419da 100644
Binary files a/examples/benchmarks/LSTM/model_lstm_csi300.pkl and b/examples/benchmarks/LSTM/model_lstm_csi300.pkl differ
diff --git a/examples/benchmarks/DNN/requirements.txt b/examples/benchmarks/MLP/requirements.txt
similarity index 100%
rename from examples/benchmarks/DNN/requirements.txt
rename to examples/benchmarks/MLP/requirements.txt
diff --git a/examples/benchmarks/DNN/workflow_config_dnn.yaml b/examples/benchmarks/MLP/workflow_config_mlp.yaml
similarity index 100%
rename from examples/benchmarks/DNN/workflow_config_dnn.yaml
rename to examples/benchmarks/MLP/workflow_config_mlp.yaml
diff --git a/examples/benchmarks/SFM/README.md b/examples/benchmarks/SFM/README.md
index 06ca50485..5f74c15d2 100644
--- a/examples/benchmarks/SFM/README.md
+++ b/examples/benchmarks/SFM/README.md
@@ -1,4 +1,3 @@
# State-Frequency-Memory
-- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform (DFT) to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
-- The code used in Qlib is a pyTorch implementation of SFM (Zhang, L., Aggarwal, C., & Qi, G. J. (2017,)).
-- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.
\ No newline at end of file
+- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
+- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.)
\ No newline at end of file
diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py
index a3b4fc919..3387a5947 100644
--- a/examples/benchmarks/TFT/tft.py
+++ b/examples/benchmarks/TFT/tft.py
@@ -233,9 +233,8 @@ class TFTModel(ModelFT):
tf.keras.backend.set_session(default_keras_session)
predict = format_score(p90_forecast, "pred", 0) # self.label_shift
- label = format_score(targets, "label", 0)
# ===========================Predicting Process===========================
- return predict, label
+ return predict
def finetune(self, dataset: DatasetH):
"""
diff --git a/examples/benchmarks/TabNet/README.md b/examples/benchmarks/TabNet/README.md
deleted file mode 100644
index 3a233df46..000000000
--- a/examples/benchmarks/TabNet/README.md
+++ /dev/null
@@ -1,4 +0,0 @@
-# TabNet
-* TabNet is a novel high-performance and interpretable canonical deep tabular data learning architectur. TabNet uses sequential attention to choose which features to reason from at each decision step, enabling interpretability and more effcient learning as the learning capacity is used for the most salient features.
-* The code used in Qlib is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). [https://github.com/dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet)
-* Paper: TabNet: Attentive Interpretable Tabular Learning. [https://arxiv.org/pdf/1908.07442.pdf](https://arxiv.org/pdf/1908.07442.pdf).
\ No newline at end of file
diff --git a/examples/benchmarks/TabNet/requirements.txt b/examples/benchmarks/TabNet/requirements.txt
deleted file mode 100644
index 244b74b19..000000000
--- a/examples/benchmarks/TabNet/requirements.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-pandas==1.1.2
-numpy==1.17.4
-scikit_learn==0.23.2
-torch==1.7.0
-pytorch-tabnet==2.0.1
\ No newline at end of file
diff --git a/examples/benchmarks/TabNet/workflow_config_tabnet.yaml b/examples/benchmarks/TabNet/workflow_config_tabnet.yaml
deleted file mode 100644
index 5f6aa8b6d..000000000
--- a/examples/benchmarks/TabNet/workflow_config_tabnet.yaml
+++ /dev/null
@@ -1,66 +0,0 @@
-provider_uri: "~/.qlib/qlib_data/cn_data"
-region: cn
-market: &market csi300
-benchmark: &benchmark SH000300
-data_handler_config: &data_handler_config
- start_time: 2008-01-01
- end_time: 2020-08-01
- fit_start_time: 2008-01-01
- fit_end_time: 2014-12-31
- instruments: *market
-port_analysis_config: &port_analysis_config
- strategy:
- class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
- kwargs:
- topk: 50
- n_drop: 5
- backtest:
- verbose: False
- limit_threshold: 0.095
- account: 100000000
- benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
-task:
- model:
- class: TabNetModel
- module_path: qlib.contrib.model.tabnet
- kwargs:
- n_d: 8
- n_a: 8
- n_steps: 3
- gamma: 1.3
- n_independent: 2
- n_shared: 2
- seed: 0
- momentum: 0.02
- lambda_sparse: 1e-3
- optimizer_params: {lr: 2e-3}
- dataset:
- class: DatasetH
- module_path: qlib.data.dataset
- kwargs:
- handler:
- class: ALPHA360_Denoise
- module_path: qlib.contrib.data.handler
- kwargs: *data_handler_config
- segments:
- train: [2008-01-01, 2014-12-31]
- valid: [2015-01-01, 2016-12-31]
- test: [2017-01-01, 2020-08-01]
- record:
- - class: SignalRecord
- module_path: qlib.workflow.record_temp
- kwargs: {}
- - class: SigAnaRecord
- module_path: qlib.workflow.record_temp
- kwargs:
- ana_long_short: False
- ann_scaler: 252
- - class: PortAnaRecord
- module_path: qlib.workflow.record_temp
- kwargs:
- config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/benchmarks/XGBoost/workflow_config_xgboost.yaml b/examples/benchmarks/XGBoost/workflow_config_xgboost.yaml
index 31eee8206..1352c496d 100644
--- a/examples/benchmarks/XGBoost/workflow_config_xgboost.yaml
+++ b/examples/benchmarks/XGBoost/workflow_config_xgboost.yaml
@@ -30,14 +30,12 @@ task:
module_path: qlib.contrib.model.xgboost
kwargs:
eval_metric: rmse
- colsample_bytree: 0.5
- eta: 0.2
- gamma: 0.55
- max_depth: 2
- min_child_weight: 1.0
+ colsample_bytree: 0.8879
+ eta: 0.0421
+ max_depth: 8
n_estimators: 647
- subsample: 0.8
- nthread: 4
+ subsample: 0.8789
+ nthread: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
@@ -62,4 +60,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
- config: *port_analysis_config
\ No newline at end of file
+ config: *port_analysis_config
diff --git a/examples/portfolio_optimization_example.ipynb b/examples/portfolio_optimization_example.ipynb
deleted file mode 100644
index 4d6c2b3d2..000000000
--- a/examples/portfolio_optimization_example.ipynb
+++ /dev/null
@@ -1,446 +0,0 @@
-{
- "metadata": {
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.9-final"
- },
- "orig_nbformat": 2,
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2,
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import sys\n",
- "import copy\n",
- "from pathlib import Path\n",
- "\n",
- "import qlib\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "from qlib.config import REG_CN\n",
- "from qlib.contrib.model.gbdt import LGBModel\n",
- "from qlib.contrib.data.handler import Alpha158\n",
- "from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
- "from qlib.contrib.evaluate import (\n",
- " backtest as normal_backtest,\n",
- " risk_analysis,\n",
- ")\n",
- "from qlib.utils import exists_qlib_data, init_instance_by_config\n",
- "from qlib.workflow import R\n",
- "from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
- "from qlib.utils import flatten_dict"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "[35366:MainThread](2020-11-27 10:31:09,528) INFO - qlib.Initialization - [__init__.py:41] - default_conf: client.\n",
- "[35366:MainThread](2020-11-27 10:31:09,531) WARNING - qlib.Initialization - [__init__.py:57] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
- "[35366:MainThread](2020-11-27 10:31:09,531) INFO - qlib.Initialization - [__init__.py:76] - qlib successfully initialized based on client settings.\n",
- "[35366:MainThread](2020-11-27 10:31:09,532) INFO - qlib.Initialization - [__init__.py:79] - data_path=/home/dongzho/.qlib/qlib_data/cn_data\n"
- ]
- }
- ],
- "source": [
- "# use default data\n",
- "# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data\n",
- "provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
- "if not exists_qlib_data(provider_uri):\n",
- " print(f\"Qlib data is not found in {provider_uri}\")\n",
- " sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
- " from get_data import GetData\n",
- " GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
- "qlib.init(provider_uri=provider_uri, region=REG_CN)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "market = \"csi300\"\n",
- "benchmark = \"SH000300\""
- ]
- },
- {
- "source": [
- "## Model Training"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "[35366:MainThread](2020-11-27 10:31:29,731) INFO - qlib.timer - [log.py:81] - Time cost: 20.103s | Loading data Done\n",
- "[35366:MainThread](2020-11-27 10:31:30,557) INFO - qlib.timer - [log.py:81] - Time cost: 0.241s | DropnaLabel Done\n",
- "[35366:MainThread](2020-11-27 10:31:38,518) INFO - qlib.timer - [log.py:81] - Time cost: 7.960s | CSZScoreNorm Done\n",
- "[35366:MainThread](2020-11-27 10:31:38,519) INFO - qlib.timer - [log.py:81] - Time cost: 8.786s | fit & process data Done\n",
- "[35366:MainThread](2020-11-27 10:31:38,520) INFO - qlib.timer - [log.py:81] - Time cost: 28.891s | Init data Done\n",
- "[35366:MainThread](2020-11-27 10:31:38,527) INFO - qlib.workflow - [exp.py:180] - Experiment 2 starts running ...\n",
- "[35366:MainThread](2020-11-27 10:31:38,651) INFO - qlib.workflow - [recorder.py:234] - Recorder c81375e3b5474feb9c77711babd158c3 starts running under Experiment 2 ...\n",
- "[35366:MainThread](2020-11-27 10:31:38,652) INFO - qlib.workflow - [expm.py:251] - No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory.\n",
- "Training until validation scores don't improve for 50 rounds\n",
- "[20]\ttrain's l2: 0.990559\tvalid's l2: 0.994332\n",
- "[40]\ttrain's l2: 0.98687\tvalid's l2: 0.993702\n",
- "[60]\ttrain's l2: 0.984308\tvalid's l2: 0.993503\n",
- "[80]\ttrain's l2: 0.982202\tvalid's l2: 0.993446\n",
- "[100]\ttrain's l2: 0.980318\tvalid's l2: 0.993423\n",
- "[120]\ttrain's l2: 0.97854\tvalid's l2: 0.993409\n",
- "[140]\ttrain's l2: 0.97679\tvalid's l2: 0.993413\n",
- "[160]\ttrain's l2: 0.975116\tvalid's l2: 0.993473\n",
- "Early stopping, best iteration is:\n",
- "[127]\ttrain's l2: 0.977957\tvalid's l2: 0.993381\n"
- ]
- }
- ],
- "source": [
- "###################################\n",
- "# train model\n",
- "###################################\n",
- "data_handler_config = {\n",
- " \"start_time\": \"2008-01-01\",\n",
- " \"end_time\": \"2020-08-01\",\n",
- " \"fit_start_time\": \"2008-01-01\",\n",
- " \"fit_end_time\": \"2014-12-31\",\n",
- " \"instruments\": market,\n",
- "}\n",
- "\n",
- "task = {\n",
- " \"model\": {\n",
- " \"class\": \"LGBModel\",\n",
- " \"module_path\": \"qlib.contrib.model.gbdt\",\n",
- " \"kwargs\": {\n",
- " \"loss\": \"mse\",\n",
- " \"colsample_bytree\": 0.8879,\n",
- " \"learning_rate\": 0.0421,\n",
- " \"subsample\": 0.8789,\n",
- " \"lambda_l1\": 205.6999,\n",
- " \"lambda_l2\": 580.9768,\n",
- " \"max_depth\": 8,\n",
- " \"num_leaves\": 210,\n",
- " \"num_threads\": 20,\n",
- " },\n",
- " },\n",
- " \"dataset\": {\n",
- " \"class\": \"DatasetH\",\n",
- " \"module_path\": \"qlib.data.dataset\",\n",
- " \"kwargs\": {\n",
- " \"handler\": {\n",
- " \"class\": \"Alpha158\",\n",
- " \"module_path\": \"qlib.contrib.data.handler\",\n",
- " \"kwargs\": data_handler_config,\n",
- " },\n",
- " \"segments\": {\n",
- " \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
- " \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
- " \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
- " },\n",
- " },\n",
- " },\n",
- "}\n",
- "\n",
- "# model initiaiton\n",
- "model = init_instance_by_config(task[\"model\"])\n",
- "dataset = init_instance_by_config(task[\"dataset\"])\n",
- "\n",
- "# start exp to train model\n",
- "with R.start(experiment_name=\"train_model\"):\n",
- " R.log_params(**flatten_dict(task))\n",
- " model.fit(dataset)\n",
- " R.save_objects(trained_model=model)\n",
- " rid = R.get_recorder().id\n"
- ]
- },
- {
- "source": [
- "## Optimization Based Strategy"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "from qlib.contrib.strategy.strategy import BaseStrategy\n",
- "\n",
- "\n",
- "class OptBasedStrategy(BaseStrategy):\n",
- " \"\"\"Optimization Based Strategy\"\"\"\n",
- "\n",
- " def __init__(self, data_handler, cov_estimator, optimizer):\n",
- " self.data_handler = data_handler\n",
- " self.cov_estimator = cov_estimator\n",
- " self.optimizer = optimizer\n",
- "\n",
- " def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):\n",
- " \"\"\"\n",
- " Parameters\n",
- " -----------\n",
- " score_series : pd.Seires\n",
- " stock_id , score.\n",
- " current : Position()\n",
- " current of account.\n",
- " trade_exchange : Exchange()\n",
- " exchange.\n",
- " trade_date : pd.Timestamp\n",
- " date.\n",
- " \"\"\"\n",
- " score_series = score_series.dropna()\n",
- "\n",
- " # check stock holdings, if\n",
- " # 1. doesn't have score: target amount = 0 (force sell)\n",
- " # 2. stock not tradable: target amount = current amount\n",
- " current_position = current.get_stock_amount_dict()\n",
- " target_position = {}\n",
- " for stock_id in current_position:\n",
- " if not trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):\n",
- " target_position[stock_id] = current_position[stock_id]\n",
- " elif stock_id not in score_series.index:\n",
- " target_position[stock_id] = 0\n",
- " else:\n",
- " # need to be solved by optimizer\n",
- " pass\n",
- "\n",
- " # filter scores, if\n",
- " # 1. kept in `amount_dict` by previous rules\n",
- " # 2. not tradable\n",
- " skipped = []\n",
- " for stock_id in score_series.index:\n",
- " if stock_id in target_position:\n",
- " skipped.append(stock_id)\n",
- " elif not trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):\n",
- " skipped.append(stock_id)\n",
- " score_series = score_series[~score_series.index.isin(skipped)]\n",
- "\n",
- " # calc remaining value\n",
- " current_value = pd.Series({\n",
- " stock_id: current.get_stock_price(stock_id) * amount\n",
- " for stock_id, amount in current_position.items()\n",
- " })\n",
- " risk_total_value = self.get_risk_degree(trade_date) * current.calculate_value()\n",
- " traded_value = risk_total_value - current_value.loc[list(target_position)].sum()\n",
- "\n",
- " # portfolio init weight\n",
- " init_weight = current_value.reindex(score_series.index, fill_value=0)\n",
- " init_weight_sum = init_weight.sum()\n",
- " if init_weight_sum > 0:\n",
- " init_weight /= init_weight_sum\n",
- "\n",
- " # covariance estimation\n",
- " selector = (self.data_handler.get_range_selector(pred_date, 252), score_series.index)\n",
- " price = self.data_handler.fetch(selector, level=None, squeeze=True)\n",
- " cov = self.cov_estimator(price)\n",
- " cov = cov.reindex(\n",
- " index=score_series.index, \n",
- " columns=score_series.index, \n",
- " #fill_value=cov.max().max()\n",
- " )\n",
- "\n",
- " # optimize target portfolio\n",
- " if init_weight.sum() > 0:\n",
- " target_weight = self.optimizer(cov, score_series, init_weight)\n",
- " else:\n",
- " target_weight = self.optimizer(cov, score_series)\n",
- " target_weight = target_weight[target_weight > 1e-6]\n",
- " for stock_id, weight in target_weight.items():\n",
- " try:\n",
- " target_position[stock_id] = int(traded_value * weight / trade_exchange.get_close(stock_id, pred_date))\n",
- " except Exception as e:\n",
- " # TODO: unknown exception\n",
- " print('Exception:', e)\n",
- "\n",
- " # for debug\n",
- " print('trade date:', trade_date)\n",
- " print('target weight:', target_weight.to_dict())\n",
- " print('target position:', target_position)\n",
- "\n",
- " # generate order list\n",
- " order_list = trade_exchange.generate_order_for_target_amount_position(\n",
- " target_position=target_position,\n",
- " current_position=current_position,\n",
- " trade_date=trade_date,\n",
- " )\n",
- "\n",
- " return order_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "from qlib.data.dataset.loader import QlibDataLoader\n",
- "from qlib.data.dataset.handler import DataHandler\n",
- "from qlib.model.riskmodel import ShrinkCovEstimator\n",
- "from qlib.portfolio.optimizer import PortfolioOptimizer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "[35366:MainThread](2020-11-27 10:31:56,951) INFO - qlib.timer - [log.py:81] - Time cost: 6.763s | Loading data Done\n",
- "[35366:MainThread](2020-11-27 10:31:56,953) INFO - qlib.timer - [log.py:81] - Time cost: 6.766s | Init data Done\n"
- ]
- }
- ],
- "source": [
- "data_loader = QlibDataLoader([\"$close\"])\n",
- "data_handler = DataHandler(\"all\", \"2015-01-01\", \"2020-08-01\", data_loader)\n",
- "cov_estimator = ShrinkCovEstimator(nan_option=\"mask\")\n",
- "optimizer = PortfolioOptimizer(\"mvo\", lamb=2, delta=0.2, tol=1e-5)\n",
- "strategy = OptBasedStrategy(data_handler, cov_estimator, optimizer)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 49,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "1': 0.08936553334387595, 'SH601800': 0.011014844457113308, 'SH601939': 0.013378001170219945, 'SH603993': 0.013820193926861863, 'SZ000338': 0.002455991798001457, 'SZ000423': 0.004893338273543826, 'SZ000538': 0.010686211189620477, 'SZ002065': 0.09095125419435357, 'SZ002074': 0.010299013738522475, 'SZ002085': 0.19844965949420615, 'SZ002236': 0.09210003831704765, 'SZ002310': 0.05664352912360013, 'SZ300017': 0.0197442255539771}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 272224, 'SH600009': 604839, 'SH600018': 3097398, 'SH600028': 335726, 'SH600196': 23243, 'SH600276': 71634, 'SH600519': 17354, 'SH600585': 269686, 'SH600900': 2501521, 'SH601111': 2400659, 'SH601800': 334062, 'SH601939': 1283164, 'SH603993': 742901, 'SZ000338': 95285, 'SZ000423': 21697, 'SZ000538': 14518, 'SZ002065': 498253, 'SZ002074': 111674, 'SZ002085': 591507, 'SZ002236': 394197, 'SZ002310': 2202674, 'SZ300017': 206128}\n",
- "target weight: {'SH600000': 0.02310668460556249, 'SH600009': 0.06170206213753432, 'SH600018': 0.027608180837257277, 'SH600028': 0.00971532319525714, 'SH600196': 0.0036133308423111116, 'SH600276': 0.093195014492093, 'SH600519': 0.013476706174774766, 'SH600585': 0.036024919027310476, 'SH600660': 0.04512159672692613, 'SH600900': 0.12506534473579556, 'SH601939': 0.013494851810297546, 'SH603993': 0.07619418669734077, 'SZ000338': 0.0024673392047414363, 'SZ000423': 0.00485981529404862, 'SZ000538': 0.010602880875660015, 'SZ002065': 0.09064325205359221, 'SZ002074': 0.0011889996597580427, 'SZ002085': 0.1982091371262038, 'SZ002236': 0.09254320484936242, 'SZ002310': 0.05152917909181458, 'SZ002466': 0.00014732765084648903, 'SZ300017': 0.019490662910321074}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 272079, 'SH600009': 604359, 'SH600018': 3095205, 'SH600028': 335471, 'SH600196': 23407, 'SH600276': 71567, 'SH600519': 17345, 'SH600585': 269447, 'SH600660': 129265, 'SH600900': 2499305, 'SH601939': 1282317, 'SH603993': 4058172, 'SZ000338': 95223, 'SZ000423': 21703, 'SZ000538': 14509, 'SZ002065': 497821, 'SZ002074': 12787, 'SZ002085': 590955, 'SZ002236': 393895, 'SZ002310': 2190685, 'SZ002466': 4483, 'SZ300017': 205994}\n",
- "target weight: {'SH600000': 0.0014042138463464568, 'SH600009': 0.11511740651805806, 'SH600018': 0.026968513725965638, 'SH600028': 0.009566603496832042, 'SH600150': 0.016339328084607228, 'SH600276': 0.09374974543357856, 'SH600489': 0.021876512936684123, 'SH600585': 0.035840818294258524, 'SH600900': 0.12414161958870683, 'SH601888': 0.005682635273269834, 'SH601939': 0.013289788356428228, 'SH603993': 0.07491407610535435, 'SZ000338': 0.002426716760042838, 'SZ000423': 0.00492071038737461, 'SZ000503': 0.005617017904986693, 'SZ000538': 0.010859006699485451, 'SZ002065': 0.08924691553942904, 'SZ002085': 0.19757848255238786, 'SZ002236': 0.09381012783787722, 'SZ002310': 0.03737359938389514, 'SZ300017': 0.01927616131502695}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16809, 'SH600009': 1075516, 'SH600018': 3091248, 'SH600028': 335128, 'SH600150': 114804, 'SH600276': 71473, 'SH600489': 66586, 'SH600585': 268644, 'SH600900': 2496175, 'SH601888': 173824, 'SH601939': 1281108, 'SH603993': 4052802, 'SZ000338': 95107, 'SZ000423': 21684, 'SZ000503': 80461, 'SZ000538': 14507, 'SZ002065': 497197, 'SZ002085': 590211, 'SZ002236': 393412, 'SZ002310': 1573728, 'SZ300017': 205818}\n",
- "target weight: {'SH600000': 0.0013962189421662084, 'SH600009': 0.09330267135244051, 'SH600018': 0.026443154116291615, 'SH600028': 0.009581412428525829, 'SH600150': 0.016443917649559808, 'SH600276': 0.09378402212481758, 'SH600703': 0.0005233118350013756, 'SH600741': 0.10117549074044105, 'SH600900': 0.12435147566444608, 'SH601888': 0.00560250787284307, 'SH601939': 0.013238798853730008, 'SH603993': 0.07455231781733267, 'SZ000423': 0.0048695925705555185, 'SZ000503': 0.006070996956328167, 'SZ000538': 0.010870567565742796, 'SZ002065': 0.08722983720892508, 'SZ002074': 0.00037126948590009574, 'SZ002085': 0.19840484837030906, 'SZ002236': 0.09365186287123867, 'SZ002310': 0.03806080531862309, 'SZ300017': 7.492025186876957e-05}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16889, 'SH600009': 867443, 'SH600018': 3086467, 'SH600028': 334573, 'SH600150': 114383, 'SH600276': 71360, 'SH600703': 1760, 'SH600741': 665366, 'SH600900': 2491839, 'SH601888': 173465, 'SH601939': 1278590, 'SH603993': 4045939, 'SZ000423': 21674, 'SZ000503': 80212, 'SZ000538': 14499, 'SZ002065': 496361, 'SZ002074': 4086, 'SZ002085': 589224, 'SZ002236': 392766, 'SZ002310': 1571463, 'SZ300017': 805}\n",
- "target weight: {'SH600000': 0.0014143911110003147, 'SH600018': 0.026834186435965166, 'SH600028': 0.00961324990522086, 'SH600150': 0.015905361405158292, 'SH600276': 0.09486308638260738, 'SH600685': 1.0253334545374858e-06, 'SH600703': 0.0005108576602907958, 'SH600741': 0.10252334336233063, 'SH600900': 0.1250632059809011, 'SH601888': 0.005830869532670813, 'SH601939': 0.01336945356138906, 'SH603993': 0.07101851124599835, 'SZ000423': 0.004899981502195361, 'SZ000503': 0.006113894785564276, 'SZ000538': 0.011081925761176491, 'SZ000709': 1.06442568357325e-06, 'SZ002065': 0.08812103684766726, 'SZ002074': 0.0003564773234700175, 'SZ002085': 0.19097427428977284, 'SZ002236': 0.09299395368630246, 'SZ002310': 0.03841630892378685, 'SZ002475': 0.10001934454071283, 'SZ300017': 7.322667303400442e-05}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16886, 'SH600018': 3080789, 'SH600028': 334087, 'SH600150': 114360, 'SH600276': 71234, 'SH600685': 10, 'SH600703': 1709, 'SH600741': 663932, 'SH600900': 2486951, 'SH601888': 173417, 'SH601939': 1276335, 'SH603993': 3740672, 'SZ000423': 21667, 'SZ000503': 80191, 'SZ000538': 14495, 'SZ000709': 11, 'SZ002065': 495371, 'SZ002074': 3867, 'SZ002085': 588051, 'SZ002236': 392002, 'SZ002310': 1568834, 'SZ002475': 1264636, 'SZ300017': 809}\n",
- "target weight: {'SH600000': 0.0013872765178790307, 'SH600018': 0.026321999857337998, 'SH600028': 0.009491029058787367, 'SH600150': 0.015749871987744815, 'SH600276': 0.09581999547114961, 'SH600703': 0.000518490273176083, 'SH600741': 0.1037547619508012, 'SH600900': 0.12396253436063161, 'SH601258': 0.02298494942988327, 'SH601888': 0.005915886046387033, 'SH601939': 0.013177336599075601, 'SH603993': 0.06888468621566025, 'SZ000423': 0.005102036718661418, 'SZ000503': 0.00602692511970311, 'SZ000538': 0.011127923667697532, 'SZ000709': 0.07688609680386178, 'SZ002065': 0.08693397271897534, 'SZ002074': 0.000347445594871718, 'SZ002085': 0.1905176824564206, 'SZ002236': 0.035835596544641496, 'SZ002475': 0.09918059167278087, 'SZ300017': 7.291118905149903e-05}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16948, 'SH600018': 3086676, 'SH600028': 334750, 'SH600150': 114560, 'SH600276': 71372, 'SH600703': 1715, 'SH600741': 665129, 'SH600900': 2491433, 'SH601258': 4190669, 'SH601888': 174070, 'SH601939': 1278836, 'SH603993': 3747283, 'SZ000423': 21744, 'SZ000503': 80490, 'SZ000538': 14538, 'SZ000709': 871429, 'SZ002065': 496245, 'SZ002074': 3887, 'SZ002085': 589120, 'SZ002236': 145147, 'SZ002475': 1268582, 'SZ300017': 814}\n",
- "target weight: {'SH600000': 0.001373124016867567, 'SH600018': 0.02646941123076474, 'SH600028': 0.009458335378810856, 'SH600150': 0.015442533996257352, 'SH600276': 0.09620341387657301, 'SH600649': 0.012613476480118908, 'SH600703': 0.0005280976985716832, 'SH600741': 0.06577156829314017, 'SH600900': 0.12455488881029539, 'SH601258': 0.02270943336842379, 'SH601939': 0.013066707696697587, 'SH603993': 0.0649427819283919, 'SZ000423': 0.0051167756388828005, 'SZ000503': 0.006076486564538039, 'SZ000709': 0.0770418453012855, 'SZ000778': 0.08738918304165759, 'SZ002065': 0.08804613990036694, 'SZ002074': 0.00034315924263262563, 'SZ002085': 0.18241434394629127, 'SZ002475': 0.10035998625624482, 'SZ300017': 7.809604376099223e-05}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16935, 'SH600018': 3089469, 'SH600028': 334906, 'SH600150': 114496, 'SH600276': 71430, 'SH600649': 337388, 'SH600703': 1714, 'SH600741': 419916, 'SH600900': 2493978, 'SH601258': 4194599, 'SH601939': 1279661, 'SH603993': 3750968, 'SZ000423': 21734, 'SZ000503': 80440, 'SZ000709': 872293, 'SZ000778': 366855, 'SZ002065': 496756, 'SZ002074': 3880, 'SZ002085': 564610, 'SZ002475': 1269872, 'SZ300017': 812}\n",
- "target weight: {'SH600000': 0.0013497287789570015, 'SH600018': 0.02647482761554837, 'SH600028': 0.00941080088689994, 'SH600150': 0.01556139303593115, 'SH600276': 0.09732218714743374, 'SH600649': 0.012606184789019243, 'SH600703': 0.0005334649726542859, 'SH600900': 0.12593267687041163, 'SH601258': 0.021199485570796834, 'SH601939': 0.013025993149697816, 'SH603993': 0.06446918682668012, 'SZ000423': 0.005311875734339093, 'SZ000503': 0.006125989728635501, 'SZ000709': 0.0707610058353687, 'SZ000778': 0.14004715956352495, 'SZ002065': 0.08746446321200681, 'SZ002074': 0.00033710686535540885, 'SZ002085': 0.15238971653801253, 'SZ002146': 0.042585776887618575, 'SZ002475': 0.10701429615740456, 'SZ300017': 7.667981013711115e-05}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 17031, 'SH600018': 3109084, 'SH600028': 336978, 'SH600150': 115126, 'SH600276': 71888, 'SH600649': 339316, 'SH600703': 1724, 'SH600900': 2510148, 'SH601258': 4237748, 'SH601939': 1287810, 'SH603993': 3775382, 'SZ000423': 21853, 'SZ000503': 80885, 'SZ000709': 878077, 'SZ000778': 625157, 'SZ002065': 499988, 'SZ002074': 3901, 'SZ002085': 469624, 'SZ002146': 2000993, 'SZ002475': 1278084, 'SZ300017': 814}\n",
- "target weight: {'SH600000': 0.0013594926998639766, 'SH600009': 0.021101252574639438, 'SH600028': 0.009528554544265834, 'SH600150': 0.015013601602404225, 'SH600276': 0.09860402207319302, 'SH600649': 0.01292550325031454, 'SH600685': 0.00703471182662378, 'SH600703': 0.0005218767517596246, 'SH600900': 0.12786995199482584, 'SH601258': 0.04401496515184404, 'SH601398': 0.025932829520167643, 'SH601939': 0.0134408200189716, 'SH603993': 0.06319752369639879, 'SZ000423': 0.005221187626834546, 'SZ000503': 0.006085670359590286, 'SZ000568': 0.003081214755480397, 'SZ000709': 0.07061122716452324, 'SZ000778': 0.1379488795662632, 'SZ000839': 0.019142903464547063, 'SZ002065': 0.04714685528331623, 'SZ002074': 0.00033291622875151913, 'SZ002085': 0.11947661465752588, 'SZ002146': 0.043205942689553425, 'SZ002310': 0.0009243182551654129, 'SZ002475': 0.106199974013018, 'SZ300017': 7.709323254732814e-05}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16933, 'SH600009': 196068, 'SH600028': 337025, 'SH600150': 115100, 'SH600276': 71926, 'SH600649': 339354, 'SH600685': 75328, 'SH600703': 1713, 'SH600900': 2511928, 'SH601258': 8791935, 'SH601398': 1146896, 'SH601939': 1288215, 'SH603993': 3777819, 'SZ000423': 21728, 'SZ000503': 80869, 'SZ000568': 10375, 'SZ000709': 878683, 'SZ000778': 625604, 'SZ000839': 312116, 'SZ002065': 268413, 'SZ002074': 3860, 'SZ002085': 369761, 'SZ002146': 2002072, 'SZ002310': 40341, 'SZ002475': 1278918, 'SZ300017': 811}\n",
- "target weight: {'SH600000': 0.0013764694393366029, 'SH600009': 0.021541655860797534, 'SH600028': 0.009752609535237182, 'SH600276': 0.06514222178877259, 'SH600649': 0.01273168785031133, 'SH600685': 0.006989932070614982, 'SH600900': 0.12998548252109676, 'SH601258': 0.13157540821422453, 'SH601398': 0.02641881439805636, 'SH601939': 0.0136141957873422, 'SH603993': 0.0602411123337629, 'SZ000503': 0.006084251045333903, 'SZ000709': 0.06977363144499521, 'SZ000778': 0.1385461140272643, 'SZ000839': 0.018579865431307987, 'SZ002065': 0.046270476942690986, 'SZ002074': 0.00025974854597178115, 'SZ002085': 0.10060756172850334, 'SZ002146': 0.043204792194791966, 'SZ002310': 0.0009022784286642987, 'SZ002466': 0.011748866835406593, 'SZ002475': 0.08457581284822364, 'SZ300017': 7.701070501151889e-05}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16938, 'SH600009': 196239, 'SH600028': 337355, 'SH600276': 46535, 'SH600649': 339479, 'SH600685': 75274, 'SH600900': 2514488, 'SH601258': 26730440, 'SH601398': 1148157, 'SH601939': 1289259, 'SH603993': 3781937, 'SZ000503': 80900, 'SZ000709': 879645, 'SZ000778': 626285, 'SZ000839': 312384, 'SZ002065': 268717, 'SZ002074': 3093, 'SZ002085': 309206, 'SZ002146': 2003901, 'SZ002310': 39782, 'SZ002466': 367691, 'SZ002475': 1026389, 'SZ300017': 812}\n",
- "target weight: {'SH600000': 0.0013689894888766726, 'SH600009': 0.021087495457198752, 'SH600028': 0.009589419355091226, 'SH600276': 0.0644304399184473, 'SH600535': 0.016420787426513667, 'SH600649': 0.0267771761277641, 'SH600900': 0.12784455237901315, 'SH601169': 0.004374459372110214, 'SH601258': 0.13288651981531077, 'SH601398': 0.02615927477879055, 'SH601939': 0.013573361058977978, 'SH603993': 1.157895161672162e-06, 'SZ000503': 0.009069218941980683, 'SZ000709': 0.07014466816191627, 'SZ000778': 0.13956352821962528, 'SZ002065': 0.045206445945654664, 'SZ002085': 0.08649963592018277, 'SZ002146': 0.04234588186007612, 'SZ002310': 0.0008924777422846245, 'SZ002466': 0.07334842360184116, 'SZ002475': 0.08834296814868704, 'SZ300017': 7.311841306821287e-05}\n",
- "Exception: ('SH601169', Timestamp('2017-04-25 00:00:00'))\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16929, 'SH600009': 196092, 'SH600028': 337333, 'SH600276': 46571, 'SH600535': 57649, 'SH600649': 731641, 'SH600900': 2515321, 'SH601258': 26740467, 'SH601398': 1148635, 'SH601939': 1289434, 'SH603993': 72, 'SZ000503': 122157, 'SZ000709': 879908, 'SZ000778': 626506, 'SZ002065': 268767, 'SZ002085': 267906, 'SZ002146': 2004576, 'SZ002310': 39745, 'SZ002466': 2332750, 'SZ002475': 1026858, 'SZ300017': 806}\n",
- "target weight: {'SH600000': 0.0013439859873209908, 'SH600009': 0.02075652616964347, 'SH600028': 0.00939963933310415, 'SH600276': 0.06236017906066887, 'SH600535': 0.016369568294734148, 'SH600649': 0.025541724367766302, 'SH600900': 0.12768966131041845, 'SH601258': 0.1370446945486361, 'SH601398': 0.02601619218529119, 'SH601939': 0.013440958024818669, 'SH603993': 4.144559709761373e-06, 'SZ000503': 0.0084237188568659, 'SZ000568': 0.020576387679160105, 'SZ000709': 0.056783757531829446, 'SZ000778': 0.06920027928808208, 'SZ002008': 0.07943378393922318, 'SZ002065': 0.045339177613740886, 'SZ002085': 0.08505902525865962, 'SZ002146': 0.031624633954490035, 'SZ002310': 0.0008996156348854183, 'SZ002466': 0.0764983539831682, 'SZ002475': 0.086193992434369}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SZ300017': 812.4573136217659, 'SH600000': 16923, 'SH600009': 196076, 'SH600028': 337279, 'SH600276': 46567, 'SH600535': 57624, 'SH600649': 731549, 'SH600900': 2515891, 'SH601258': 26747448, 'SH601398': 1148886, 'SH601939': 1289307, 'SH603993': 263, 'SZ000503': 122158, 'SZ000568': 69471, 'SZ000709': 700781, 'SZ000778': 302643, 'SZ002008': 746285, 'SZ002065': 268804, 'SZ002085': 267988, 'SZ002146': 1473970, 'SZ002310': 39739, 'SZ002466': 2333288, 'SZ002475': 1027134}\n",
- "target weight: {'SH600000': 0.0014508867295425067, 'SH600009': 0.022137935734971876, 'SH600028': 0.01003980705499816, 'SH600276': 0.065554410760754, 'SH600535': 0.017337663954140436, 'SH600649': 0.026752732524884384, 'SH600900': 0.13610376526017787, 'SH601258': 0.14230666244775886, 'SH601398': 0.027847743092481312, 'SH601939': 0.014306563408357105, 'SH603993': 2.7770868647848817e-06, 'SZ000069': 0.10104502775773525, 'SZ000503': 0.009049444347506782, 'SZ000568': 0.005686495401232644, 'SZ000778': 0.0715782861850023, 'SZ002008': 0.08609584908472251, 'SZ002065': 0.04706561122827146, 'SZ002085': 0.09099179117275048, 'SZ002146': 0.03204301334262787, 'SZ002475': 0.09241758644387384, 'SZ300017': 0.00018594702102337797}\n",
- "target position: {'SZ000709': 700825.0269758024, 'SZ002299': 6184584.0980107365, 'SH600000': 16845, 'SH600009': 195098, 'SH600028': 335689, 'SH600276': 46340, 'SH600535': 57343, 'SH600649': 728078, 'SH600900': 2504242, 'SH601258': 26624542, 'SH601398': 1143577, 'SH601939': 1283067, 'SH603993': 160, 'SZ000069': 367637, 'SZ000503': 121565, 'SZ000568': 17626, 'SZ000778': 301250, 'SZ002008': 742790, 'SZ002065': 267559, 'SZ002085': 266737, 'SZ002146': 1467579, 'SZ002475': 1022346, 'SZ300017': 1776}\n",
- "target weight: {'SH600000': 0.0013484985106016394, 'SH600009': 0.020750773768622693, 'SH600028': 0.009285673867962157, 'SH600104': 2.9067007814076732e-05, 'SH600196': 0.10012804077099052, 'SH600276': 0.05943563439541343, 'SH600535': 0.015902136087846228, 'SH600649': 0.025189836387314323, 'SH600900': 0.12584805827140388, 'SH601111': 6.857382365314848e-06, 'SH601258': 0.03895938466363849, 'SH601398': 0.025753888553878806, 'SH601939': 0.013275755331575599, 'SH603993': 4.249178615404585e-06, 'SZ000069': 0.09445579375504781, 'SZ000503': 0.008532747266799033, 'SZ000568': 0.0052599046052527266, 'SZ000709': 0.06003418476540357, 'SZ000778': 0.06923031488245988, 'SZ002008': 0.07903025205993618, 'SZ002065': 0.04448484691775433, 'SZ002085': 0.08426354045447453, 'SZ002146': 0.031142767130486235, 'SZ002475': 0.08747938111190227, 'SZ300017': 0.00016841662419817417}\n",
- "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16906, 'SH600009': 195107, 'SH600028': 335257, 'SH600104': 197, 'SH600196': 630404, 'SH600276': 46282, 'SH600535': 57311, 'SH600649': 727170, 'SH600900': 2500379, 'SH601111': 203, 'SH601258': 7443096, 'SH601398': 1142014, 'SH601939': 1281361, 'SH603993': 263, 'SZ000069': 366998, 'SZ000503': 121479, 'SZ000568': 17699, 'SZ000709': 699639, 'SZ000778': 300752, 'SZ002008': 741767, 'SZ002065': 267133, 'SZ002085': 266334, 'SZ002146': 1465489, 'SZ002475': 1020693, 'SZ300017': 1756}\n",
- "target weight: {'SH600000': 0.0012976336004362882, 'SH600009': 0.0204756895024156, 'SH600028': 0.008883617000656601, 'SH600104': 2.592943319382378e-05, 'SH600196': 0.09617041827497698, 'SH600276': 0.05681162545715886, 'SH600535': 0.015294256733040745, 'SH600649': 0.02417676167926707, 'SH600900': 0.12233373885315162, 'SH601398': 0.024531954099214746, 'SH601628': 0.005044154324745466, 'SH601888': 0.09500034426651846, 'SH601939': 0.012657033879067425, 'SH603993': 4.079522960136806e-06, 'SZ000069': 0.09054142453059062, 'SZ000503': 0.008036587259744734, 'SZ000568': 0.0049533657881637655, 'SZ000778': 0.06904486736535222, 'SZ002008': 0.06688985213943154, 'SZ002065': 0.04278977877238287, 'SZ002085': 0.0820368284038888, 'SZ002299': 0.06899317887598991, 'SZ002475': 0.08384652594205952, 'SZ300017': 0.00016035416530955983}\n",
- "target position: {'SH601258': 7443495.190430395, 'SH600000': 16952, 'SH600009': 195676, 'SH600028': 336044, 'SH600104': 183, 'SH600196': 631454, 'SH600276': 46372, 'SH600535': 57498, 'SH600649': 728582, 'SH600900': 2504660, 'SH601398': 1143938, 'SH601628': 695470, 'SH601888': 2951253, 'SH601939': 1283887, 'SH603993': 255, 'SZ000069': 367641, 'SZ000503': 121875, 'SZ000568': 17775, 'SZ000778': 301255, 'SZ002008': 638620, 'SZ002065': 267645, 'SZ002085': 266802, 'SZ002299': 6194843, 'SZ002475': 1022527, 'SZ300017': 1765}\n",
- "target weight: {'SH600000': 0.0013469483722729403, 'SH600028': 0.009286467498269333, 'SH600104': 2.368500734977497e-05, 'SH600196': 0.10145424564201923, 'SH600276': 0.06002237364700993, 'SH600535': 0.01588332650422844, 'SH600649': 0.025440421851940002, 'SH600900': 0.1279028471227695, 'SH601258': 0.035917606048396986, 'SH601398': 0.02559318344055778, 'SH601628': 0.005221942888216608, 'SH601888': 0.14928498761757883, 'SH601939': 0.013161430940131148, 'SH603993': 4.350147095904942e-06, 'SZ000069': 0.14038473724819095, 'SZ000503': 0.008556251357999256, 'SZ000568': 0.005243511514392524, 'SZ002008': 0.06824325050397591, 'SZ002065': 0.04420632869308568, 'SZ002085': 0.074424247013131, 'SZ002299': 0.0010812901181988855, 'SZ002475': 0.0871460668952185, 'SZ300017': 0.00017049992832446128}\n",
- "target position: {'SZ000778': 301254.84776855103, 'SH600000': 16873, 'SH600028': 335064, 'SH600104': 156, 'SH600196': 629613, 'SH600276': 46235, 'SH600535': 57245, 'SH600649': 726346, 'SH600900': 2497776, 'SH601258': 7423462, 'SH601398': 1140689, 'SH601628': 692346, 'SH601888': 4557826, 'SH601939': 1279908, 'SH603993': 261, 'SZ000069': 551887, 'SZ000503': 121344, 'SZ000568': 17697, 'SZ002008': 636943, 'SZ002065': 266904, 'SZ002085': 231781, 'SZ002299': 97527, 'SZ002475': 1019747, 'SZ300017': 1749}\n"
- ]
- },
- {
- "output_type": "error",
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "traceback": [
- "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[1;31m# backtest & analysis\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 31\u001b[0m \u001b[0mpar\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mport_analysis_config\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 32\u001b[1;33m \u001b[0mpar\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[1;32md:\\qlib\\qlib\\workflow\\record_temp.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(self, **kwargs)\u001b[0m\n\u001b[0;32m 230\u001b[0m \u001b[1;31m# custom strategy and get backtest\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 231\u001b[0m \u001b[0mpred_score\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 232\u001b[1;33m \u001b[0mreport_normal\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpositions_normal\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnormal_backtest\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpred_score\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstrategy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbacktest_config\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 233\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave_objects\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m\"report_normal.pkl\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mreport_normal\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0martifact_path\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_path\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 234\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave_objects\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m\"positions_normal.pkl\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mpositions_normal\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0martifact_path\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_path\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
- "\u001b[1;32md:\\qlib\\qlib\\contrib\\evaluate.py\u001b[0m in \u001b[0;36mbacktest\u001b[1;34m(pred, account, shift, benchmark, verbose, **kwargs)\u001b[0m\n\u001b[0;32m 269\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 270\u001b[0m \u001b[0maccount\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0maccount\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 271\u001b[1;33m \u001b[0mbenchmark\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbenchmark\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 272\u001b[0m )\n\u001b[0;32m 273\u001b[0m \u001b[1;31m# for compatibility of the old API. return the dict positions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
- "\u001b[1;32md:\\qlib\\qlib\\contrib\\backtest\\backtest.py\u001b[0m in \u001b[0;36mbacktest\u001b[1;34m(pred, strategy, trade_exchange, shift, verbose, account, benchmark)\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[0mtrade_exchange\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrade_exchange\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[0mpred_date\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mpred_date\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m \u001b[0mtrade_date\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrade_date\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 103\u001b[0m )\n\u001b[0;32m 104\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
- "\u001b[1;32m\u001b[0m in \u001b[0;36mgenerate_order_list\u001b[1;34m(self, score_series, current, trade_exchange, pred_date, trade_date)\u001b[0m\n\u001b[0;32m 76\u001b[0m \u001b[1;31m# optimize target portfolio\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0minit_weight\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 78\u001b[1;33m \u001b[0mtarget_weight\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcov\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mscore_series\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minit_weight\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 79\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 80\u001b[0m \u001b[0mtarget_weight\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcov\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mscore_series\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
- "\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[1;31m# optimize\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m \u001b[0mw\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_optimize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 103\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 104\u001b[0m \u001b[1;31m# restore index if needed\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
- "\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_optimize\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 126\u001b[0m \u001b[1;31m# mean-variance\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 127\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmethod\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOPT_MVO\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 128\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_optimize_mvo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 129\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 130\u001b[0m \u001b[1;31m# risk parity\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
- "\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_optimize_mvo\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 162\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mlamb\u001b[0m\u001b[0;31m`\u001b[0m \u001b[1;32mis\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mrisk\u001b[0m \u001b[0maversion\u001b[0m \u001b[0mparameter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 163\u001b[0m \"\"\"\n\u001b[1;32m--> 164\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_solve\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_objective_mvo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_constrains\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 165\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 166\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_optimize_rp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mS\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
- "\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_solve\u001b[1;34m(self, n, obj, bounds, cons)\u001b[0m\n\u001b[0;32m 252\u001b[0m \u001b[1;31m# solve\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 253\u001b[0m \u001b[0mx0\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mones\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0mn\u001b[0m \u001b[1;31m# init results\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 254\u001b[1;33m \u001b[0msol\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mso\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mminimize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mwrapped_obj\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbounds\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbounds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconstraints\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcons\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtol\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 255\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msol\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msuccess\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 256\u001b[0m \u001b[0mwarnings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwarn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf\"optimization not success ({sol.status})\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
- "\u001b[1;32m~\\AppData\\Local\\Continuum\\miniconda3\\envs\\qlib\\lib\\site-packages\\scipy\\optimize\\_minimize.py\u001b[0m in \u001b[0;36mminimize\u001b[1;34m(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)\u001b[0m\n\u001b[0;32m 624\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0mmeth\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'slsqp'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 625\u001b[0m return _minimize_slsqp(fun, x0, args, jac, bounds,\n\u001b[1;32m--> 626\u001b[1;33m constraints, callback=callback, **options)\n\u001b[0m\u001b[0;32m 627\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0mmeth\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'trust-constr'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 628\u001b[0m return _minimize_trustregion_constr(fun, x0, args, jac, hess, hessp,\n",
- "\u001b[1;32m~\\AppData\\Local\\Continuum\\miniconda3\\envs\\qlib\\lib\\site-packages\\scipy\\optimize\\slsqp.py\u001b[0m in \u001b[0;36m_minimize_slsqp\u001b[1;34m(func, x0, args, jac, bounds, constraints, maxiter, ftol, iprint, disp, eps, callback, finite_diff_rel_step, **unknown_options)\u001b[0m\n\u001b[0;32m 419\u001b[0m n1, n2, n3)\n\u001b[0;32m 420\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 421\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# objective and constraint evaluation required\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 422\u001b[0m \u001b[0mfx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 423\u001b[0m \u001b[0mc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_eval_constraint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcons\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
- "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
- ]
- }
- ],
- "source": [
- "###################################\n",
- "# prediction, backtest & analysis\n",
- "###################################\n",
- "port_analysis_config = {\n",
- " \"strategy\": strategy,\n",
- " \"backtest\": {\n",
- " \"verbose\": False,\n",
- " \"limit_threshold\": 0.095,\n",
- " \"account\": 100000000,\n",
- " \"benchmark\": benchmark,\n",
- " \"deal_price\": \"close\",\n",
- " \"open_cost\": 0.0005,\n",
- " \"close_cost\": 0.0015,\n",
- " \"min_cost\": 5,\n",
- " },\n",
- "}\n",
- "\n",
- "\n",
- "# backtest and analysis\n",
- "with R.start(experiment_name=\"backtest_analysis\"):\n",
- " recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
- " model = recorder.load_object(\"trained_model\")\n",
- "\n",
- " # prediction\n",
- " recorder = R.get_recorder()\n",
- " ba_rid = recorder.id\n",
- " sr = SignalRecord(model, dataset, recorder)\n",
- " sr.generate()\n",
- "\n",
- " # backtest & analysis\n",
- " par = PortAnaRecord(recorder, port_analysis_config)\n",
- " par.generate()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ]
-}
\ No newline at end of file
diff --git a/examples/run_all_model.py b/examples/run_all_model.py
index 2f6c4299e..05839a125 100644
--- a/examples/run_all_model.py
+++ b/examples/run_all_model.py
@@ -4,18 +4,20 @@
import os
import sys
import fire
+import time
import venv
import glob
import shutil
+import signal
+import inspect
import tempfile
+import traceback
+import functools
import statistics
+import subprocess
from pathlib import Path
from operator import xor
-from subprocess import Popen, PIPE
-from threading import Thread
from pprint import pprint
-from urllib.parse import urlparse
-from urllib.request import urlretrieve
import qlib
from qlib.config import REG_CN
@@ -23,144 +25,53 @@ from qlib.workflow import R
from qlib.workflow.cli import workflow
from qlib.utils import exists_qlib_data
+
# init qlib
provider_uri = "~/.qlib/qlib_data/cn_data"
+exp_folder_name = "run_all_model_records"
+exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name)
+exp_manager = {
+ "class": "MLflowExpManager",
+ "module_path": "qlib.workflow.expm",
+ "kwargs": {
+ "uri": "file:" + exp_path,
+ "default_exp_name": "Experiment",
+ },
+}
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
-qlib.init(provider_uri=provider_uri, region=REG_CN)
+qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
+if os.path.isdir(exp_path):
+ shutil.rmtree(exp_path)
+
+# decorator to check the arguments
+def only_allow_defined_args(function_to_decorate):
+ @functools.wraps(function_to_decorate)
+ def _return_wrapped(*args, **kwargs):
+ """Internal wrapper function."""
+ argspec = inspect.getfullargspec(function_to_decorate)
+ valid_names = set(argspec.args + argspec.kwonlyargs)
+ if "self" in valid_names:
+ valid_names.remove("self")
+ for arg_name in kwargs:
+ if arg_name not in valid_names:
+ raise ValueError("Unknown argument seen '%s', expected: [%s]" % (arg_name, ", ".join(valid_names)))
+ return function_to_decorate(*args, **kwargs)
+
+ return _return_wrapped
-class ExtendedEnvBuilder(venv.EnvBuilder):
- """
- Thie class is modified based on https://docs.python.org/3/library/venv.html.
- This builder installs setuptools and pip so that you can pip or
- easy_install other packages into the created virtual environment.
+# function to handle ctrl z and ctrl c
+def handler(signum, frame):
+ os.system("kill -9 %d" % os.getpid())
- :param nodist: If true, setuptools and pip are not installed into the
- created virtual environment.
- :param nopip: If true, pip is not installed into the created
- virtual environment.
- :param progress: If setuptools or pip are installed, the progress of the
- installation can be monitored by passing a progress
- callable. If specified, it is called with two
- arguments: a string indicating some progress, and a
- context indicating where the string is coming from.
- The context argument can have one of three values:
- 'main', indicating that it is called from virtualize()
- itself, and 'stdout' and 'stderr', which are obtained
- by reading lines from the output streams of a subprocess
- which is used to install the app.
-
- If a callable is not specified, default progress
- information is output to sys.stderr.
- """
-
- def __init__(self, *args, **kwargs):
- self.nodist = kwargs.pop("nodist", False)
- self.nopip = kwargs.pop("nopip", False)
- self.progress = kwargs.pop("progress", None)
- self.verbose = kwargs.pop("verbose", False)
- super().__init__(*args, **kwargs)
-
- def post_setup(self, context):
- """
- Set up any packages which need to be pre-installed into the
- virtual environment being created.
-
- :param context: The information for the virtual environment
- creation request being processed.
- """
- os.environ["VIRTUAL_ENV"] = context.env_dir
- if not self.nodist:
- self.install_setuptools(context)
- # Can't install pip without setuptools
- if not self.nopip and not self.nodist:
- self.install_pip(context)
-
- def reader(self, stream, context):
- """
- Read lines from a subprocess' output stream and either pass to a progress
- callable (if specified) or write progress information to sys.stderr.
- """
- progress = self.progress
- while True:
- s = stream.readline()
- if not s:
- break
- if progress is not None:
- progress(s, context)
- else:
- if not self.verbose:
- sys.stderr.write(".")
- else:
- sys.stderr.write(s.decode("utf-8"))
- sys.stderr.flush()
- stream.close()
-
- def install_script(self, context, name, url):
- _, _, path, _, _, _ = urlparse(url)
- fn = os.path.split(path)[-1]
- binpath = context.bin_path
- distpath = os.path.join(binpath, fn)
- # Download script into the virtual environment's binaries folder
- urlretrieve(url, distpath)
- progress = self.progress
- if self.verbose:
- term = "\n"
- else:
- term = ""
- if progress is not None:
- progress("Installing %s ...%s" % (name, term), "main")
- else:
- sys.stderr.write("Installing %s ...%s" % (name, term))
- sys.stderr.flush()
- # Install in the virtual environment
- args = [context.env_exe, fn]
- p = Popen(args, stdout=PIPE, stderr=PIPE, cwd=binpath)
- t1 = Thread(target=self.reader, args=(p.stdout, "stdout"))
- t1.start()
- t2 = Thread(target=self.reader, args=(p.stderr, "stderr"))
- t2.start()
- p.wait()
- t1.join()
- t2.join()
- if progress is not None:
- progress("done.", "main")
- else:
- sys.stderr.write("done.\n")
- # Clean up - no longer needed
- os.unlink(distpath)
-
- def install_setuptools(self, context):
- """
- Install setuptools in the virtual environment.
-
- :param context: The information for the virtual environment
- creation request being processed.
- """
- url = "https://bootstrap.pypa.io/ez_setup.py"
- self.install_script(context, "setuptools", url)
- # clear up the setuptools archive which gets downloaded
- pred = lambda o: o.startswith("setuptools-") and o.endswith(".tar.gz")
- files = filter(pred, os.listdir(context.bin_path))
- for f in files:
- f = os.path.join(context.bin_path, f)
- os.unlink(f)
-
- def install_pip(self, context):
- """
- Install pip in the virtual environment.
-
- :param context: The information for the virtual environment
- creation request being processed.
- """
- url = "https://bootstrap.pypa.io/get-pip.py"
- self.install_script(context, "pip", url)
+signal.signal(signal.SIGTSTP, handler)
+signal.signal(signal.SIGINT, handler)
# function to calculate the mean and std of a list in the results dictionary
def cal_mean_std(results) -> dict:
@@ -174,6 +85,36 @@ def cal_mean_std(results) -> dict:
return mean_std
+# function to create the environment ofr an anaconda environment
+def create_env():
+ # create env
+ temp_dir = tempfile.mkdtemp()
+ env_path = Path(temp_dir).absolute()
+ sys.stderr.write(f"Creating Virtual Environment with path: {env_path}...\n")
+ execute(f"conda create --prefix {env_path} python=3.7 -y")
+ python_path = env_path / "bin" / "python" # TODO: FIX ME!
+ sys.stderr.write("\n")
+ # get anaconda activate path
+ conda_activate = Path(os.environ["CONDA_PREFIX"]) / "bin" / "activate" # TODO: FIX ME!
+ return env_path, python_path, conda_activate
+
+
+# function to execute the cmd
+def execute(cmd):
+ with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
+ for line in p.stdout:
+ sys.stdout.write(line.split("\b")[0])
+ if "\b" in line:
+ sys.stdout.flush()
+ time.sleep(0.1)
+ sys.stdout.write("\b" * 10 + "\b".join(line.split("\b")[1:-1]))
+
+ if p.returncode != 0:
+ return p.stderr
+ else:
+ return None
+
+
# function to get all the folders benchmark folder
def get_all_folders(models, exclude) -> dict:
folders = dict()
@@ -212,11 +153,12 @@ def get_all_results(folders) -> dict:
result["information_ratio_with_cost"] = list()
result["max_drawdown_with_cost"] = list()
for recorder_id in recorders:
- recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
- metrics = recorder.list_metrics()
- result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
- result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
- result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
+ if recorders[recorder_id].status == "FINISHED":
+ recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
+ metrics = recorder.list_metrics()
+ result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
+ result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
+ result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
results[fn] = result
return results
@@ -237,6 +179,7 @@ def gen_and_save_md_table(metrics):
# function to run the all the models
+@only_allow_defined_args
def run(times=1, models=None, exclude=False):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
@@ -275,53 +218,48 @@ def run(times=1, models=None, exclude=False):
"""
# get all folders
folders = get_all_folders(models, exclude)
- # set up
- compatible = True
- if sys.version_info < (3, 3):
- compatible = False
- elif not hasattr(sys, "base_prefix"):
- compatible = False
- if not compatible:
- raise ValueError("This script is only for use with " "Python 3.3 or later")
- if os.name == "nt":
- use_symlinks = False
- else:
- use_symlinks = True
- builder = ExtendedEnvBuilder(
- system_site_packages=False,
- clear=False,
- symlinks=use_symlinks,
- upgrade=False,
- nodist=False,
- nopip=False,
- verbose=False,
- )
+ # init error messages:
+ errors = dict()
# run all the model for iterations
for fn in folders:
- # create env
- temp_dir = tempfile.mkdtemp()
- env_path = Path(temp_dir).absolute()
- sys.stderr.write(f"Creating Virtual Environment with path: {env_path}...\n")
- builder.create(str(env_path))
- python_path = env_path / "bin" / "python" # TODO: FIX ME!
- sys.stderr.write("\n")
+ # create env by anaconda
+ env_path, python_path, conda_activate = create_env()
# get all files
sys.stderr.write("Retrieving files...\n")
yaml_path, req_path = get_all_files(folders[fn])
sys.stderr.write("\n")
# install requirements.txt
sys.stderr.write("Installing requirements.txt...\n")
- os.system(f"{python_path} -m pip install -r {req_path}")
+ execute(f"{python_path} -m pip install -r {req_path}")
sys.stderr.write("\n")
+ # setup gpu for tft
+ if fn == "TFT":
+ execute(
+ f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn"
+ )
+ sys.stderr.write("\n")
# install qlib
sys.stderr.write("Installing qlib...\n")
- os.system(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
- os.system(f"{python_path} -m pip install -e git+https://github.com/you-n-g/qlib#egg=pyqlib") # TODO: FIX ME!
+ execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
+ if fn == "TFT":
+ execute(
+ f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e git+https://github.com/you-n-g/qlib#egg=pyqlib"
+ ) # TODO: FIX ME!
+ else:
+ execute(
+ f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e git+https://github.com/you-n-g/qlib#egg=pyqlib"
+ ) # TODO: FIX ME!
sys.stderr.write("\n")
# run workflow_by_config for multiple times
for i in range(times):
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
- os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}")
+ errs = execute(
+ f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
+ )
+ if errs is not None:
+ _errs = errors.get(fn, {})
+ _errs.update({i: errs})
+ errors[fn] = _errs
sys.stderr.write("\n")
# remove env
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
@@ -335,13 +273,12 @@ def run(times=1, models=None, exclude=False):
# generating md table
sys.stderr.write(f"Generating markdown table...\n")
gen_and_save_md_table(results)
+ sys.stderr.write("\n")
+ # print erros
+ sys.stderr.write(f"Here are some of the errors of the models...\n")
+ pprint(errors)
+ sys.stderr.write("\n")
if __name__ == "__main__":
- rc = 1
- try:
- fire.Fire(run) # run all the model
- rc = 0
- except Exception as e:
- print("Error: %s" % e, file=sys.stderr)
- sys.exit(rc)
+ fire.Fire(run) # run all the model
diff --git a/examples/workflow_by_code.ipynb b/examples/workflow_by_code.ipynb
index 692e52078..5a992e339 100644
--- a/examples/workflow_by_code.ipynb
+++ b/examples/workflow_by_code.ipynb
@@ -1,5 +1,12 @@
{
"cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -10,14 +17,43 @@
"# Licensed under the MIT License."
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "import sys, site\n",
+ "from pathlib import Path\n",
+ "\n",
+ "\n",
+ "try:\n",
+ " import qlib\n",
+ "except ImportError:\n",
+ " # install qlib\n",
+ " ! pip install pyqlib\n",
+ " # reload\n",
+ " site.main()\n",
+ "\n",
+ "scripts_dir = Path.cwd().parent.joinpath(\"scripts\")\n",
+ "if not scripts_dir.joinpath(\"get_data.py\").exists():\n",
+ " # download get_data.py script\n",
+ " scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
+ " scripts_dir.mkdir(parents=True, exist_ok=True)\n",
+ " import requests\n",
+ " with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n",
+ " with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
+ " fp.write(resp.content)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "import sys\n",
- "from pathlib import Path\n",
"\n",
"import qlib\n",
"import pandas as pd\n",
@@ -32,7 +68,7 @@
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
"from qlib.workflow import R\n",
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
- "from qlib.utils import flatten_dict"
+ "from qlib.utils import flatten_dict\n"
]
},
{
@@ -48,7 +84,7 @@
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
"if not exists_qlib_data(provider_uri):\n",
" print(f\"Qlib data is not found in {provider_uri}\")\n",
- " sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
+ " sys.path.append(str(scripts_dir))\n",
" from get_data import GetData\n",
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
@@ -202,7 +238,9 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [],
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
@@ -320,7 +358,8 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3"
+ "pygments_lexer": "ipython3",
+ "version": "3.7.9"
},
"toc": {
"base_numbering": 1,
diff --git a/examples/workflow_by_code_alstm.py b/examples/workflow_by_code_alstm.py
deleted file mode 100644
index 8fd9e3565..000000000
--- a/examples/workflow_by_code_alstm.py
+++ /dev/null
@@ -1,138 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import sys
-from pathlib import Path
-
-import qlib
-import pandas as pd
-from qlib.config import REG_CN
-from qlib.contrib.strategy.strategy import TopkDropoutStrategy
-from qlib.contrib.evaluate import (
- backtest as normal_backtest,
- risk_analysis,
-)
-from qlib.utils import exists_qlib_data
-from qlib.utils import init_instance_by_config
-
-if __name__ == "__main__":
-
- # use default data
- provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
- if not exists_qlib_data(provider_uri):
- print(f"Qlib data is not found in {provider_uri}")
- sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
- from get_data import GetData
-
- GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
-
- qlib.init(provider_uri=provider_uri, region=REG_CN)
-
- MARKET = "csi300"
- BENCHMARK = "SH000300"
-
- ###################################
- # train model
- ###################################
- DATA_HANDLER_CONFIG = {
- "start_time": "2008-01-01",
- "end_time": "2020-08-01",
- "fit_start_time": "2008-01-01",
- "fit_end_time": "2014-12-31",
- "instruments": MARKET,
- }
-
- TRAINER_CONFIG = {
- "train_start_time": "2008-01-01",
- "train_end_time": "2014-12-31",
- "validate_start_time": "2015-01-01",
- "validate_end_time": "2016-12-31",
- "test_start_time": "2017-01-01",
- "test_end_time": "2020-08-01",
- }
-
- task = {
- "model": {
- "class": "ALSTM",
- "module_path": "qlib.contrib.model.pytorch_alstm",
- "kwargs": {
- "d_feat": 6,
- "hidden_size": 64,
- "num_layers": 2,
- "dropout": 0.0,
- "n_epochs": 200,
- "lr": 1e-3,
- "early_stop": 20,
- "batch_size": 800,
- "metric": "IC",
- "loss": "mse",
- "seed": 0,
- "GPU": "0",
- "rnn_type": "GRU",
- },
- },
- "dataset": {
- "class": "DatasetH",
- "module_path": "qlib.data.dataset",
- "kwargs": {
- "handler": {
- "class": "ALPHA360_Denoise",
- "module_path": "qlib.contrib.data.handler",
- "kwargs": DATA_HANDLER_CONFIG,
- },
- "segments": {
- "train": ("2008-01-01", "2014-12-31"),
- "valid": ("2015-01-01", "2016-12-31"),
- "test": ("2017-01-01", "2020-08-01"),
- },
- },
- }
- # You shoud record the data in specific sequence
- # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
- }
-
- model = init_instance_by_config(task["model"])
- dataset = init_instance_by_config(task["dataset"])
- model.fit(dataset)
-
- pred_score = model.predict(dataset)
-
- # save pred_score to file
- pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
- pred_score_path.parent.mkdir(exist_ok=True, parents=True)
- pred_score.to_pickle(pred_score_path)
-
- ###################################
- # backtest
- ###################################
- STRATEGY_CONFIG = {
- "topk": 50,
- "n_drop": 5,
- }
- BACKTEST_CONFIG = {
- "verbose": False,
- "limit_threshold": 0.095,
- "account": 100000000,
- "benchmark": BENCHMARK,
- "deal_price": "close",
- "open_cost": 0.0005,
- "close_cost": 0.0015,
- "min_cost": 5,
- }
-
- # use default strategy
- # custom Strategy, refer to: TODO: Strategy API url
- strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
- report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
-
- ###################################
- # analyze
- # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
- ###################################
- analysis = dict()
- analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- print(analysis_df)
diff --git a/examples/workflow_by_code_gats.py b/examples/workflow_by_code_gats.py
deleted file mode 100644
index 20f3ae552..000000000
--- a/examples/workflow_by_code_gats.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import sys
-from pathlib import Path
-
-import qlib
-import pandas as pd
-from qlib.config import REG_CN
-
-from qlib.contrib.strategy.strategy import TopkDropoutStrategy
-from qlib.contrib.evaluate import (
- backtest as normal_backtest,
- risk_analysis,
-)
-from qlib.utils import exists_qlib_data
-from qlib.utils import init_instance_by_config
-
-
-if __name__ == "__main__":
-
- # use default data
- provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
- if not exists_qlib_data(provider_uri):
- print(f"Qlib data is not found in {provider_uri}")
- sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
- from get_data import GetData
-
- GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
-
- qlib.init(provider_uri=provider_uri, region=REG_CN)
-
- MARKET = "csi300"
- BENCHMARK = "SH000300"
-
- ###################################
- # train model
- ###################################
- DATA_HANDLER_CONFIG = {
- "start_time": "2008-01-01",
- "end_time": "2020-08-01",
- "fit_start_time": "2008-01-01",
- "fit_end_time": "2014-12-31",
- "instruments": MARKET,
- }
-
- TRAINER_CONFIG = {
- "train_start_time": "2008-01-01",
- "train_end_time": "2014-12-31",
- "validate_start_time": "2015-01-01",
- "validate_end_time": "2016-12-31",
- "test_start_time": "2017-01-01",
- "test_end_time": "2020-08-01",
- }
-
- task = {
- "model": {
- "class": "GAT",
- "module_path": "qlib.contrib.model.pytorch_gats",
- "kwargs": {
- "d_feat": 6,
- "hidden_size": 64,
- "num_layers": 2,
- "dropout": 0.7,
- "n_epochs": 200,
- "lr": 1e-4,
- "early_stop": 20,
- "metric": "loss",
- "loss": "mse",
- "base_model": "LSTM",
- "with_pretrain": True,
- "seed": 0,
- "GPU": "0",
- },
- },
- "dataset": {
- "class": "DatasetH",
- "module_path": "qlib.data.dataset",
- "kwargs": {
- "handler": {
- "class": "ALPHA360_Denoise",
- "module_path": "qlib.contrib.data.handler",
- "kwargs": DATA_HANDLER_CONFIG,
- },
- "segments": {
- "train": ("2008-01-01", "2014-12-31"),
- "valid": ("2015-01-01", "2016-12-31"),
- "test": ("2017-01-01", "2020-08-01"),
- },
- },
- }
- # You shoud record the data in specific sequence
- # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
- }
-
- model = init_instance_by_config(task["model"])
- dataset = init_instance_by_config(task["dataset"])
- model.fit(dataset)
-
- pred_score = model.predict(dataset)
-
- # save pred_score to file
- pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
- pred_score_path.parent.mkdir(exist_ok=True, parents=True)
- pred_score.to_pickle(pred_score_path)
-
- ###################################
- # backtest
- ###################################
- STRATEGY_CONFIG = {
- "topk": 50,
- "n_drop": 5,
- }
- BACKTEST_CONFIG = {
- "verbose": False,
- "limit_threshold": 0.095,
- "account": 100000000,
- "benchmark": BENCHMARK,
- "deal_price": "close",
- "open_cost": 0.0005,
- "close_cost": 0.0015,
- "min_cost": 5,
- }
-
- # use default strategy
- # custom Strategy, refer to: TODO: Strategy API url
- strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
- report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
-
- ###################################
- # analyze
- # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
- ###################################
- analysis = dict()
- analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- print(analysis_df)
diff --git a/examples/workflow_by_code_gru.py b/examples/workflow_by_code_gru.py
deleted file mode 100644
index dece520d1..000000000
--- a/examples/workflow_by_code_gru.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import sys
-from pathlib import Path
-
-import qlib
-import pandas as pd
-from qlib.config import REG_CN
-from qlib.contrib.model.pytorch_gru import GRU
-from qlib.contrib.data.handler import ALPHA360_Denoise
-from qlib.contrib.strategy.strategy import TopkDropoutStrategy
-from qlib.contrib.evaluate import (
- backtest as normal_backtest,
- risk_analysis,
-)
-from qlib.utils import exists_qlib_data
-
-# from qlib.model.learner import train_model
-from qlib.utils import init_instance_by_config
-
-import pickle
-
-if __name__ == "__main__":
-
- # use default data
- provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
- if not exists_qlib_data(provider_uri):
- print(f"Qlib data is not found in {provider_uri}")
- sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
- from get_data import GetData
-
- GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
-
- qlib.init(provider_uri=provider_uri, region=REG_CN)
-
- MARKET = "csi300"
- BENCHMARK = "SH000300"
-
- ###################################
- # train model
- ###################################
- DATA_HANDLER_CONFIG = {
- "start_time": "2008-01-01",
- "end_time": "2020-08-01",
- "fit_start_time": "2008-01-01",
- "fit_end_time": "2014-12-31",
- "instruments": MARKET,
- }
-
- TRAINER_CONFIG = {
- "train_start_time": "2008-01-01",
- "train_end_time": "2014-12-31",
- "validate_start_time": "2015-01-01",
- "validate_end_time": "2016-12-31",
- "test_start_time": "2017-01-01",
- "test_end_time": "2020-08-01",
- }
-
- task = {
- "model": {
- "class": "GRU",
- "module_path": "qlib.contrib.model.pytorch_gru",
- "kwargs": {
- "d_feat": 6,
- "hidden_size": 64,
- "num_layers": 2,
- "dropout": 0.0,
- "n_epochs": 200,
- "lr": 1e-3,
- "early_stop": 20,
- "batch_size": 800,
- "metric": "loss",
- "loss": "mse",
- "seed": 0,
- "GPU": 0,
- },
- },
- "dataset": {
- "class": "DatasetH",
- "module_path": "qlib.data.dataset",
- "kwargs": {
- "handler": {
- "class": "ALPHA360_Denoise",
- "module_path": "qlib.contrib.data.handler",
- "kwargs": DATA_HANDLER_CONFIG,
- },
- "segments": {
- "train": ("2008-01-01", "2014-12-31"),
- "valid": ("2015-01-01", "2016-12-31"),
- "test": ("2017-01-01", "2020-08-01"),
- },
- },
- }
- # You shoud record the data in specific sequence
- # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
- }
-
- # model = train_model(task)
- model = init_instance_by_config(task["model"])
- dataset = init_instance_by_config(task["dataset"])
- model.fit(dataset)
-
- pred_score = model.predict(dataset)
-
- # save pred_score to file
- pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
- pred_score_path.parent.mkdir(exist_ok=True, parents=True)
- pred_score.to_pickle(pred_score_path)
-
- ###################################
- # backtest
- ###################################
- STRATEGY_CONFIG = {
- "topk": 50,
- "n_drop": 5,
- }
- BACKTEST_CONFIG = {
- "verbose": False,
- "limit_threshold": 0.095,
- "account": 100000000,
- "benchmark": BENCHMARK,
- "deal_price": "close",
- "open_cost": 0.0005,
- "close_cost": 0.0015,
- "min_cost": 5,
- }
-
- # use default strategy
- # custom Strategy, refer to: TODO: Strategy API url
- strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
- report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
-
- ###################################
- # analyze
- # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
- ###################################
- analysis = dict()
- analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- print(analysis_df)
diff --git a/examples/workflow_by_code_hats.py b/examples/workflow_by_code_hats.py
deleted file mode 100644
index 64bc860b4..000000000
--- a/examples/workflow_by_code_hats.py
+++ /dev/null
@@ -1,136 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import sys
-from pathlib import Path
-import qlib
-import pandas as pd
-from qlib.config import REG_CN
-from qlib.contrib.strategy.strategy import TopkDropoutStrategy
-from qlib.contrib.evaluate import (
- backtest as normal_backtest,
- risk_analysis,
-)
-from qlib.utils import exists_qlib_data
-from qlib.utils import init_instance_by_config
-
-if __name__ == "__main__":
-
- # use default data
- provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
- if not exists_qlib_data(provider_uri):
- print(f"Qlib data is not found in {provider_uri}")
- sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
- from get_data import GetData
-
- GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
-
- qlib.init(provider_uri=provider_uri, region=REG_CN)
-
- MARKET = "csi300"
- BENCHMARK = "SH000300"
-
- ###################################
- # train model
- ###################################
- DATA_HANDLER_CONFIG = {
- "start_time": "2008-01-01",
- "end_time": "2020-08-01",
- "fit_start_time": "2008-01-01",
- "fit_end_time": "2014-12-31",
- "instruments": MARKET,
- }
-
- TRAINER_CONFIG = {
- "train_start_time": "2008-01-01",
- "train_end_time": "2014-12-31",
- "validate_start_time": "2015-01-01",
- "validate_end_time": "2016-12-31",
- "test_start_time": "2017-01-01",
- "test_end_time": "2020-08-01",
- }
-
- task = {
- "model": {
- "class": "HATS",
- "module_path": "qlib.contrib.model.pytorch_hats",
- "kwargs": {
- "d_feat": 6,
- "hidden_size": 64,
- "num_layers": 2,
- "dropout": 0.7,
- "n_epochs": 200,
- "lr": 1e-4,
- "early_stop": 20,
- "metric": "loss",
- "loss": "mse",
- "base_model": "LSTM",
- "seed": 0,
- "GPU": "2",
- },
- },
- "dataset": {
- "class": "DatasetH",
- "module_path": "qlib.data.dataset",
- "kwargs": {
- "handler": {
- "class": "ALPHA360_Denoise",
- "module_path": "qlib.contrib.data.handler",
- "kwargs": DATA_HANDLER_CONFIG,
- },
- "segments": {
- "train": ("2008-01-01", "2014-12-31"),
- "valid": ("2015-01-01", "2016-12-31"),
- "test": ("2017-01-01", "2020-08-01"),
- },
- },
- }
- # You shoud record the data in specific sequence
- # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
- }
-
- model = init_instance_by_config(task["model"])
- dataset = init_instance_by_config(task["dataset"])
- model.fit(dataset, save_path="benchmarks/HATS/model_hat.pkl")
-
- pred_score = model.predict(dataset)
-
- # save pred_score to file
- pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
- pred_score_path.parent.mkdir(exist_ok=True, parents=True)
- pred_score.to_pickle(pred_score_path)
-
- ###################################
- # backtest
- ###################################
- STRATEGY_CONFIG = {
- "topk": 50,
- "n_drop": 5,
- }
- BACKTEST_CONFIG = {
- "verbose": False,
- "limit_threshold": 0.095,
- "account": 100000000,
- "benchmark": BENCHMARK,
- "deal_price": "close",
- "open_cost": 0.0005,
- "close_cost": 0.0015,
- "min_cost": 5,
- }
-
- # use default strategy
- # custom Strategy, refer to: TODO: Strategy API url
- strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
- report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
-
- ###################################
- # analyze
- # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
- ###################################
- analysis = dict()
- analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- print(analysis_df)
diff --git a/examples/workflow_by_code_lstm.py b/examples/workflow_by_code_lstm.py
deleted file mode 100644
index ee50c9aff..000000000
--- a/examples/workflow_by_code_lstm.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import sys
-from pathlib import Path
-
-import qlib
-import pandas as pd
-from qlib.config import REG_CN
-from qlib.contrib.model.pytorch_lstm import LSTM
-from qlib.contrib.data.handler import ALPHA360_Denoise
-from qlib.contrib.strategy.strategy import TopkDropoutStrategy
-from qlib.contrib.evaluate import (
- backtest as normal_backtest,
- risk_analysis,
-)
-from qlib.utils import exists_qlib_data
-
-# from qlib.model.learner import train_model
-from qlib.utils import init_instance_by_config
-
-import pickle
-
-if __name__ == "__main__":
-
- # use default data
- provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
- if not exists_qlib_data(provider_uri):
- print(f"Qlib data is not found in {provider_uri}")
- sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
- from get_data import GetData
-
- GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
-
- qlib.init(provider_uri=provider_uri, region=REG_CN)
-
- MARKET = "csi300"
- BENCHMARK = "SH000300"
-
- ###################################
- # train model
- ###################################
- DATA_HANDLER_CONFIG = {
- "start_time": "2008-01-01",
- "end_time": "2020-08-01",
- "fit_start_time": "2008-01-01",
- "fit_end_time": "2014-12-31",
- "instruments": MARKET,
- }
-
- TRAINER_CONFIG = {
- "train_start_time": "2008-01-01",
- "train_end_time": "2014-12-31",
- "validate_start_time": "2015-01-01",
- "validate_end_time": "2016-12-31",
- "test_start_time": "2017-01-01",
- "test_end_time": "2020-08-01",
- }
-
- task = {
- "model": {
- "class": "LSTM",
- "module_path": "qlib.contrib.model.pytorch_lstm",
- "kwargs": {
- "d_feat": 6,
- "hidden_size": 64,
- "num_layers": 2,
- "dropout": 0.0,
- "n_epochs": 200,
- "lr": 1e-3,
- "early_stop": 20,
- "batch_size": 800,
- "metric": "IC",
- "loss": "mse",
- "seed": 0,
- "GPU": 0,
- },
- },
- "dataset": {
- "class": "DatasetH",
- "module_path": "qlib.data.dataset",
- "kwargs": {
- "handler": {
- "class": "ALPHA360_Denoise",
- "module_path": "qlib.contrib.data.handler",
- "kwargs": DATA_HANDLER_CONFIG,
- },
- "segments": {
- "train": ("2008-01-01", "2014-12-31"),
- "valid": ("2015-01-01", "2016-12-31"),
- "test": ("2017-01-01", "2020-08-01"),
- },
- },
- }
- # You shoud record the data in specific sequence
- # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
- }
-
- # model = train_model(task)
- model = init_instance_by_config(task["model"])
- dataset = init_instance_by_config(task["dataset"])
- model.fit(dataset)
-
- pred_score = model.predict(dataset)
-
- # save pred_score to file
- pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
- pred_score_path.parent.mkdir(exist_ok=True, parents=True)
- pred_score.to_pickle(pred_score_path)
-
- ###################################
- # backtest
- ###################################
- STRATEGY_CONFIG = {
- "topk": 50,
- "n_drop": 5,
- }
- BACKTEST_CONFIG = {
- "verbose": False,
- "limit_threshold": 0.095,
- "account": 100000000,
- "benchmark": BENCHMARK,
- "deal_price": "close",
- "open_cost": 0.0005,
- "close_cost": 0.0015,
- "min_cost": 5,
- }
-
- # use default strategy
- # custom Strategy, refer to: TODO: Strategy API url
- strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
- report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
-
- ###################################
- # analyze
- # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
- ###################################
- analysis = dict()
- analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- print(analysis_df)
diff --git a/examples/workflow_by_code_sfm.py b/examples/workflow_by_code_sfm.py
deleted file mode 100644
index 5bd91ded8..000000000
--- a/examples/workflow_by_code_sfm.py
+++ /dev/null
@@ -1,158 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import sys
-from pathlib import Path
-
-import qlib
-import pandas as pd
-from qlib.config import REG_CN
-from qlib.contrib.model.pytorch_gru import GRU
-from qlib.contrib.data.handler import ALPHA360_Denoise
-from qlib.contrib.strategy.strategy import TopkDropoutStrategy
-from qlib.contrib.evaluate import (
- backtest as normal_backtest,
- risk_analysis,
-)
-from qlib.utils import exists_qlib_data
-from qlib.utils import init_instance_by_config
-
-import pickle
-
-if __name__ == "__main__":
-
- # use default data
- provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
- if not exists_qlib_data(provider_uri):
- print(f"Qlib data is not found in {provider_uri}")
- sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
- from get_data import GetData
-
- GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
-
- qlib.init(provider_uri=provider_uri, region=REG_CN)
-
- MARKET = "csi300"
- BENCHMARK = "SH000300"
-
- ###################################
- # train model
- ###################################
- DATA_HANDLER_CONFIG = {
- "start_time": "2008-01-01",
- "end_time": "2020-08-01",
- "fit_start_time": "2008-01-01",
- "fit_end_time": "2014-12-31",
- "instruments": MARKET,
- }
-
- TRAINER_CONFIG = {
- "train_start_time": "2008-01-01",
- "train_end_time": "2014-12-31",
- "validate_start_time": "2015-01-01",
- "validate_end_time": "2016-12-31",
- "test_start_time": "2017-01-01",
- "test_end_time": "2020-08-01",
- }
-
- task = {
- "model": {
- "class": "SFM",
- "module_path": "qlib.contrib.model.pytorch_sfm",
- "kwargs": {
- "d_feat": 6,
- "hidden_size": 64,
- "output_dim": 32,
- "freq_dim": 25,
- "dropout_W": 0.5,
- "dropout_U": 0.5,
- "n_epochs": 15,
- "lr": 1e-3,
- "metric": "",
- "batch_size": 1600,
- "early_stop": 20,
- "eval_steps": 5,
- "loss": "mse",
- "lr_decay": 0.96,
- "lr_decay_steps": 100,
- "optimizer": "adam",
- "GPU": 3,
- "seed": 710,
- },
- },
- "dataset": {
- "class": "DatasetH",
- "module_path": "qlib.data.dataset",
- "kwargs": {
- "handler": {
- "class": "ALPHA360_Denoise",
- "module_path": "qlib.contrib.data.handler",
- "kwargs": DATA_HANDLER_CONFIG,
- },
- "segments": {
- "train": ("2008-01-01", "2014-12-31"),
- "valid": ("2015-01-01", "2016-12-31"),
- "test": ("2017-01-01", "2020-08-01"),
- },
- },
- }
- # You shoud record the data in specific sequence
- # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
- }
-
- # model = train_model(task)
- model = init_instance_by_config(task["model"])
- dataset = init_instance_by_config(task["dataset"])
- model.fit(dataset)
-
- pred_score = model.predict(dataset)
-
- # save pred_score to file
- pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
- pred_score_path.parent.mkdir(exist_ok=True, parents=True)
- pred_score.to_pickle(pred_score_path)
-
- ###################################
- # backtest
- ###################################
- STRATEGY_CONFIG = {
- "topk": 50,
- "n_drop": 5,
- }
- BACKTEST_CONFIG = {
- "verbose": False,
- "limit_threshold": 0.095,
- "account": 100000000,
- "benchmark": BENCHMARK,
- "deal_price": "close",
- "open_cost": 0.0005,
- "close_cost": 0.0015,
- "min_cost": 5,
- }
-
- # use default strategy
- # custom Strategy, refer to: TODO: Strategy API url
- strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
- report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
-
- ###################################
- # analyze
- # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
- ###################################
- analysis = dict()
- analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- print(analysis_df)
diff --git a/examples/workflow_by_code_tabnet.py b/examples/workflow_by_code_tabnet.py
deleted file mode 100644
index 3778b9d59..000000000
--- a/examples/workflow_by_code_tabnet.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import sys
-from pathlib import Path
-
-import qlib
-import pandas as pd
-from qlib.config import REG_CN
-from qlib.contrib.model.tabnet import TabNetModel
-from qlib.contrib.data.handler import ALPHA360_Denoise
-from qlib.contrib.strategy.strategy import TopkDropoutStrategy
-from qlib.contrib.evaluate import (
- backtest as normal_backtest,
- risk_analysis,
-)
-from qlib.utils import exists_qlib_data
-
-# from qlib.model.learner import train_model
-from qlib.utils import init_instance_by_config
-
-import pickle
-
-if __name__ == "__main__":
-
- # use default data
- provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
- if not exists_qlib_data(provider_uri):
- print(f"Qlib data is not found in {provider_uri}")
- sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
- from get_data import GetData
-
- GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
-
- qlib.init(provider_uri=provider_uri, region=REG_CN)
-
- MARKET = "csi300"
- BENCHMARK = "SH000300"
-
- ###################################
- # train model
- ###################################
- DATA_HANDLER_CONFIG = {
- "start_time": "2008-01-01",
- "end_time": "2020-08-01",
- "fit_start_time": "2008-01-01",
- "fit_end_time": "2014-12-31",
- "instruments": MARKET,
- }
-
- TRAINER_CONFIG = {
- "train_start_time": "2008-01-01",
- "train_end_time": "2014-12-31",
- "validate_start_time": "2015-01-01",
- "validate_end_time": "2016-12-31",
- "test_start_time": "2017-01-01",
- "test_end_time": "2020-08-01",
- }
-
- task = {
- "model": {
- "class": "TabNetModel",
- "module_path": "qlib.contrib.model.tabnet",
- "kwargs": {
- "n_d": 8,
- "n_a": 8,
- "n_steps": 3,
- "gamma": 1.3,
- "n_independent": 2,
- "n_shared": 2,
- "seed": 0,
- "momentum": 0.02,
- "lambda_sparse": 1e-3,
- "optimizer_params": {"lr": 2e-3},
- },
- },
- "dataset": {
- "class": "DatasetH",
- "module_path": "qlib.data.dataset",
- "kwargs": {
- "handler": {
- "class": "ALPHA360_Denoise",
- "module_path": "qlib.contrib.data.handler",
- "kwargs": DATA_HANDLER_CONFIG,
- },
- "segments": {
- "train": ("2008-01-01", "2014-12-31"),
- "valid": ("2015-01-01", "2016-12-31"),
- "test": ("2017-01-01", "2020-08-01"),
- },
- },
- }
- # You shoud record the data in specific sequence
- # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
- }
-
- # model = train_model(task)
- model = init_instance_by_config(task["model"])
- dataset = init_instance_by_config(task["dataset"])
- model.fit(dataset)
-
- pred_score = model.predict(dataset)
-
- # save pred_score to file
- pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
- pred_score_path.parent.mkdir(exist_ok=True, parents=True)
- pred_score.to_pickle(pred_score_path)
-
- ###################################
- # backtest
- ###################################
- STRATEGY_CONFIG = {
- "topk": 50,
- "n_drop": 5,
- }
- BACKTEST_CONFIG = {
- "verbose": False,
- "limit_threshold": 0.095,
- "account": 100000000,
- "benchmark": BENCHMARK,
- "deal_price": "close",
- "open_cost": 0.0005,
- "close_cost": 0.0015,
- "min_cost": 5,
- }
-
- # use default strategy
- # custom Strategy, refer to: TODO: Strategy API url
- strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
- report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
-
- ###################################
- # analyze
- # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
- ###################################
- analysis = dict()
- analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- print(analysis_df)
diff --git a/qlib/contrib/model/catboost_model.py b/qlib/contrib/model/catboost_model.py
index bba006c35..01830d1b5 100644
--- a/qlib/contrib/model/catboost_model.py
+++ b/qlib/contrib/model/catboost_model.py
@@ -41,7 +41,9 @@ class CatBoostModel(Model):
**kwargs
):
df_train, df_valid = dataset.prepare(
- ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ ["train", "valid"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py
index 1b23d2401..40c2f8226 100644
--- a/qlib/contrib/model/pytorch_alstm.py
+++ b/qlib/contrib/model/pytorch_alstm.py
@@ -11,7 +11,12 @@ import pandas as pd
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
-from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
+from ...utils import (
+ unpack_archive_with_buffer,
+ save_multiple_parts_file,
+ create_save_path,
+ drop_nan_by_y_index,
+)
from ...log import get_module_logger, TimeInspector
import torch
@@ -109,7 +114,10 @@ class ALSTM(Model):
)
self.ALSTM_model = ALSTMModel(
- d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
+ d_feat=self.d_feat,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ dropout=self.dropout,
)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
@@ -141,7 +149,7 @@ class ALSTM(Model):
mask = torch.isfinite(label)
- if self.metric == "" or self.metric == "loss": # use loss
+ if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
@@ -219,7 +227,9 @@ class ALSTM(Model):
):
df_train, df_valid, df_test = dataset.prepare(
- ["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ ["train", "valid", "test"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
@@ -328,10 +338,16 @@ class ALSTMModel(nn.Module):
)
self.fc_out = nn.Linear(in_features=self.hid_size * 2, out_features=1)
self.att_net = nn.Sequential()
- self.att_net.add_module("att_fc_in", nn.Linear(in_features=self.hid_size, out_features=int(self.hid_size / 2)))
+ self.att_net.add_module(
+ "att_fc_in",
+ nn.Linear(in_features=self.hid_size, out_features=int(self.hid_size / 2)),
+ )
self.att_net.add_module("att_dropout", torch.nn.Dropout(self.dropout))
self.att_net.add_module("att_act", nn.Tanh())
- self.att_net.add_module("att_fc_out", nn.Linear(in_features=int(self.hid_size / 2), out_features=1, bias=False))
+ self.att_net.add_module(
+ "att_fc_out",
+ nn.Linear(in_features=int(self.hid_size / 2), out_features=1, bias=False),
+ )
self.att_net.add_module("att_softmax", nn.Softmax(dim=1))
def forward(self, inputs):
diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py
old mode 100755
new mode 100644
index 77a02a9b2..e9cbcf9cb
--- a/qlib/contrib/model/pytorch_gats.py
+++ b/qlib/contrib/model/pytorch_gats.py
@@ -12,6 +12,7 @@ import copy
from ...utils import create_save_path
from ...log import get_module_logger
+
import torch
import torch.nn as nn
import torch.optim as optim
@@ -19,10 +20,12 @@ import torch.optim as optim
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
+from ...contrib.model.pytorch_lstm import LSTMModel
+from ...contrib.model.pytorch_gru import GRUModel
-class GAT(Model):
- """GAT Model
+class GATs(Model):
+ """GATs Model
Parameters
----------
@@ -57,8 +60,8 @@ class GAT(Model):
**kwargs
):
# Set logger.
- self.logger = get_module_logger("GAT")
- self.logger.info("GAT pytorch version...")
+ self.logger = get_module_logger("GATs")
+ self.logger.info("GATs pytorch version...")
# set hyper-parameters.
self.d_feat = d_feat
@@ -78,7 +81,7 @@ class GAT(Model):
self.seed = seed
self.logger.info(
- "GAT parameters setting:"
+ "GATs parameters setting:"
"\nd_feat : {}"
"\nhidden_size : {}"
"\nnum_layers : {}"
@@ -149,18 +152,18 @@ class GAT(Model):
mask = torch.isfinite(label)
- if self.metric == "" or self.metric == "loss": # use loss
+ if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def get_daily_inter(self, df, shuffle=False):
- # organize the train data into daily inter as daily batches
+ # organize the train data into daily batches
daily_count = df.groupby(level=0).size().values
daily_index = np.roll(np.cumsum(daily_count), 1)
daily_index[0] = 0
if shuffle:
- # shuffle the daily inter data
+ # shuffle data
daily_shuffle = list(zip(daily_index, daily_count))
np.random.shuffle(daily_shuffle)
daily_index, daily_count = zip(*daily_shuffle)
@@ -172,7 +175,7 @@ class GAT(Model):
y_train_values = np.squeeze(y_train.values)
self.GAT_model.train()
- # organize the train data into daily inter as daily batches
+ # organize the train data into daily batches
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
for idx, count in zip(daily_index, daily_count):
@@ -203,7 +206,7 @@ class GAT(Model):
scores = []
losses = []
- # organize the test data into daily inter as daily batches
+ # organize the test data into daily batches
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
for idx, count in zip(daily_index, daily_count):
@@ -233,7 +236,9 @@ class GAT(Model):
):
df_train, df_valid, df_test = dataset.prepare(
- ["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ ["train", "valid", "test"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
@@ -251,15 +256,13 @@ class GAT(Model):
if self.with_pretrain:
self.logger.info("Loading pretrained model...")
if self.base_model == "LSTM":
- from ...contrib.model.pytorch_lstm import LSTMModel
-
pretrained_model = LSTMModel()
pretrained_model.load_state_dict(torch.load("benchmarks/LSTM/model_lstm_csi300.pkl"))
- elif self.base_model == "GRU":
- from ...contrib.model.pytorch_gru import GRUModel
+ elif self.base_model == "GRU":
pretrained_model = GRUModel()
pretrained_model.load_state_dict(torch.load("benchmarks/GRU/model_gru_csi300.pkl"))
+
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
model_dict.update(pretrained_dict)
@@ -269,7 +272,6 @@ class GAT(Model):
# train
self.logger.info("training...")
self._fitted = True
- # return
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -310,7 +312,7 @@ class GAT(Model):
x_values = x_test.values
preds = []
- # organize the data into daily inter as daily batches
+ # organize the data into daily batches
daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
for idx, count in zip(daily_index, daily_count):
@@ -355,22 +357,29 @@ class GATModel(nn.Module):
raise ValueError("unknown base model name `%s`" % base_model)
self.hidden_size = hidden_size
- self.bn1 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
- self.fc = nn.Linear(hidden_size, hidden_size)
- self.bn2 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
+ self.d_feat = d_feat
+ self.transformation = nn.Linear(self.hidden_size, self.hidden_size)
+ self.a = nn.Parameter(torch.randn(self.hidden_size * 2, 1))
+ self.a.requires_grad = True
+ self.fc = nn.Linear(self.hidden_size, self.hidden_size)
self.fc_out = nn.Linear(hidden_size, 1)
self.leaky_relu = nn.LeakyReLU()
self.softmax = nn.Softmax(dim=1)
- self.d_feat = d_feat
- def cal_convariance(self, x, y): # the 2nd dimension of x and y are the same
- e_x = torch.mean(x, dim=1).reshape(-1, 1)
- e_y = torch.mean(y, dim=1).reshape(-1, 1)
- e_x_e_y = e_x.mm(torch.t(e_y))
- x_extend = x.reshape(x.shape[0], 1, x.shape[1]).repeat(1, y.shape[0], 1)
- y_extend = y.reshape(1, y.shape[0], y.shape[1]).repeat(x.shape[0], 1, 1)
- e_xy = torch.mean(x_extend * y_extend, dim=2)
- return e_xy - e_x_e_y
+ def cal_attention(self, x, y):
+ x = self.transformation(x)
+ y = self.transformation(y)
+
+ sample_num = x.shape[0]
+ dim = x.shape[1]
+ e_x = x.expand(sample_num, sample_num, dim)
+ e_y = torch.transpose(e_x, 0, 1)
+ attention_in = torch.cat((e_x, e_y), 2).view(-1, dim * 2)
+ self.a_t = torch.t(self.a)
+ attention_out = self.a_t.mm(torch.t(attention_in)).view(sample_num, sample_num)
+ attention_out = self.leaky_relu(attention_out)
+ att_weight = self.softmax(attention_out)
+ return att_weight
def forward(self, x):
# x: [N, F*T]
@@ -378,10 +387,8 @@ class GATModel(nn.Module):
x = x.permute(0, 2, 1) # [N, T, F]
out, _ = self.rnn(x)
hidden = out[:, -1, :]
- hidden = self.bn1(hidden)
- gamma = self.cal_convariance(hidden, hidden)
- output = gamma.mm(hidden)
- output = self.fc(output)
- output = self.bn2(output)
- output = self.leaky_relu(output)
- return self.fc_out(output).squeeze()
+ att_weight = self.cal_attention(hidden, hidden)
+ hidden = att_weight.mm(hidden) + hidden
+ hidden = self.fc(hidden)
+ hidden = self.leaky_relu(hidden)
+ return self.fc_out(hidden).squeeze()
diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py
index 02664b6ac..5daf4707e 100755
--- a/qlib/contrib/model/pytorch_gru.py
+++ b/qlib/contrib/model/pytorch_gru.py
@@ -11,7 +11,12 @@ import pandas as pd
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
-from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
+from ...utils import (
+ unpack_archive_with_buffer,
+ save_multiple_parts_file,
+ create_save_path,
+ drop_nan_by_y_index,
+)
from ...log import get_module_logger, TimeInspector
import torch
@@ -109,7 +114,10 @@ class GRU(Model):
)
self.gru_model = GRUModel(
- d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
+ d_feat=self.d_feat,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ dropout=self.dropout,
)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.gru_model.parameters(), lr=self.lr)
@@ -141,7 +149,7 @@ class GRU(Model):
mask = torch.isfinite(label)
- if self.metric == "" or self.metric == "loss": # use loss
+ if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
@@ -219,7 +227,9 @@ class GRU(Model):
):
df_train, df_valid, df_test = dataset.prepare(
- ["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ ["train", "valid", "test"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
diff --git a/qlib/contrib/model/pytorch_hats.py b/qlib/contrib/model/pytorch_hats.py
deleted file mode 100644
index 7affea73c..000000000
--- a/qlib/contrib/model/pytorch_hats.py
+++ /dev/null
@@ -1,491 +0,0 @@
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from __future__ import division
-from __future__ import print_function
-
-import os
-import numpy as np
-import pandas as pd
-import copy
-from ...utils import create_save_path
-from ...log import get_module_logger
-
-import torch
-import torch.nn as nn
-import torch.optim as optim
-
-from ...model.base import Model
-from ...data.dataset import DatasetH
-from ...data.dataset.handler import DataHandlerLP
-
-
-class HATS(Model):
- """HATS Model
-
- Parameters
- ----------
- d_feat : int
- input dimension for each time step
- metric: str
- the evaluate metric used in early stop
- optimizer : str
- optimizer name
- GPU : str
- the GPU ID(s) used for training
- """
-
- def __init__(
- self,
- d_feat=6,
- hidden_size=64,
- num_layers=2,
- dropout=0.5,
- n_epochs=200,
- lr=0.01,
- metric="",
- early_stop=20,
- loss="mse",
- base_model="GRU",
- with_pretrain=True,
- optimizer="adam",
- GPU="0",
- seed=0,
- **kwargs
- ):
- # Set logger.
- self.logger = get_module_logger("HATS")
- self.logger.info("HATS pytorch version...")
-
- # set hyper-parameters.
- self.d_feat = d_feat
- self.hidden_size = hidden_size
- self.num_layers = num_layers
- self.dropout = dropout
- self.n_epochs = n_epochs
- self.lr = lr
- self.metric = metric
- self.early_stop = early_stop
- self.optimizer = optimizer.lower()
- self.loss = loss
- self.base_model = base_model
- self.with_pretrain = with_pretrain
- self.visible_GPU = GPU
- self.use_gpu = torch.cuda.is_available()
- self.seed = seed
-
- self.logger.info(
- "HATS parameters setting:"
- "\nd_feat : {}"
- "\nhidden_size : {}"
- "\nnum_layers : {}"
- "\ndropout : {}"
- "\nn_epochs : {}"
- "\nlr : {}"
- "\nmetric : {}"
- "\nearly_stop : {}"
- "\noptimizer : {}"
- "\nloss_type : {}"
- "\nbase_model : {}"
- "\nwith_pretrain : {}"
- "\nvisible_GPU : {}"
- "\nuse_GPU : {}"
- "\nseed : {}".format(
- d_feat,
- hidden_size,
- num_layers,
- dropout,
- n_epochs,
- lr,
- metric,
- early_stop,
- optimizer.lower(),
- loss,
- base_model,
- with_pretrain,
- GPU,
- self.use_gpu,
- seed,
- )
- )
-
- self.HATS_model = HATSModel(
- d_feat=self.d_feat,
- hidden_size=self.hidden_size,
- num_layers=self.num_layers,
- dropout=self.dropout,
- base_model=self.base_model,
- )
- if optimizer.lower() == "adam":
- self.train_optimizer = optim.Adam(self.HATS_model.parameters(), lr=self.lr)
- elif optimizer.lower() == "gd":
- self.train_optimizer = optim.SGD(self.HATS_model.parameters(), lr=self.lr)
- else:
- raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
-
- self._fitted = False
- if self.use_gpu:
- self.HATS_model.cuda()
- # set the visible GPU
- if self.visible_GPU:
- os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU
-
- def mse(self, pred, label):
- loss = (pred - label) ** 2
- return torch.mean(loss)
-
- def loss_fn(self, pred, label):
- mask = ~torch.isnan(label)
-
- if self.loss == "mse":
- return self.mse(pred[mask], label[mask])
-
- raise ValueError("unknown loss `%s`" % self.loss)
-
- def metric_fn(self, pred, label):
- mask = torch.isfinite(label)
-
- if self.metric == "" or self.metric == "loss": # use loss
- return -self.loss_fn(pred[mask], label[mask])
-
- raise ValueError("unknown metric `%s`" % self.metric)
-
- def get_daily_inter(self, df, shuffle=False):
- # organize the train data into daily inter as daily batches
- daily_count = df.groupby(level=0).size().values
- daily_index = np.roll(np.cumsum(daily_count), 1)
- daily_index[0] = 0
- if shuffle:
- # shuffle the daily inter data
- daily_shuffle = list(zip(daily_index, daily_count))
- np.random.shuffle(daily_shuffle)
- daily_index, daily_count = zip(*daily_shuffle)
- return daily_index, daily_count
-
- def train_epoch(self, x_train, y_train):
-
- x_train_values = x_train.values
- y_train_values = np.squeeze(y_train.values)
-
- self.HATS_model.train()
-
- # organize the train data into daily inter as daily batches
- daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
-
- for idx, count in zip(daily_index, daily_count):
- batch = slice(idx, idx + count)
- feature = torch.from_numpy(x_train_values[batch]).float()
- label = torch.from_numpy(y_train_values[batch]).float()
-
- if self.use_gpu:
- feature = feature.cuda()
- label = label.cuda()
-
- pred = self.HATS_model(feature)
- loss = self.loss_fn(pred, label)
-
- self.train_optimizer.zero_grad()
- loss.backward()
- torch.nn.utils.clip_grad_value_(self.HATS_model.parameters(), 3.0)
- self.train_optimizer.step()
-
- def test_epoch(self, data_x, data_y):
-
- # prepare testing data
- x_values = data_x.values
- y_values = np.squeeze(data_y.values)
-
- self.HATS_model.eval()
-
- scores = []
- losses = []
-
- # organize the test data into daily inter as daily batches
- daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
-
- for idx, count in zip(daily_index, daily_count):
- batch = slice(idx, idx + count)
- feature = torch.from_numpy(x_values[batch]).float()
- label = torch.from_numpy(y_values[batch]).float()
-
- if self.use_gpu:
- feature = feature.cuda()
- label = label.cuda()
-
- pred = self.HATS_model(feature)
- loss = self.loss_fn(pred, label)
- losses.append(loss.item())
-
- score = self.metric_fn(pred, label)
- scores.append(score.item())
-
- return np.mean(losses), np.mean(scores)
-
- def fit(
- self,
- dataset: DatasetH,
- evals_result=dict(),
- verbose=True,
- save_path=None,
- ):
-
- df_train, df_valid, df_test = dataset.prepare(
- ["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
- )
-
- x_train, y_train = df_train["feature"], df_train["label"]
- x_valid, y_valid = df_valid["feature"], df_valid["label"]
-
- if save_path == None:
- save_path = create_save_path(save_path)
- stop_steps = 0
- best_score = -np.inf
- best_epoch = 0
- evals_result["train"] = []
- evals_result["valid"] = []
-
- # load pretrained base_model
- if self.with_pretrain:
- self.logger.info("Loading pretrained model...")
- if self.base_model == "LSTM":
- from ...contrib.model.pytorch_lstm import LSTMModel
-
- pretrained_model = LSTMModel()
- pretrained_model.load_state_dict(torch.load("benchmarks/LSTM/model_lstm_csi300.pkl"))
- elif self.base_model == "GRU":
- from ...contrib.model.pytorch_gru import GRUModel
-
- pretrained_model = GRUModel()
- pretrained_model.load_state_dict(torch.load("benchmarks/GRU/model_gru_csi300.pkl"))
- model_dict = self.HATS_model.state_dict()
- pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
- model_dict.update(pretrained_dict)
- self.HATS_model.load_state_dict(model_dict)
- self.logger.info("Loading pretrained model Done...")
-
- # train
- self.logger.info("training...")
- self._fitted = True
-
- for step in range(self.n_epochs):
- self.logger.info("Epoch%d:", step)
- self.logger.info("training...")
- self.train_epoch(x_train, y_train)
- self.logger.info("evaluating...")
- train_loss, train_score = self.test_epoch(x_train, y_train)
- val_loss, val_score = self.test_epoch(x_valid, y_valid)
- self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
- evals_result["train"].append(train_score)
- evals_result["valid"].append(val_score)
-
- if val_score > best_score:
- best_score = val_score
- stop_steps = 0
- best_epoch = step
- best_param = copy.deepcopy(self.HATS_model.state_dict())
- else:
- stop_steps += 1
- if stop_steps >= self.early_stop:
- self.logger.info("early stop")
- break
-
- self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
- self.HATS_model.load_state_dict(best_param)
- torch.save(best_param, save_path)
-
- if self.use_gpu:
- torch.cuda.empty_cache()
-
- def predict(self, dataset):
- if not self._fitted:
- raise ValueError("model is not fitted yet!")
-
- x_test = dataset.prepare("test", col_set="feature")
- index = x_test.index
- self.HATS_model.eval()
- x_values = x_test.values
- sample_num = x_values.shape[0]
- preds = []
-
- # organize the data into daily inter as daily batches
- daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
-
- for idx, count in zip(daily_index, daily_count):
- batch = slice(idx, idx + count)
- x_batch = torch.from_numpy(x_values[batch]).float()
-
- if self.use_gpu:
- x_batch = x_batch.cuda()
-
- with torch.no_grad():
- if self.use_gpu:
- pred = self.HATS_model(x_batch).detach().cpu().numpy()
- else:
- pred = self.HATS_model(x_batch).detach().numpy()
-
- preds.append(pred)
-
- return pd.Series(np.concatenate(preds), index=index)
-
-
-class HATSModel(nn.Module):
- def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"):
- super().__init__()
-
- if base_model == "GRU":
- self.model = nn.GRU(
- input_size=d_feat,
- hidden_size=hidden_size,
- num_layers=num_layers,
- batch_first=True,
- dropout=dropout,
- )
- elif base_model == "LSTM":
- self.model = nn.LSTM(
- input_size=d_feat,
- hidden_size=hidden_size,
- num_layers=num_layers,
- batch_first=True,
- dropout=dropout,
- )
- else:
- raise ValueError("unknown base model name `%s`" % base_model)
-
- self.hidden_size = hidden_size
- self.bn1 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
- self.fc = nn.Linear(hidden_size, hidden_size)
- self.bn2 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
- self.fc_out = nn.Linear(hidden_size, 1)
- self.leaky_relu = nn.LeakyReLU()
- self.softmax = nn.Softmax(dim=1)
- self.d_feat = d_feat
-
- num_head_att = [1] * num_layers
- hidden_dim = [hidden_size] * num_layers
- dims = [d_feat] + [d * nh for (d, nh) in zip(hidden_dim, num_head_att[:-1])] + [num_head_att[-1]]
- in_dims = dims[:-1]
- out_dims = [d // nh for (d, nh) in zip(dims[1:], num_head_att)]
- self.attn = nn.ModuleList(
- [GraphAttention(i, o, nh, dropout) for (i, o, nh) in zip(in_dims, out_dims, num_head_att)]
- )
- self.bns = nn.ModuleList([nn.BatchNorm1d(dim) for dim in dims[1:-1]])
- self.dropout = nn.Dropout(dropout)
- self.elu = nn.ELU()
-
- def forward(self, x):
- x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
- x = x.permute(0, 2, 1) # [N, T, F]
- out, _ = self.model(x)
- hidden = out[:, -1, :]
- hidden = self.bn1(hidden)
- attention = GraphAttention.cal_attention(hidden, hidden)
- output = attention.mm(hidden)
- output = self.fc(output)
- output = self.bn2(output)
- output = self.leaky_relu(output)
- return self.fc_out(output).squeeze()
-
-
-class GraphAttention(nn.Module):
- def __init__(self, input_dim, output_dim, num_heads, dropout=0.5):
-
- super().__init__()
-
- """
- Parameters
- ----------
- input_dim : int
- Dimension of input node features.
- output_dim : int
- Dimension of output node features.
- num_heads : list of ints
- Number of attention heads in each hidden layer and output layer. Must be non empty. Note that len(num_heads) = len(hidden_dims)+1.
- dropout : float
- Dropout rate. Default: 0.5.
- """
-
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.num_heads = num_heads
-
- self.fcs = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)])
- self.a = nn.ModuleList([nn.Linear(2 * output_dim, 1) for _ in range(num_heads)])
-
- self.dropout = nn.Dropout(dropout)
- self.softmax = nn.Softmax(dim=0)
- self.leakyrelu = nn.LeakyReLU()
-
- def forward(self, features, nodes, mappings, rows):
-
- """
- Parameters
- ----------
- features : torch.Tensor
- An (n' x input_dim) tensor of input node features.
- nodes : list of numpy array
- nodes[i] is an array of the nodes in the ith layer of the
- computation graph.
- mappings : list of dictionary
- mappings[i] is a dictionary mappings node v (labelled 0 to |V|-1)
- in nodes[i] to its position in nodes[i]. For example,
- if nodes[i] = [2,5], then mappings[i][2] = 0 and
- mappings[i][5] = 1.
- rows : numpy array
- rows[i] is an array of neighbors of node i.
- Returns
- -------
- out : torch.Tensor
- An (len(node_layers[-1]) x output_dim) tensor of output node features.
- """
-
- nprime = features.shape[0]
- rows = [np.array([mappings[v] for v in row], dtype=np.int64) for row in rows]
- sum_degs = np.hstack(([0], np.cumsum([len(row) for row in rows])))
- mapped_nodes = [mappings[v] for v in nodes]
- indices = torch.LongTensor([[v, c] for (v, row) in zip(mapped_nodes, rows) for c in row]).t()
-
- out = []
- for k in range(self.num_heads):
- h = self.fcs[k](features)
-
- nbr_h = torch.cat(tuple([h[row] for row in rows]), dim=0)
- self_h = torch.cat(
- tuple([h[mappings[nodes[i]]].repeat(len(row), 1) for (i, row) in enumerate(rows)]), dim=0
- )
- cat_h = torch.cat((self_h, nbr_h), dim=1)
-
- e = self.leakyrelu(self.a[k](cat_h))
-
- alpha = [self.softmax(e[lo:hi]) for (lo, hi) in zip(sum_degs, sum_degs[1:])]
- alpha = torch.cat(tuple(alpha), dim=0)
- alpha = alpha.squeeze(1)
- alpha = self.dropout(alpha)
-
- adj = torch.sparse.FloatTensor(indices, alpha, torch.Size([nprime, nprime]))
- out.append(torch.sparse.mm(adj, h)[mapped_nodes])
-
- return out
-
- @staticmethod
- def cal_attention(x, y):
- att_x = torch.mean(x, dim=1).reshape(-1, 1)
- att_y = torch.mean(y, dim=1).reshape(-1, 1)
- att = att_x.mm(torch.t(att_y))
- return (
- torch.mean(
- x.reshape(x.shape[0], 1, x.shape[1]).repeat(1, y.shape[0], 1)
- * y.reshape(1, y.shape[0], y.shape[1]).repeat(x.shape[0], 1, 1),
- dim=2,
- )
- - att
- )
diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py
index f8951509a..eef1680ec 100755
--- a/qlib/contrib/model/pytorch_lstm.py
+++ b/qlib/contrib/model/pytorch_lstm.py
@@ -11,7 +11,12 @@ import pandas as pd
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
-from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
+from ...utils import (
+ unpack_archive_with_buffer,
+ save_multiple_parts_file,
+ create_save_path,
+ drop_nan_by_y_index,
+)
from ...log import get_module_logger, TimeInspector
import torch
@@ -109,7 +114,10 @@ class LSTM(Model):
)
self.lstm_model = LSTMModel(
- d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
+ d_feat=self.d_feat,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ dropout=self.dropout,
)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.lstm_model.parameters(), lr=self.lr)
@@ -141,7 +149,7 @@ class LSTM(Model):
mask = torch.isfinite(label)
- if self.metric == "" or self.metric == "loss": # use loss
+ if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
@@ -219,7 +227,9 @@ class LSTM(Model):
):
df_train, df_valid, df_test = dataset.prepare(
- ["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ ["train", "valid", "test"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py
index 1d27f3927..8fddd1612 100644
--- a/qlib/contrib/model/pytorch_sfm.py
+++ b/qlib/contrib/model/pytorch_sfm.py
@@ -19,7 +19,12 @@ import pandas as pd
import copy
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
-from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
+from ...utils import (
+ unpack_archive_with_buffer,
+ save_multiple_parts_file,
+ create_save_path,
+ drop_nan_by_y_index,
+)
from ...log import get_module_logger, TimeInspector
import torch
@@ -33,7 +38,16 @@ from ...data.dataset.handler import DataHandlerLP
class SFM_Model(nn.Module):
- def __init__(self, d_feat=6, output_dim=1, freq_dim=10, hidden_size=64, dropout_W=0.0, dropout_U=0.0, device="cpu"):
+ def __init__(
+ self,
+ d_feat=6,
+ output_dim=1,
+ freq_dim=10,
+ hidden_size=64,
+ dropout_W=0.0,
+ dropout_U=0.0,
+ device="cpu",
+ ):
super().__init__()
self.input_dim = d_feat
@@ -157,7 +171,16 @@ class SFM_Model(nn.Module):
init_state_time = torch.tensor(0).to(self.device)
- self.states = [init_state_p, init_state_h, init_state_S_re, init_state_S_im, init_state_time, None, None, None]
+ self.states = [
+ init_state_p,
+ init_state_h,
+ init_state_S_re,
+ init_state_S_im,
+ init_state_time,
+ None,
+ None,
+ None,
+ ]
def get_constants(self, x):
constants = []
@@ -352,7 +375,9 @@ class SFM(Model):
):
df_train, df_valid = dataset.prepare(
- ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ ["train", "valid"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
@@ -409,7 +434,7 @@ class SFM(Model):
mask = torch.isfinite(label)
- if self.metric == "" or self.metric == "loss": # use loss
+ if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
diff --git a/qlib/contrib/model/tabnet.py b/qlib/contrib/model/tabnet.py
deleted file mode 100644
index bc13d1f62..000000000
--- a/qlib/contrib/model/tabnet.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import numpy as np
-import pandas as pd
-from pytorch_tabnet.tab_model import TabNetRegressor
-
-from ...model.base import Model
-from ...data.dataset import DatasetH
-from ...data.dataset.handler import DataHandlerLP
-
-
-class TabNetModel(Model):
- """TabNetModel Model"""
-
- def __init__(
- self,
- n_d,
- n_a,
- n_steps,
- gamma,
- n_independent,
- n_shared,
- seed,
- momentum,
- lambda_sparse,
- optimizer_params,
- **kwargs
- ):
- self.model = None
-
- self.n_d = n_d
- self.n_a = n_a
- self.n_steps = n_steps
- self.gamma = gamma
- self.n_independent = n_independent
- self.n_shared = n_shared
- self.seed = seed
- self.momentum = momentum
- self.lambda_sparse = lambda_sparse
- self.optimizer_params = optimizer_params
-
- def fit(
- self,
- dataset: DatasetH,
- n_d=8,
- n_a=8,
- n_steps=3,
- gamma=1.3,
- n_independent=2,
- n_shared=2,
- seed=0,
- momentum=0.02,
- lambda_sparse=1e-3,
- optimizer_params={"lr": 2e-3},
- **kwargs
- ):
-
- df_train, df_valid = dataset.prepare(
- ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
- )
- x_train, y_train = df_train["feature"].values, df_train["label"].values * 100
- x_valid, y_valid = df_valid["feature"].values, df_valid["label"].values * 100
-
- self.model = TabNetRegressor(
- n_d=self.n_d,
- n_a=self.n_a,
- n_steps=self.n_steps,
- gamma=self.gamma,
- n_independent=self.n_independent,
- n_shared=self.n_shared,
- seed=self.seed,
- momentum=self.momentum,
- lambda_sparse=self.lambda_sparse,
- optimizer_params=self.optimizer_params,
- **kwargs
- )
- self.model.fit(x_train, y_train, eval_set=[(x_valid, y_valid)])
-
- def predict(self, dataset):
- if self.model is None:
- raise ValueError("model is not fitted yet!")
- x_test = dataset.prepare("test", col_set="feature")
- test_pred = self.model.predict(x_test.values)
- return pd.Series(test_pred.reshape([-1]), index=x_test.index)
diff --git a/qlib/contrib/model/xgboost.py b/qlib/contrib/model/xgboost.py
index 039fd2c80..c9e45d4ac 100755
--- a/qlib/contrib/model/xgboost.py
+++ b/qlib/contrib/model/xgboost.py
@@ -38,7 +38,9 @@ class XGBModel(Model):
):
df_train, df_valid = dataset.prepare(
- ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ ["train", "valid"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py
index 15cc5fd0e..3fa688d36 100644
--- a/qlib/contrib/report/graph.py
+++ b/qlib/contrib/report/graph.py
@@ -96,7 +96,19 @@ class BaseGraph(object):
"""
py.init_notebook_mode()
for _fig in figure_list:
- py.iplot(_fig)
+ # NOTE: displays figures: https://plotly.com/python/renderers/
+ # default: plotly_mimetype+notebook
+ # support renderers: import plotly.io as pio; print(pio.renderers)
+ renderer = None
+ try:
+ # in notebook
+ _ipykernel = str(type(get_ipython()))
+ if "google.colab" in _ipykernel:
+ renderer = "colab"
+ except NameError:
+ pass
+
+ _fig.show(renderer=renderer)
def _get_layout(self) -> go.Layout:
"""
diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py
index e4fc8eef9..0ef062021 100644
--- a/qlib/model/trainer.py
+++ b/qlib/model/trainer.py
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
+
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py
index 08c13de2a..65d9a14b4 100644
--- a/qlib/workflow/cli.py
+++ b/qlib/workflow/cli.py
@@ -1,13 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-import sys
+import sys, os
from pathlib import Path
import qlib
import fire
import pandas as pd
import ruamel.yaml as yaml
+from qlib.config import C
from qlib.model.trainer import task_train
@@ -41,7 +42,7 @@ def sys_config(config, config_path):
# worflow handler function
-def workflow(config_path, experiment_name="workflow"):
+def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
with open(config_path) as fp:
config = yaml.load(fp, Loader=yaml.Loader)
@@ -50,7 +51,9 @@ def workflow(config_path, experiment_name="workflow"):
provider_uri = config.get("provider_uri")
region = config.get("region")
- qlib.init(provider_uri=provider_uri, region=region)
+ exp_manager = C["exp_manager"]
+ exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder)
+ qlib.init(provider_uri=provider_uri, region=region, exp_manager=exp_manager)
task_train(config, experiment_name=experiment_name)
diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py
index 156beb690..80d471845 100644
--- a/qlib/workflow/expm.py
+++ b/qlib/workflow/expm.py
@@ -239,20 +239,17 @@ class MLflowExpManager(ExpManager):
return self._client
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
+ # set the tracking uri
+ if uri is None:
+ logger.info("No tracking URI is provided. Use the default tracking URI.")
+ else:
+ self.uri = uri
# create experiment
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
# set up active experiment
self.active_experiment = experiment
# start the experiment
self.active_experiment.start(recorder_name)
- # set the tracking uri
- if uri is None:
- logger.info(
- "No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory."
- )
- else:
- self.uri = uri
- mlflow.set_tracking_uri(self.uri)
return self.active_experiment
diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py
index b3069b9ac..4c1ddfdfe 100644
--- a/qlib/workflow/recorder.py
+++ b/qlib/workflow/recorder.py
@@ -224,6 +224,8 @@ class MLflowRecorder(Recorder):
)
def start_run(self):
+ # set the tracking uri
+ mlflow.set_tracking_uri(self._uri)
# start the run
run = mlflow.start_run(self.id, self.experiment_id, self.name)
# save the run id and artifact_uri
diff --git a/requirements.txt b/requirements.txt
index d3511d780..638ce22f4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,5 +22,4 @@ scikit_learn==0.23.2
torch==1.6.0
tqdm==4.49.0
yahooquery==2.2.7
-mlflow==1.12.1
-pytorch-tabnet==2.0.1
\ No newline at end of file
+mlflow==1.12.1
\ No newline at end of file
diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py
index 69c7f8f15..0d41251f1 100644
--- a/scripts/data_collector/yahoo/collector.py
+++ b/scripts/data_collector/yahoo/collector.py
@@ -44,6 +44,7 @@ class YahooCollector:
delay=0,
check_data_length: bool = False,
limit_nums: int = None,
+ show_1m_logging: bool = False,
):
"""
@@ -67,10 +68,13 @@ class YahooCollector:
check data length, by default False
limit_nums: int
using for debug, by default None
+ show_1m_logging: bool
+ show 1m logging, by default False; if True, there may be many warning logs
"""
self.save_dir = Path(save_dir).expanduser().resolve()
self.save_dir.mkdir(parents=True, exist_ok=True)
self._delay = delay
+ self._show_1m_logging = show_1m_logging
self.stock_list = sorted(set(self.get_stock_list()))
if limit_nums is not None:
try:
@@ -83,7 +87,7 @@ class YahooCollector:
self._interval = interval
self._check_small_data = check_data_length
self._start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
- self._end_datetime = pd.Timestamp(str(end)) if end else self.END_DATETIME
+ self._end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
if self._interval == "1m":
self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME)
elif self._interval == "1d":
@@ -91,8 +95,12 @@ class YahooCollector:
else:
raise ValueError(f"interval error: {self._interval}")
+ # using for 1m
+ self._next_datetime = self.convert_datetime(self._start_datetime.date() + pd.Timedelta(days=1))
+ self._latest_datetime = self.convert_datetime(self._end_datetime.date())
+
self._start_datetime = self.convert_datetime(self._start_datetime)
- self._end_datetime = self.convert_datetime(min(self._end_datetime, self.END_DATETIME))
+ self._end_datetime = self.convert_datetime(self._end_datetime)
@property
@abc.abstractmethod
@@ -100,20 +108,24 @@ class YahooCollector:
# daily, one year: 252 / 4
# us 1min, a week: 6.5 * 60 * 5
# cn 1min, a week: 4 * 60 * 5
- raise NotImplementedError("rewirte min_numbers_trading")
+ raise NotImplementedError("rewrite min_numbers_trading")
@abc.abstractmethod
def get_stock_list(self):
- raise NotImplementedError("rewirte get_stock_list")
+ raise NotImplementedError("rewrite get_stock_list")
@property
- @abc.abstractclassmethod
+ @abc.abstractmethod
def _timezone(self):
raise NotImplementedError("rewrite get_timezone")
- def convert_datetime(self, dt: pd.Timestamp):
- dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
- return pd.Timestamp(dt, tz=tzlocal(), unit="s")
+ def convert_datetime(self, dt: [pd.Timestamp, datetime.date, str]):
+ try:
+ dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
+ dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
+ except ValueError as e:
+ pass
+ return dt
def _sleep(self):
time.sleep(self._delay)
@@ -136,7 +148,7 @@ class YahooCollector:
df["symbol"] = symbol
if stock_path.exists():
with stock_path.open("a") as fp:
- df.to_csv(fp, index=False, header=None)
+ df.to_csv(fp, index=False, header=False)
else:
with stock_path.open("w") as fp:
df.to_csv(fp, index=False)
@@ -155,34 +167,47 @@ class YahooCollector:
def _get_from_remote(self, symbol):
def _get_simple(start_, end_):
self._sleep()
+ error_msg = f"{symbol}-{self._interval}-{start_}-{end_}"
+
+ def _show_logging_func():
+ if self._interval == "1m" and self._show_1m_logging:
+ logger.warning(f"{error_msg}:{_resp}")
+
try:
_resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=start_, end=end_)
if isinstance(_resp, pd.DataFrame):
return _resp.reset_index()
+ elif isinstance(_resp, dict):
+ _temp_data = _resp.get(symbol, {})
+ if isinstance(_temp_data, str) or (
+ isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
+ ):
+ _show_logging_func()
else:
- logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{_resp}")
+ _show_logging_func()
except Exception as e:
- logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{e}")
+ logger.warning(f"{error_msg}:{e}")
_result = None
if self._interval == "1d":
_result = _get_simple(self._start_datetime, self._end_datetime)
elif self._interval == "1m":
- _start_date = self._start_datetime.date() + pd.Timedelta(days=1)
- _end_date = self._end_datetime.date()
- if _start_date >= _end_date:
+ if self._next_datetime >= self._latest_datetime:
_result = _get_simple(self._start_datetime, self._end_datetime)
else:
_res = []
def _get_multi(start_, end_):
_resp = _get_simple(start_, end_)
- if _resp is not None:
+ if _resp is not None and not _resp.empty:
_res.append(_resp)
- for _s, _e in ((self._start_datetime, _start_date), (_end_date, self._end_datetime)):
+ for _s, _e in (
+ (self._start_datetime, self._next_datetime),
+ (self._latest_datetime, self._end_datetime),
+ ):
_get_multi(_s, _e)
- for _start in pd.date_range(_start_date, _end_date, closed="left"):
+ for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
_end = _start + pd.Timedelta(days=1)
self._sleep()
_get_multi(_start, _end)
@@ -472,6 +497,7 @@ class Run:
interval="1d",
check_data_length=False,
limit_nums=None,
+ show_1m_logging=False,
):
"""download data from Internet
@@ -491,6 +517,9 @@ class Run:
check data length, by default False
limit_nums: int
using for debug, by default None
+ show_1m_logging: bool
+ show 1m logging, by default False; if True, there may be many warning logs
+
Examples
---------
# get daily data
@@ -510,6 +539,7 @@ class Run:
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
+ show_1m_logging=show_1m_logging,
).collector_data()
def normalize_data(self):
@@ -531,6 +561,7 @@ class Run:
interval="1d",
check_data_length=False,
limit_nums=None,
+ show_1m_logging=False,
):
"""download -> normalize
@@ -550,6 +581,9 @@ class Run:
check data length, by default False
limit_nums: int
using for debug, by default None
+ show_1m_logging: bool
+ show 1m logging, by default False; if True, there may be many warning logs
+
Examples
-------
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
@@ -562,6 +596,7 @@ class Run:
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
+ show_1m_logging=show_1m_logging,
)
self.normalize_data()