mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Merge branch 'main' of github.com:you-n-g/qlib into main
This commit is contained in:
41
README.md
41
README.md
@@ -41,7 +41,7 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
# Framework of Qlib
|
||||
|
||||
<div style="align: center">
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png" />
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.1" />
|
||||
</div>
|
||||
|
||||
|
||||
@@ -192,24 +192,6 @@ The automatic workflow may not suite the research workflow of all Quant research
|
||||
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
|
||||
## Run a single model
|
||||
`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
|
||||
- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only supprots *Linux* now. Other OS will be supported in the future.)
|
||||
|
||||
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.
|
||||
|
||||
Here is an example of running all the models for 10 iterations:
|
||||
```python
|
||||
python run_all_model.py 10
|
||||
```
|
||||
|
||||
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on LightGBM](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost](qlib/contrib/model/catboost_model.py)
|
||||
@@ -219,13 +201,30 @@ Here is a list of models built on `Qlib`.
|
||||
- [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [ALSTM based on pytorcn](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py)
|
||||
- [TabNet based on pytorch](qlib/contrib/model/tabnet.py)
|
||||
- [SFM based on pytorch](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [HATs based on pytorch](qlib/contrib/model/pytorch_hats.py)
|
||||
- [TFT based on tensorflow](examples/benchmarks/TFT/tft.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
## Run a single model
|
||||
`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
|
||||
- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only supprots *Linux* now. Other OS will be supported in the future.)
|
||||
|
||||
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored. (**Note**: the script will erase your previous experiment records created by running itself.)
|
||||
|
||||
Here is an example of running all the models for 10 iterations:
|
||||
```python
|
||||
python run_all_model.py 10
|
||||
```
|
||||
|
||||
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
|
||||
# Quant Dataset Zoo
|
||||
Dataset plays a very important role in Quant. Here is a list of the datasets built on `Qlib`.
|
||||
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
|
||||
- ALSTM contains a temporal attentive aggregation layer based on normal LSTM.
|
||||
|
||||
- The code used in Qlib is a pyTorch implementation of Code: https://github.com/fulifeng/Adv-ALSTM
|
||||
|
||||
- Paper: A dual-stage attention-based recurrent neural network for time series prediction.
|
||||
|
||||
https://www.ijcai.org/Proceedings/2017/0366.pdf
|
||||
[https://www.ijcai.org/Proceedings/2017/0366.pdf](https://www.ijcai.org/Proceedings/2017/0366.pdf)
|
||||
|
||||
|
||||
@@ -8,6 +8,20 @@ data_handler_config: &data_handler_config
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
@@ -26,20 +40,19 @@ port_analysis_config: &port_analysis_config
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GAT
|
||||
class: GATs
|
||||
module_path: qlib.contrib.model.pytorch_gats
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
dropout: 0.7
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
with_pretrain: True
|
||||
seed: 0
|
||||
GPU: 0
|
||||
dataset:
|
||||
@@ -47,7 +60,7 @@ task:
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360_Denoise
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
@@ -58,11 +71,6 @@ task:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
## Requirement
|
||||
|
||||
* pandas==1.1.2
|
||||
* numpy==1.17.4
|
||||
* scikit_learn==0.23.2
|
||||
* torch==1.7.0
|
||||
|
||||
## HATS
|
||||
|
||||
* HATS is a a hierarchical attention network for stock prediction which uses relational data for stock market prediction. HATS selectively aggregates information
|
||||
on different relation types and adds the information to the representations of each company. HATS is used as a relational modeling module with initialized node representations.Furthermore, HATS
|
||||
can predict not only individual stock prices but also market index movements, which is similar to the graph classification task.
|
||||
|
||||
* HATS uses pretrained model of GRU and LSTM. The code of GRU and LSTM used in Qlib is a pyTorch implemention of GRU and LSTM.
|
||||
* Paper address:HATS: A Hierarchical Graph Attention Network for Stock Movement Prediction https://arxiv.org/pdf/1908.07999.pdf
|
||||
@@ -1,4 +0,0 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -1,77 +0,0 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: HATS
|
||||
module_path: qlib.contrib.model.pytorch_hats
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.6
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: GRU
|
||||
seed: 0
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
Binary file not shown.
@@ -1,4 +1,3 @@
|
||||
# State-Frequency-Memory
|
||||
- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform (DFT) to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
|
||||
- The code used in Qlib is a pyTorch implementation of SFM (Zhang, L., Aggarwal, C., & Qi, G. J. (2017,)).
|
||||
- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.
|
||||
- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
|
||||
- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.)
|
||||
@@ -233,9 +233,8 @@ class TFTModel(ModelFT):
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
|
||||
predict = format_score(p90_forecast, "pred", 0) # self.label_shift
|
||||
label = format_score(targets, "label", 0)
|
||||
# ===========================Predicting Process===========================
|
||||
return predict, label
|
||||
return predict
|
||||
|
||||
def finetune(self, dataset: DatasetH):
|
||||
"""
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
# TabNet
|
||||
* TabNet is a novel high-performance and interpretable canonical deep tabular data learning architectur. TabNet uses sequential attention to choose which features to reason from at each decision step, enabling interpretability and more effcient learning as the learning capacity is used for the most salient features.
|
||||
* The code used in Qlib is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). [https://github.com/dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet)
|
||||
* Paper: TabNet: Attentive Interpretable Tabular Learning. [https://arxiv.org/pdf/1908.07442.pdf](https://arxiv.org/pdf/1908.07442.pdf).
|
||||
@@ -1,5 +0,0 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
pytorch-tabnet==2.0.1
|
||||
@@ -1,66 +0,0 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TabNetModel
|
||||
module_path: qlib.contrib.model.tabnet
|
||||
kwargs:
|
||||
n_d: 8
|
||||
n_a: 8
|
||||
n_steps: 3
|
||||
gamma: 1.3
|
||||
n_independent: 2
|
||||
n_shared: 2
|
||||
seed: 0
|
||||
momentum: 0.02
|
||||
lambda_sparse: 1e-3
|
||||
optimizer_params: {lr: 2e-3}
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360_Denoise
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -30,14 +30,12 @@ task:
|
||||
module_path: qlib.contrib.model.xgboost
|
||||
kwargs:
|
||||
eval_metric: rmse
|
||||
colsample_bytree: 0.5
|
||||
eta: 0.2
|
||||
gamma: 0.55
|
||||
max_depth: 2
|
||||
min_child_weight: 1.0
|
||||
colsample_bytree: 0.8879
|
||||
eta: 0.0421
|
||||
max_depth: 8
|
||||
n_estimators: 647
|
||||
subsample: 0.8
|
||||
nthread: 4
|
||||
subsample: 0.8789
|
||||
nthread: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
@@ -62,4 +60,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -1,446 +0,0 @@
|
||||
{
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9-final"
|
||||
},
|
||||
"orig_nbformat": 2,
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2,
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import copy\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import qlib\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.contrib.model.gbdt import LGBModel\n",
|
||||
"from qlib.contrib.data.handler import Alpha158\n",
|
||||
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
|
||||
"from qlib.contrib.evaluate import (\n",
|
||||
" backtest as normal_backtest,\n",
|
||||
" risk_analysis,\n",
|
||||
")\n",
|
||||
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
|
||||
"from qlib.utils import flatten_dict"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"[35366:MainThread](2020-11-27 10:31:09,528) INFO - qlib.Initialization - [__init__.py:41] - default_conf: client.\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:09,531) WARNING - qlib.Initialization - [__init__.py:57] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:09,531) INFO - qlib.Initialization - [__init__.py:76] - qlib successfully initialized based on client settings.\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:09,532) INFO - qlib.Initialization - [__init__.py:79] - data_path=/home/dongzho/.qlib/qlib_data/cn_data\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# use default data\n",
|
||||
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data\n",
|
||||
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
|
||||
"if not exists_qlib_data(provider_uri):\n",
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
|
||||
" from get_data import GetData\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"market = \"csi300\"\n",
|
||||
"benchmark = \"SH000300\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"## Model Training"
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"[35366:MainThread](2020-11-27 10:31:29,731) INFO - qlib.timer - [log.py:81] - Time cost: 20.103s | Loading data Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:30,557) INFO - qlib.timer - [log.py:81] - Time cost: 0.241s | DropnaLabel Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,518) INFO - qlib.timer - [log.py:81] - Time cost: 7.960s | CSZScoreNorm Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,519) INFO - qlib.timer - [log.py:81] - Time cost: 8.786s | fit & process data Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,520) INFO - qlib.timer - [log.py:81] - Time cost: 28.891s | Init data Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,527) INFO - qlib.workflow - [exp.py:180] - Experiment 2 starts running ...\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,651) INFO - qlib.workflow - [recorder.py:234] - Recorder c81375e3b5474feb9c77711babd158c3 starts running under Experiment 2 ...\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,652) INFO - qlib.workflow - [expm.py:251] - No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory.\n",
|
||||
"Training until validation scores don't improve for 50 rounds\n",
|
||||
"[20]\ttrain's l2: 0.990559\tvalid's l2: 0.994332\n",
|
||||
"[40]\ttrain's l2: 0.98687\tvalid's l2: 0.993702\n",
|
||||
"[60]\ttrain's l2: 0.984308\tvalid's l2: 0.993503\n",
|
||||
"[80]\ttrain's l2: 0.982202\tvalid's l2: 0.993446\n",
|
||||
"[100]\ttrain's l2: 0.980318\tvalid's l2: 0.993423\n",
|
||||
"[120]\ttrain's l2: 0.97854\tvalid's l2: 0.993409\n",
|
||||
"[140]\ttrain's l2: 0.97679\tvalid's l2: 0.993413\n",
|
||||
"[160]\ttrain's l2: 0.975116\tvalid's l2: 0.993473\n",
|
||||
"Early stopping, best iteration is:\n",
|
||||
"[127]\ttrain's l2: 0.977957\tvalid's l2: 0.993381\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"###################################\n",
|
||||
"# train model\n",
|
||||
"###################################\n",
|
||||
"data_handler_config = {\n",
|
||||
" \"start_time\": \"2008-01-01\",\n",
|
||||
" \"end_time\": \"2020-08-01\",\n",
|
||||
" \"fit_start_time\": \"2008-01-01\",\n",
|
||||
" \"fit_end_time\": \"2014-12-31\",\n",
|
||||
" \"instruments\": market,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"task = {\n",
|
||||
" \"model\": {\n",
|
||||
" \"class\": \"LGBModel\",\n",
|
||||
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"loss\": \"mse\",\n",
|
||||
" \"colsample_bytree\": 0.8879,\n",
|
||||
" \"learning_rate\": 0.0421,\n",
|
||||
" \"subsample\": 0.8789,\n",
|
||||
" \"lambda_l1\": 205.6999,\n",
|
||||
" \"lambda_l2\": 580.9768,\n",
|
||||
" \"max_depth\": 8,\n",
|
||||
" \"num_leaves\": 210,\n",
|
||||
" \"num_threads\": 20,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"dataset\": {\n",
|
||||
" \"class\": \"DatasetH\",\n",
|
||||
" \"module_path\": \"qlib.data.dataset\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"handler\": {\n",
|
||||
" \"class\": \"Alpha158\",\n",
|
||||
" \"module_path\": \"qlib.contrib.data.handler\",\n",
|
||||
" \"kwargs\": data_handler_config,\n",
|
||||
" },\n",
|
||||
" \"segments\": {\n",
|
||||
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
|
||||
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
|
||||
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# model initiaiton\n",
|
||||
"model = init_instance_by_config(task[\"model\"])\n",
|
||||
"dataset = init_instance_by_config(task[\"dataset\"])\n",
|
||||
"\n",
|
||||
"# start exp to train model\n",
|
||||
"with R.start(experiment_name=\"train_model\"):\n",
|
||||
" R.log_params(**flatten_dict(task))\n",
|
||||
" model.fit(dataset)\n",
|
||||
" R.save_objects(trained_model=model)\n",
|
||||
" rid = R.get_recorder().id\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"## Optimization Based Strategy"
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.strategy.strategy import BaseStrategy\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class OptBasedStrategy(BaseStrategy):\n",
|
||||
" \"\"\"Optimization Based Strategy\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, data_handler, cov_estimator, optimizer):\n",
|
||||
" self.data_handler = data_handler\n",
|
||||
" self.cov_estimator = cov_estimator\n",
|
||||
" self.optimizer = optimizer\n",
|
||||
"\n",
|
||||
" def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):\n",
|
||||
" \"\"\"\n",
|
||||
" Parameters\n",
|
||||
" -----------\n",
|
||||
" score_series : pd.Seires\n",
|
||||
" stock_id , score.\n",
|
||||
" current : Position()\n",
|
||||
" current of account.\n",
|
||||
" trade_exchange : Exchange()\n",
|
||||
" exchange.\n",
|
||||
" trade_date : pd.Timestamp\n",
|
||||
" date.\n",
|
||||
" \"\"\"\n",
|
||||
" score_series = score_series.dropna()\n",
|
||||
"\n",
|
||||
" # check stock holdings, if\n",
|
||||
" # 1. doesn't have score: target amount = 0 (force sell)\n",
|
||||
" # 2. stock not tradable: target amount = current amount\n",
|
||||
" current_position = current.get_stock_amount_dict()\n",
|
||||
" target_position = {}\n",
|
||||
" for stock_id in current_position:\n",
|
||||
" if not trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):\n",
|
||||
" target_position[stock_id] = current_position[stock_id]\n",
|
||||
" elif stock_id not in score_series.index:\n",
|
||||
" target_position[stock_id] = 0\n",
|
||||
" else:\n",
|
||||
" # need to be solved by optimizer\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" # filter scores, if\n",
|
||||
" # 1. kept in `amount_dict` by previous rules\n",
|
||||
" # 2. not tradable\n",
|
||||
" skipped = []\n",
|
||||
" for stock_id in score_series.index:\n",
|
||||
" if stock_id in target_position:\n",
|
||||
" skipped.append(stock_id)\n",
|
||||
" elif not trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):\n",
|
||||
" skipped.append(stock_id)\n",
|
||||
" score_series = score_series[~score_series.index.isin(skipped)]\n",
|
||||
"\n",
|
||||
" # calc remaining value\n",
|
||||
" current_value = pd.Series({\n",
|
||||
" stock_id: current.get_stock_price(stock_id) * amount\n",
|
||||
" for stock_id, amount in current_position.items()\n",
|
||||
" })\n",
|
||||
" risk_total_value = self.get_risk_degree(trade_date) * current.calculate_value()\n",
|
||||
" traded_value = risk_total_value - current_value.loc[list(target_position)].sum()\n",
|
||||
"\n",
|
||||
" # portfolio init weight\n",
|
||||
" init_weight = current_value.reindex(score_series.index, fill_value=0)\n",
|
||||
" init_weight_sum = init_weight.sum()\n",
|
||||
" if init_weight_sum > 0:\n",
|
||||
" init_weight /= init_weight_sum\n",
|
||||
"\n",
|
||||
" # covariance estimation\n",
|
||||
" selector = (self.data_handler.get_range_selector(pred_date, 252), score_series.index)\n",
|
||||
" price = self.data_handler.fetch(selector, level=None, squeeze=True)\n",
|
||||
" cov = self.cov_estimator(price)\n",
|
||||
" cov = cov.reindex(\n",
|
||||
" index=score_series.index, \n",
|
||||
" columns=score_series.index, \n",
|
||||
" #fill_value=cov.max().max()\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # optimize target portfolio\n",
|
||||
" if init_weight.sum() > 0:\n",
|
||||
" target_weight = self.optimizer(cov, score_series, init_weight)\n",
|
||||
" else:\n",
|
||||
" target_weight = self.optimizer(cov, score_series)\n",
|
||||
" target_weight = target_weight[target_weight > 1e-6]\n",
|
||||
" for stock_id, weight in target_weight.items():\n",
|
||||
" try:\n",
|
||||
" target_position[stock_id] = int(traded_value * weight / trade_exchange.get_close(stock_id, pred_date))\n",
|
||||
" except Exception as e:\n",
|
||||
" # TODO: unknown exception\n",
|
||||
" print('Exception:', e)\n",
|
||||
"\n",
|
||||
" # for debug\n",
|
||||
" print('trade date:', trade_date)\n",
|
||||
" print('target weight:', target_weight.to_dict())\n",
|
||||
" print('target position:', target_position)\n",
|
||||
"\n",
|
||||
" # generate order list\n",
|
||||
" order_list = trade_exchange.generate_order_for_target_amount_position(\n",
|
||||
" target_position=target_position,\n",
|
||||
" current_position=current_position,\n",
|
||||
" trade_date=trade_date,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return order_list"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.data.dataset.loader import QlibDataLoader\n",
|
||||
"from qlib.data.dataset.handler import DataHandler\n",
|
||||
"from qlib.model.riskmodel import ShrinkCovEstimator\n",
|
||||
"from qlib.portfolio.optimizer import PortfolioOptimizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"[35366:MainThread](2020-11-27 10:31:56,951) INFO - qlib.timer - [log.py:81] - Time cost: 6.763s | Loading data Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:56,953) INFO - qlib.timer - [log.py:81] - Time cost: 6.766s | Init data Done\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data_loader = QlibDataLoader([\"$close\"])\n",
|
||||
"data_handler = DataHandler(\"all\", \"2015-01-01\", \"2020-08-01\", data_loader)\n",
|
||||
"cov_estimator = ShrinkCovEstimator(nan_option=\"mask\")\n",
|
||||
"optimizer = PortfolioOptimizer(\"mvo\", lamb=2, delta=0.2, tol=1e-5)\n",
|
||||
"strategy = OptBasedStrategy(data_handler, cov_estimator, optimizer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"1': 0.08936553334387595, 'SH601800': 0.011014844457113308, 'SH601939': 0.013378001170219945, 'SH603993': 0.013820193926861863, 'SZ000338': 0.002455991798001457, 'SZ000423': 0.004893338273543826, 'SZ000538': 0.010686211189620477, 'SZ002065': 0.09095125419435357, 'SZ002074': 0.010299013738522475, 'SZ002085': 0.19844965949420615, 'SZ002236': 0.09210003831704765, 'SZ002310': 0.05664352912360013, 'SZ300017': 0.0197442255539771}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 272224, 'SH600009': 604839, 'SH600018': 3097398, 'SH600028': 335726, 'SH600196': 23243, 'SH600276': 71634, 'SH600519': 17354, 'SH600585': 269686, 'SH600900': 2501521, 'SH601111': 2400659, 'SH601800': 334062, 'SH601939': 1283164, 'SH603993': 742901, 'SZ000338': 95285, 'SZ000423': 21697, 'SZ000538': 14518, 'SZ002065': 498253, 'SZ002074': 111674, 'SZ002085': 591507, 'SZ002236': 394197, 'SZ002310': 2202674, 'SZ300017': 206128}\n",
|
||||
"target weight: {'SH600000': 0.02310668460556249, 'SH600009': 0.06170206213753432, 'SH600018': 0.027608180837257277, 'SH600028': 0.00971532319525714, 'SH600196': 0.0036133308423111116, 'SH600276': 0.093195014492093, 'SH600519': 0.013476706174774766, 'SH600585': 0.036024919027310476, 'SH600660': 0.04512159672692613, 'SH600900': 0.12506534473579556, 'SH601939': 0.013494851810297546, 'SH603993': 0.07619418669734077, 'SZ000338': 0.0024673392047414363, 'SZ000423': 0.00485981529404862, 'SZ000538': 0.010602880875660015, 'SZ002065': 0.09064325205359221, 'SZ002074': 0.0011889996597580427, 'SZ002085': 0.1982091371262038, 'SZ002236': 0.09254320484936242, 'SZ002310': 0.05152917909181458, 'SZ002466': 0.00014732765084648903, 'SZ300017': 0.019490662910321074}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 272079, 'SH600009': 604359, 'SH600018': 3095205, 'SH600028': 335471, 'SH600196': 23407, 'SH600276': 71567, 'SH600519': 17345, 'SH600585': 269447, 'SH600660': 129265, 'SH600900': 2499305, 'SH601939': 1282317, 'SH603993': 4058172, 'SZ000338': 95223, 'SZ000423': 21703, 'SZ000538': 14509, 'SZ002065': 497821, 'SZ002074': 12787, 'SZ002085': 590955, 'SZ002236': 393895, 'SZ002310': 2190685, 'SZ002466': 4483, 'SZ300017': 205994}\n",
|
||||
"target weight: {'SH600000': 0.0014042138463464568, 'SH600009': 0.11511740651805806, 'SH600018': 0.026968513725965638, 'SH600028': 0.009566603496832042, 'SH600150': 0.016339328084607228, 'SH600276': 0.09374974543357856, 'SH600489': 0.021876512936684123, 'SH600585': 0.035840818294258524, 'SH600900': 0.12414161958870683, 'SH601888': 0.005682635273269834, 'SH601939': 0.013289788356428228, 'SH603993': 0.07491407610535435, 'SZ000338': 0.002426716760042838, 'SZ000423': 0.00492071038737461, 'SZ000503': 0.005617017904986693, 'SZ000538': 0.010859006699485451, 'SZ002065': 0.08924691553942904, 'SZ002085': 0.19757848255238786, 'SZ002236': 0.09381012783787722, 'SZ002310': 0.03737359938389514, 'SZ300017': 0.01927616131502695}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16809, 'SH600009': 1075516, 'SH600018': 3091248, 'SH600028': 335128, 'SH600150': 114804, 'SH600276': 71473, 'SH600489': 66586, 'SH600585': 268644, 'SH600900': 2496175, 'SH601888': 173824, 'SH601939': 1281108, 'SH603993': 4052802, 'SZ000338': 95107, 'SZ000423': 21684, 'SZ000503': 80461, 'SZ000538': 14507, 'SZ002065': 497197, 'SZ002085': 590211, 'SZ002236': 393412, 'SZ002310': 1573728, 'SZ300017': 205818}\n",
|
||||
"target weight: {'SH600000': 0.0013962189421662084, 'SH600009': 0.09330267135244051, 'SH600018': 0.026443154116291615, 'SH600028': 0.009581412428525829, 'SH600150': 0.016443917649559808, 'SH600276': 0.09378402212481758, 'SH600703': 0.0005233118350013756, 'SH600741': 0.10117549074044105, 'SH600900': 0.12435147566444608, 'SH601888': 0.00560250787284307, 'SH601939': 0.013238798853730008, 'SH603993': 0.07455231781733267, 'SZ000423': 0.0048695925705555185, 'SZ000503': 0.006070996956328167, 'SZ000538': 0.010870567565742796, 'SZ002065': 0.08722983720892508, 'SZ002074': 0.00037126948590009574, 'SZ002085': 0.19840484837030906, 'SZ002236': 0.09365186287123867, 'SZ002310': 0.03806080531862309, 'SZ300017': 7.492025186876957e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16889, 'SH600009': 867443, 'SH600018': 3086467, 'SH600028': 334573, 'SH600150': 114383, 'SH600276': 71360, 'SH600703': 1760, 'SH600741': 665366, 'SH600900': 2491839, 'SH601888': 173465, 'SH601939': 1278590, 'SH603993': 4045939, 'SZ000423': 21674, 'SZ000503': 80212, 'SZ000538': 14499, 'SZ002065': 496361, 'SZ002074': 4086, 'SZ002085': 589224, 'SZ002236': 392766, 'SZ002310': 1571463, 'SZ300017': 805}\n",
|
||||
"target weight: {'SH600000': 0.0014143911110003147, 'SH600018': 0.026834186435965166, 'SH600028': 0.00961324990522086, 'SH600150': 0.015905361405158292, 'SH600276': 0.09486308638260738, 'SH600685': 1.0253334545374858e-06, 'SH600703': 0.0005108576602907958, 'SH600741': 0.10252334336233063, 'SH600900': 0.1250632059809011, 'SH601888': 0.005830869532670813, 'SH601939': 0.01336945356138906, 'SH603993': 0.07101851124599835, 'SZ000423': 0.004899981502195361, 'SZ000503': 0.006113894785564276, 'SZ000538': 0.011081925761176491, 'SZ000709': 1.06442568357325e-06, 'SZ002065': 0.08812103684766726, 'SZ002074': 0.0003564773234700175, 'SZ002085': 0.19097427428977284, 'SZ002236': 0.09299395368630246, 'SZ002310': 0.03841630892378685, 'SZ002475': 0.10001934454071283, 'SZ300017': 7.322667303400442e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16886, 'SH600018': 3080789, 'SH600028': 334087, 'SH600150': 114360, 'SH600276': 71234, 'SH600685': 10, 'SH600703': 1709, 'SH600741': 663932, 'SH600900': 2486951, 'SH601888': 173417, 'SH601939': 1276335, 'SH603993': 3740672, 'SZ000423': 21667, 'SZ000503': 80191, 'SZ000538': 14495, 'SZ000709': 11, 'SZ002065': 495371, 'SZ002074': 3867, 'SZ002085': 588051, 'SZ002236': 392002, 'SZ002310': 1568834, 'SZ002475': 1264636, 'SZ300017': 809}\n",
|
||||
"target weight: {'SH600000': 0.0013872765178790307, 'SH600018': 0.026321999857337998, 'SH600028': 0.009491029058787367, 'SH600150': 0.015749871987744815, 'SH600276': 0.09581999547114961, 'SH600703': 0.000518490273176083, 'SH600741': 0.1037547619508012, 'SH600900': 0.12396253436063161, 'SH601258': 0.02298494942988327, 'SH601888': 0.005915886046387033, 'SH601939': 0.013177336599075601, 'SH603993': 0.06888468621566025, 'SZ000423': 0.005102036718661418, 'SZ000503': 0.00602692511970311, 'SZ000538': 0.011127923667697532, 'SZ000709': 0.07688609680386178, 'SZ002065': 0.08693397271897534, 'SZ002074': 0.000347445594871718, 'SZ002085': 0.1905176824564206, 'SZ002236': 0.035835596544641496, 'SZ002475': 0.09918059167278087, 'SZ300017': 7.291118905149903e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16948, 'SH600018': 3086676, 'SH600028': 334750, 'SH600150': 114560, 'SH600276': 71372, 'SH600703': 1715, 'SH600741': 665129, 'SH600900': 2491433, 'SH601258': 4190669, 'SH601888': 174070, 'SH601939': 1278836, 'SH603993': 3747283, 'SZ000423': 21744, 'SZ000503': 80490, 'SZ000538': 14538, 'SZ000709': 871429, 'SZ002065': 496245, 'SZ002074': 3887, 'SZ002085': 589120, 'SZ002236': 145147, 'SZ002475': 1268582, 'SZ300017': 814}\n",
|
||||
"target weight: {'SH600000': 0.001373124016867567, 'SH600018': 0.02646941123076474, 'SH600028': 0.009458335378810856, 'SH600150': 0.015442533996257352, 'SH600276': 0.09620341387657301, 'SH600649': 0.012613476480118908, 'SH600703': 0.0005280976985716832, 'SH600741': 0.06577156829314017, 'SH600900': 0.12455488881029539, 'SH601258': 0.02270943336842379, 'SH601939': 0.013066707696697587, 'SH603993': 0.0649427819283919, 'SZ000423': 0.0051167756388828005, 'SZ000503': 0.006076486564538039, 'SZ000709': 0.0770418453012855, 'SZ000778': 0.08738918304165759, 'SZ002065': 0.08804613990036694, 'SZ002074': 0.00034315924263262563, 'SZ002085': 0.18241434394629127, 'SZ002475': 0.10035998625624482, 'SZ300017': 7.809604376099223e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16935, 'SH600018': 3089469, 'SH600028': 334906, 'SH600150': 114496, 'SH600276': 71430, 'SH600649': 337388, 'SH600703': 1714, 'SH600741': 419916, 'SH600900': 2493978, 'SH601258': 4194599, 'SH601939': 1279661, 'SH603993': 3750968, 'SZ000423': 21734, 'SZ000503': 80440, 'SZ000709': 872293, 'SZ000778': 366855, 'SZ002065': 496756, 'SZ002074': 3880, 'SZ002085': 564610, 'SZ002475': 1269872, 'SZ300017': 812}\n",
|
||||
"target weight: {'SH600000': 0.0013497287789570015, 'SH600018': 0.02647482761554837, 'SH600028': 0.00941080088689994, 'SH600150': 0.01556139303593115, 'SH600276': 0.09732218714743374, 'SH600649': 0.012606184789019243, 'SH600703': 0.0005334649726542859, 'SH600900': 0.12593267687041163, 'SH601258': 0.021199485570796834, 'SH601939': 0.013025993149697816, 'SH603993': 0.06446918682668012, 'SZ000423': 0.005311875734339093, 'SZ000503': 0.006125989728635501, 'SZ000709': 0.0707610058353687, 'SZ000778': 0.14004715956352495, 'SZ002065': 0.08746446321200681, 'SZ002074': 0.00033710686535540885, 'SZ002085': 0.15238971653801253, 'SZ002146': 0.042585776887618575, 'SZ002475': 0.10701429615740456, 'SZ300017': 7.667981013711115e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 17031, 'SH600018': 3109084, 'SH600028': 336978, 'SH600150': 115126, 'SH600276': 71888, 'SH600649': 339316, 'SH600703': 1724, 'SH600900': 2510148, 'SH601258': 4237748, 'SH601939': 1287810, 'SH603993': 3775382, 'SZ000423': 21853, 'SZ000503': 80885, 'SZ000709': 878077, 'SZ000778': 625157, 'SZ002065': 499988, 'SZ002074': 3901, 'SZ002085': 469624, 'SZ002146': 2000993, 'SZ002475': 1278084, 'SZ300017': 814}\n",
|
||||
"target weight: {'SH600000': 0.0013594926998639766, 'SH600009': 0.021101252574639438, 'SH600028': 0.009528554544265834, 'SH600150': 0.015013601602404225, 'SH600276': 0.09860402207319302, 'SH600649': 0.01292550325031454, 'SH600685': 0.00703471182662378, 'SH600703': 0.0005218767517596246, 'SH600900': 0.12786995199482584, 'SH601258': 0.04401496515184404, 'SH601398': 0.025932829520167643, 'SH601939': 0.0134408200189716, 'SH603993': 0.06319752369639879, 'SZ000423': 0.005221187626834546, 'SZ000503': 0.006085670359590286, 'SZ000568': 0.003081214755480397, 'SZ000709': 0.07061122716452324, 'SZ000778': 0.1379488795662632, 'SZ000839': 0.019142903464547063, 'SZ002065': 0.04714685528331623, 'SZ002074': 0.00033291622875151913, 'SZ002085': 0.11947661465752588, 'SZ002146': 0.043205942689553425, 'SZ002310': 0.0009243182551654129, 'SZ002475': 0.106199974013018, 'SZ300017': 7.709323254732814e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16933, 'SH600009': 196068, 'SH600028': 337025, 'SH600150': 115100, 'SH600276': 71926, 'SH600649': 339354, 'SH600685': 75328, 'SH600703': 1713, 'SH600900': 2511928, 'SH601258': 8791935, 'SH601398': 1146896, 'SH601939': 1288215, 'SH603993': 3777819, 'SZ000423': 21728, 'SZ000503': 80869, 'SZ000568': 10375, 'SZ000709': 878683, 'SZ000778': 625604, 'SZ000839': 312116, 'SZ002065': 268413, 'SZ002074': 3860, 'SZ002085': 369761, 'SZ002146': 2002072, 'SZ002310': 40341, 'SZ002475': 1278918, 'SZ300017': 811}\n",
|
||||
"target weight: {'SH600000': 0.0013764694393366029, 'SH600009': 0.021541655860797534, 'SH600028': 0.009752609535237182, 'SH600276': 0.06514222178877259, 'SH600649': 0.01273168785031133, 'SH600685': 0.006989932070614982, 'SH600900': 0.12998548252109676, 'SH601258': 0.13157540821422453, 'SH601398': 0.02641881439805636, 'SH601939': 0.0136141957873422, 'SH603993': 0.0602411123337629, 'SZ000503': 0.006084251045333903, 'SZ000709': 0.06977363144499521, 'SZ000778': 0.1385461140272643, 'SZ000839': 0.018579865431307987, 'SZ002065': 0.046270476942690986, 'SZ002074': 0.00025974854597178115, 'SZ002085': 0.10060756172850334, 'SZ002146': 0.043204792194791966, 'SZ002310': 0.0009022784286642987, 'SZ002466': 0.011748866835406593, 'SZ002475': 0.08457581284822364, 'SZ300017': 7.701070501151889e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16938, 'SH600009': 196239, 'SH600028': 337355, 'SH600276': 46535, 'SH600649': 339479, 'SH600685': 75274, 'SH600900': 2514488, 'SH601258': 26730440, 'SH601398': 1148157, 'SH601939': 1289259, 'SH603993': 3781937, 'SZ000503': 80900, 'SZ000709': 879645, 'SZ000778': 626285, 'SZ000839': 312384, 'SZ002065': 268717, 'SZ002074': 3093, 'SZ002085': 309206, 'SZ002146': 2003901, 'SZ002310': 39782, 'SZ002466': 367691, 'SZ002475': 1026389, 'SZ300017': 812}\n",
|
||||
"target weight: {'SH600000': 0.0013689894888766726, 'SH600009': 0.021087495457198752, 'SH600028': 0.009589419355091226, 'SH600276': 0.0644304399184473, 'SH600535': 0.016420787426513667, 'SH600649': 0.0267771761277641, 'SH600900': 0.12784455237901315, 'SH601169': 0.004374459372110214, 'SH601258': 0.13288651981531077, 'SH601398': 0.02615927477879055, 'SH601939': 0.013573361058977978, 'SH603993': 1.157895161672162e-06, 'SZ000503': 0.009069218941980683, 'SZ000709': 0.07014466816191627, 'SZ000778': 0.13956352821962528, 'SZ002065': 0.045206445945654664, 'SZ002085': 0.08649963592018277, 'SZ002146': 0.04234588186007612, 'SZ002310': 0.0008924777422846245, 'SZ002466': 0.07334842360184116, 'SZ002475': 0.08834296814868704, 'SZ300017': 7.311841306821287e-05}\n",
|
||||
"Exception: ('SH601169', Timestamp('2017-04-25 00:00:00'))\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16929, 'SH600009': 196092, 'SH600028': 337333, 'SH600276': 46571, 'SH600535': 57649, 'SH600649': 731641, 'SH600900': 2515321, 'SH601258': 26740467, 'SH601398': 1148635, 'SH601939': 1289434, 'SH603993': 72, 'SZ000503': 122157, 'SZ000709': 879908, 'SZ000778': 626506, 'SZ002065': 268767, 'SZ002085': 267906, 'SZ002146': 2004576, 'SZ002310': 39745, 'SZ002466': 2332750, 'SZ002475': 1026858, 'SZ300017': 806}\n",
|
||||
"target weight: {'SH600000': 0.0013439859873209908, 'SH600009': 0.02075652616964347, 'SH600028': 0.00939963933310415, 'SH600276': 0.06236017906066887, 'SH600535': 0.016369568294734148, 'SH600649': 0.025541724367766302, 'SH600900': 0.12768966131041845, 'SH601258': 0.1370446945486361, 'SH601398': 0.02601619218529119, 'SH601939': 0.013440958024818669, 'SH603993': 4.144559709761373e-06, 'SZ000503': 0.0084237188568659, 'SZ000568': 0.020576387679160105, 'SZ000709': 0.056783757531829446, 'SZ000778': 0.06920027928808208, 'SZ002008': 0.07943378393922318, 'SZ002065': 0.045339177613740886, 'SZ002085': 0.08505902525865962, 'SZ002146': 0.031624633954490035, 'SZ002310': 0.0008996156348854183, 'SZ002466': 0.0764983539831682, 'SZ002475': 0.086193992434369}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SZ300017': 812.4573136217659, 'SH600000': 16923, 'SH600009': 196076, 'SH600028': 337279, 'SH600276': 46567, 'SH600535': 57624, 'SH600649': 731549, 'SH600900': 2515891, 'SH601258': 26747448, 'SH601398': 1148886, 'SH601939': 1289307, 'SH603993': 263, 'SZ000503': 122158, 'SZ000568': 69471, 'SZ000709': 700781, 'SZ000778': 302643, 'SZ002008': 746285, 'SZ002065': 268804, 'SZ002085': 267988, 'SZ002146': 1473970, 'SZ002310': 39739, 'SZ002466': 2333288, 'SZ002475': 1027134}\n",
|
||||
"target weight: {'SH600000': 0.0014508867295425067, 'SH600009': 0.022137935734971876, 'SH600028': 0.01003980705499816, 'SH600276': 0.065554410760754, 'SH600535': 0.017337663954140436, 'SH600649': 0.026752732524884384, 'SH600900': 0.13610376526017787, 'SH601258': 0.14230666244775886, 'SH601398': 0.027847743092481312, 'SH601939': 0.014306563408357105, 'SH603993': 2.7770868647848817e-06, 'SZ000069': 0.10104502775773525, 'SZ000503': 0.009049444347506782, 'SZ000568': 0.005686495401232644, 'SZ000778': 0.0715782861850023, 'SZ002008': 0.08609584908472251, 'SZ002065': 0.04706561122827146, 'SZ002085': 0.09099179117275048, 'SZ002146': 0.03204301334262787, 'SZ002475': 0.09241758644387384, 'SZ300017': 0.00018594702102337797}\n",
|
||||
"target position: {'SZ000709': 700825.0269758024, 'SZ002299': 6184584.0980107365, 'SH600000': 16845, 'SH600009': 195098, 'SH600028': 335689, 'SH600276': 46340, 'SH600535': 57343, 'SH600649': 728078, 'SH600900': 2504242, 'SH601258': 26624542, 'SH601398': 1143577, 'SH601939': 1283067, 'SH603993': 160, 'SZ000069': 367637, 'SZ000503': 121565, 'SZ000568': 17626, 'SZ000778': 301250, 'SZ002008': 742790, 'SZ002065': 267559, 'SZ002085': 266737, 'SZ002146': 1467579, 'SZ002475': 1022346, 'SZ300017': 1776}\n",
|
||||
"target weight: {'SH600000': 0.0013484985106016394, 'SH600009': 0.020750773768622693, 'SH600028': 0.009285673867962157, 'SH600104': 2.9067007814076732e-05, 'SH600196': 0.10012804077099052, 'SH600276': 0.05943563439541343, 'SH600535': 0.015902136087846228, 'SH600649': 0.025189836387314323, 'SH600900': 0.12584805827140388, 'SH601111': 6.857382365314848e-06, 'SH601258': 0.03895938466363849, 'SH601398': 0.025753888553878806, 'SH601939': 0.013275755331575599, 'SH603993': 4.249178615404585e-06, 'SZ000069': 0.09445579375504781, 'SZ000503': 0.008532747266799033, 'SZ000568': 0.0052599046052527266, 'SZ000709': 0.06003418476540357, 'SZ000778': 0.06923031488245988, 'SZ002008': 0.07903025205993618, 'SZ002065': 0.04448484691775433, 'SZ002085': 0.08426354045447453, 'SZ002146': 0.031142767130486235, 'SZ002475': 0.08747938111190227, 'SZ300017': 0.00016841662419817417}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16906, 'SH600009': 195107, 'SH600028': 335257, 'SH600104': 197, 'SH600196': 630404, 'SH600276': 46282, 'SH600535': 57311, 'SH600649': 727170, 'SH600900': 2500379, 'SH601111': 203, 'SH601258': 7443096, 'SH601398': 1142014, 'SH601939': 1281361, 'SH603993': 263, 'SZ000069': 366998, 'SZ000503': 121479, 'SZ000568': 17699, 'SZ000709': 699639, 'SZ000778': 300752, 'SZ002008': 741767, 'SZ002065': 267133, 'SZ002085': 266334, 'SZ002146': 1465489, 'SZ002475': 1020693, 'SZ300017': 1756}\n",
|
||||
"target weight: {'SH600000': 0.0012976336004362882, 'SH600009': 0.0204756895024156, 'SH600028': 0.008883617000656601, 'SH600104': 2.592943319382378e-05, 'SH600196': 0.09617041827497698, 'SH600276': 0.05681162545715886, 'SH600535': 0.015294256733040745, 'SH600649': 0.02417676167926707, 'SH600900': 0.12233373885315162, 'SH601398': 0.024531954099214746, 'SH601628': 0.005044154324745466, 'SH601888': 0.09500034426651846, 'SH601939': 0.012657033879067425, 'SH603993': 4.079522960136806e-06, 'SZ000069': 0.09054142453059062, 'SZ000503': 0.008036587259744734, 'SZ000568': 0.0049533657881637655, 'SZ000778': 0.06904486736535222, 'SZ002008': 0.06688985213943154, 'SZ002065': 0.04278977877238287, 'SZ002085': 0.0820368284038888, 'SZ002299': 0.06899317887598991, 'SZ002475': 0.08384652594205952, 'SZ300017': 0.00016035416530955983}\n",
|
||||
"target position: {'SH601258': 7443495.190430395, 'SH600000': 16952, 'SH600009': 195676, 'SH600028': 336044, 'SH600104': 183, 'SH600196': 631454, 'SH600276': 46372, 'SH600535': 57498, 'SH600649': 728582, 'SH600900': 2504660, 'SH601398': 1143938, 'SH601628': 695470, 'SH601888': 2951253, 'SH601939': 1283887, 'SH603993': 255, 'SZ000069': 367641, 'SZ000503': 121875, 'SZ000568': 17775, 'SZ000778': 301255, 'SZ002008': 638620, 'SZ002065': 267645, 'SZ002085': 266802, 'SZ002299': 6194843, 'SZ002475': 1022527, 'SZ300017': 1765}\n",
|
||||
"target weight: {'SH600000': 0.0013469483722729403, 'SH600028': 0.009286467498269333, 'SH600104': 2.368500734977497e-05, 'SH600196': 0.10145424564201923, 'SH600276': 0.06002237364700993, 'SH600535': 0.01588332650422844, 'SH600649': 0.025440421851940002, 'SH600900': 0.1279028471227695, 'SH601258': 0.035917606048396986, 'SH601398': 0.02559318344055778, 'SH601628': 0.005221942888216608, 'SH601888': 0.14928498761757883, 'SH601939': 0.013161430940131148, 'SH603993': 4.350147095904942e-06, 'SZ000069': 0.14038473724819095, 'SZ000503': 0.008556251357999256, 'SZ000568': 0.005243511514392524, 'SZ002008': 0.06824325050397591, 'SZ002065': 0.04420632869308568, 'SZ002085': 0.074424247013131, 'SZ002299': 0.0010812901181988855, 'SZ002475': 0.0871460668952185, 'SZ300017': 0.00017049992832446128}\n",
|
||||
"target position: {'SZ000778': 301254.84776855103, 'SH600000': 16873, 'SH600028': 335064, 'SH600104': 156, 'SH600196': 629613, 'SH600276': 46235, 'SH600535': 57245, 'SH600649': 726346, 'SH600900': 2497776, 'SH601258': 7423462, 'SH601398': 1140689, 'SH601628': 692346, 'SH601888': 4557826, 'SH601939': 1279908, 'SH603993': 261, 'SZ000069': 551887, 'SZ000503': 121344, 'SZ000568': 17697, 'SZ002008': 636943, 'SZ002065': 266904, 'SZ002085': 231781, 'SZ002299': 97527, 'SZ002475': 1019747, 'SZ300017': 1749}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "error",
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[1;32m<ipython-input-49-2e7986244749>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[1;31m# backtest & analysis\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 31\u001b[0m \u001b[0mpar\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mport_analysis_config\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 32\u001b[1;33m \u001b[0mpar\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\workflow\\record_temp.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(self, **kwargs)\u001b[0m\n\u001b[0;32m 230\u001b[0m \u001b[1;31m# custom strategy and get backtest\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 231\u001b[0m \u001b[0mpred_score\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 232\u001b[1;33m \u001b[0mreport_normal\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpositions_normal\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnormal_backtest\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpred_score\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstrategy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbacktest_config\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 233\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave_objects\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m\"report_normal.pkl\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mreport_normal\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0martifact_path\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_path\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 234\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave_objects\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m\"positions_normal.pkl\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mpositions_normal\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0martifact_path\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_path\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\contrib\\evaluate.py\u001b[0m in \u001b[0;36mbacktest\u001b[1;34m(pred, account, shift, benchmark, verbose, **kwargs)\u001b[0m\n\u001b[0;32m 269\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 270\u001b[0m \u001b[0maccount\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0maccount\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 271\u001b[1;33m \u001b[0mbenchmark\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbenchmark\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 272\u001b[0m )\n\u001b[0;32m 273\u001b[0m \u001b[1;31m# for compatibility of the old API. return the dict positions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\contrib\\backtest\\backtest.py\u001b[0m in \u001b[0;36mbacktest\u001b[1;34m(pred, strategy, trade_exchange, shift, verbose, account, benchmark)\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[0mtrade_exchange\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrade_exchange\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[0mpred_date\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mpred_date\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m \u001b[0mtrade_date\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrade_date\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 103\u001b[0m )\n\u001b[0;32m 104\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32m<ipython-input-46-65beeeee07c0>\u001b[0m in \u001b[0;36mgenerate_order_list\u001b[1;34m(self, score_series, current, trade_exchange, pred_date, trade_date)\u001b[0m\n\u001b[0;32m 76\u001b[0m \u001b[1;31m# optimize target portfolio\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0minit_weight\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 78\u001b[1;33m \u001b[0mtarget_weight\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcov\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mscore_series\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minit_weight\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 79\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 80\u001b[0m \u001b[0mtarget_weight\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcov\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mscore_series\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[1;31m# optimize\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m \u001b[0mw\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_optimize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 103\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 104\u001b[0m \u001b[1;31m# restore index if needed\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_optimize\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 126\u001b[0m \u001b[1;31m# mean-variance\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 127\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmethod\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOPT_MVO\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 128\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_optimize_mvo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 129\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 130\u001b[0m \u001b[1;31m# risk parity\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_optimize_mvo\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 162\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mlamb\u001b[0m\u001b[0;31m`\u001b[0m \u001b[1;32mis\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mrisk\u001b[0m \u001b[0maversion\u001b[0m \u001b[0mparameter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 163\u001b[0m \"\"\"\n\u001b[1;32m--> 164\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_solve\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_objective_mvo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_constrains\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 165\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 166\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_optimize_rp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mS\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_solve\u001b[1;34m(self, n, obj, bounds, cons)\u001b[0m\n\u001b[0;32m 252\u001b[0m \u001b[1;31m# solve\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 253\u001b[0m \u001b[0mx0\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mones\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0mn\u001b[0m \u001b[1;31m# init results\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 254\u001b[1;33m \u001b[0msol\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mso\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mminimize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mwrapped_obj\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbounds\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbounds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconstraints\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcons\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtol\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 255\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msol\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msuccess\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 256\u001b[0m \u001b[0mwarnings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwarn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf\"optimization not success ({sol.status})\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32m~\\AppData\\Local\\Continuum\\miniconda3\\envs\\qlib\\lib\\site-packages\\scipy\\optimize\\_minimize.py\u001b[0m in \u001b[0;36mminimize\u001b[1;34m(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)\u001b[0m\n\u001b[0;32m 624\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0mmeth\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'slsqp'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 625\u001b[0m return _minimize_slsqp(fun, x0, args, jac, bounds,\n\u001b[1;32m--> 626\u001b[1;33m constraints, callback=callback, **options)\n\u001b[0m\u001b[0;32m 627\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0mmeth\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'trust-constr'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 628\u001b[0m return _minimize_trustregion_constr(fun, x0, args, jac, hess, hessp,\n",
|
||||
"\u001b[1;32m~\\AppData\\Local\\Continuum\\miniconda3\\envs\\qlib\\lib\\site-packages\\scipy\\optimize\\slsqp.py\u001b[0m in \u001b[0;36m_minimize_slsqp\u001b[1;34m(func, x0, args, jac, bounds, constraints, maxiter, ftol, iprint, disp, eps, callback, finite_diff_rel_step, **unknown_options)\u001b[0m\n\u001b[0;32m 419\u001b[0m n1, n2, n3)\n\u001b[0;32m 420\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 421\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# objective and constraint evaluation required\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 422\u001b[0m \u001b[0mfx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 423\u001b[0m \u001b[0mc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_eval_constraint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcons\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"###################################\n",
|
||||
"# prediction, backtest & analysis\n",
|
||||
"###################################\n",
|
||||
"port_analysis_config = {\n",
|
||||
" \"strategy\": strategy,\n",
|
||||
" \"backtest\": {\n",
|
||||
" \"verbose\": False,\n",
|
||||
" \"limit_threshold\": 0.095,\n",
|
||||
" \"account\": 100000000,\n",
|
||||
" \"benchmark\": benchmark,\n",
|
||||
" \"deal_price\": \"close\",\n",
|
||||
" \"open_cost\": 0.0005,\n",
|
||||
" \"close_cost\": 0.0015,\n",
|
||||
" \"min_cost\": 5,\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# backtest and analysis\n",
|
||||
"with R.start(experiment_name=\"backtest_analysis\"):\n",
|
||||
" recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
|
||||
" model = recorder.load_object(\"trained_model\")\n",
|
||||
"\n",
|
||||
" # prediction\n",
|
||||
" recorder = R.get_recorder()\n",
|
||||
" ba_rid = recorder.id\n",
|
||||
" sr = SignalRecord(model, dataset, recorder)\n",
|
||||
" sr.generate()\n",
|
||||
"\n",
|
||||
" # backtest & analysis\n",
|
||||
" par = PortAnaRecord(recorder, port_analysis_config)\n",
|
||||
" par.generate()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -4,18 +4,20 @@
|
||||
import os
|
||||
import sys
|
||||
import fire
|
||||
import time
|
||||
import venv
|
||||
import glob
|
||||
import shutil
|
||||
import signal
|
||||
import inspect
|
||||
import tempfile
|
||||
import traceback
|
||||
import functools
|
||||
import statistics
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from operator import xor
|
||||
from subprocess import Popen, PIPE
|
||||
from threading import Thread
|
||||
from pprint import pprint
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
@@ -23,144 +25,53 @@ from qlib.workflow import R
|
||||
from qlib.workflow.cli import workflow
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
# init qlib
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
exp_folder_name = "run_all_model_records"
|
||||
exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name)
|
||||
exp_manager = {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "file:" + exp_path,
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
if os.path.isdir(exp_path):
|
||||
shutil.rmtree(exp_path)
|
||||
|
||||
# decorator to check the arguments
|
||||
def only_allow_defined_args(function_to_decorate):
|
||||
@functools.wraps(function_to_decorate)
|
||||
def _return_wrapped(*args, **kwargs):
|
||||
"""Internal wrapper function."""
|
||||
argspec = inspect.getfullargspec(function_to_decorate)
|
||||
valid_names = set(argspec.args + argspec.kwonlyargs)
|
||||
if "self" in valid_names:
|
||||
valid_names.remove("self")
|
||||
for arg_name in kwargs:
|
||||
if arg_name not in valid_names:
|
||||
raise ValueError("Unknown argument seen '%s', expected: [%s]" % (arg_name, ", ".join(valid_names)))
|
||||
return function_to_decorate(*args, **kwargs)
|
||||
|
||||
return _return_wrapped
|
||||
|
||||
|
||||
class ExtendedEnvBuilder(venv.EnvBuilder):
|
||||
"""
|
||||
Thie class is modified based on https://docs.python.org/3/library/venv.html.
|
||||
This builder installs setuptools and pip so that you can pip or
|
||||
easy_install other packages into the created virtual environment.
|
||||
# function to handle ctrl z and ctrl c
|
||||
def handler(signum, frame):
|
||||
os.system("kill -9 %d" % os.getpid())
|
||||
|
||||
:param nodist: If true, setuptools and pip are not installed into the
|
||||
created virtual environment.
|
||||
:param nopip: If true, pip is not installed into the created
|
||||
virtual environment.
|
||||
:param progress: If setuptools or pip are installed, the progress of the
|
||||
installation can be monitored by passing a progress
|
||||
callable. If specified, it is called with two
|
||||
arguments: a string indicating some progress, and a
|
||||
context indicating where the string is coming from.
|
||||
The context argument can have one of three values:
|
||||
'main', indicating that it is called from virtualize()
|
||||
itself, and 'stdout' and 'stderr', which are obtained
|
||||
by reading lines from the output streams of a subprocess
|
||||
which is used to install the app.
|
||||
|
||||
If a callable is not specified, default progress
|
||||
information is output to sys.stderr.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.nodist = kwargs.pop("nodist", False)
|
||||
self.nopip = kwargs.pop("nopip", False)
|
||||
self.progress = kwargs.pop("progress", None)
|
||||
self.verbose = kwargs.pop("verbose", False)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def post_setup(self, context):
|
||||
"""
|
||||
Set up any packages which need to be pre-installed into the
|
||||
virtual environment being created.
|
||||
|
||||
:param context: The information for the virtual environment
|
||||
creation request being processed.
|
||||
"""
|
||||
os.environ["VIRTUAL_ENV"] = context.env_dir
|
||||
if not self.nodist:
|
||||
self.install_setuptools(context)
|
||||
# Can't install pip without setuptools
|
||||
if not self.nopip and not self.nodist:
|
||||
self.install_pip(context)
|
||||
|
||||
def reader(self, stream, context):
|
||||
"""
|
||||
Read lines from a subprocess' output stream and either pass to a progress
|
||||
callable (if specified) or write progress information to sys.stderr.
|
||||
"""
|
||||
progress = self.progress
|
||||
while True:
|
||||
s = stream.readline()
|
||||
if not s:
|
||||
break
|
||||
if progress is not None:
|
||||
progress(s, context)
|
||||
else:
|
||||
if not self.verbose:
|
||||
sys.stderr.write(".")
|
||||
else:
|
||||
sys.stderr.write(s.decode("utf-8"))
|
||||
sys.stderr.flush()
|
||||
stream.close()
|
||||
|
||||
def install_script(self, context, name, url):
|
||||
_, _, path, _, _, _ = urlparse(url)
|
||||
fn = os.path.split(path)[-1]
|
||||
binpath = context.bin_path
|
||||
distpath = os.path.join(binpath, fn)
|
||||
# Download script into the virtual environment's binaries folder
|
||||
urlretrieve(url, distpath)
|
||||
progress = self.progress
|
||||
if self.verbose:
|
||||
term = "\n"
|
||||
else:
|
||||
term = ""
|
||||
if progress is not None:
|
||||
progress("Installing %s ...%s" % (name, term), "main")
|
||||
else:
|
||||
sys.stderr.write("Installing %s ...%s" % (name, term))
|
||||
sys.stderr.flush()
|
||||
# Install in the virtual environment
|
||||
args = [context.env_exe, fn]
|
||||
p = Popen(args, stdout=PIPE, stderr=PIPE, cwd=binpath)
|
||||
t1 = Thread(target=self.reader, args=(p.stdout, "stdout"))
|
||||
t1.start()
|
||||
t2 = Thread(target=self.reader, args=(p.stderr, "stderr"))
|
||||
t2.start()
|
||||
p.wait()
|
||||
t1.join()
|
||||
t2.join()
|
||||
if progress is not None:
|
||||
progress("done.", "main")
|
||||
else:
|
||||
sys.stderr.write("done.\n")
|
||||
# Clean up - no longer needed
|
||||
os.unlink(distpath)
|
||||
|
||||
def install_setuptools(self, context):
|
||||
"""
|
||||
Install setuptools in the virtual environment.
|
||||
|
||||
:param context: The information for the virtual environment
|
||||
creation request being processed.
|
||||
"""
|
||||
url = "https://bootstrap.pypa.io/ez_setup.py"
|
||||
self.install_script(context, "setuptools", url)
|
||||
# clear up the setuptools archive which gets downloaded
|
||||
pred = lambda o: o.startswith("setuptools-") and o.endswith(".tar.gz")
|
||||
files = filter(pred, os.listdir(context.bin_path))
|
||||
for f in files:
|
||||
f = os.path.join(context.bin_path, f)
|
||||
os.unlink(f)
|
||||
|
||||
def install_pip(self, context):
|
||||
"""
|
||||
Install pip in the virtual environment.
|
||||
|
||||
:param context: The information for the virtual environment
|
||||
creation request being processed.
|
||||
"""
|
||||
url = "https://bootstrap.pypa.io/get-pip.py"
|
||||
self.install_script(context, "pip", url)
|
||||
|
||||
signal.signal(signal.SIGTSTP, handler)
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
|
||||
# function to calculate the mean and std of a list in the results dictionary
|
||||
def cal_mean_std(results) -> dict:
|
||||
@@ -174,6 +85,36 @@ def cal_mean_std(results) -> dict:
|
||||
return mean_std
|
||||
|
||||
|
||||
# function to create the environment ofr an anaconda environment
|
||||
def create_env():
|
||||
# create env
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
env_path = Path(temp_dir).absolute()
|
||||
sys.stderr.write(f"Creating Virtual Environment with path: {env_path}...\n")
|
||||
execute(f"conda create --prefix {env_path} python=3.7 -y")
|
||||
python_path = env_path / "bin" / "python" # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# get anaconda activate path
|
||||
conda_activate = Path(os.environ["CONDA_PREFIX"]) / "bin" / "activate" # TODO: FIX ME!
|
||||
return env_path, python_path, conda_activate
|
||||
|
||||
|
||||
# function to execute the cmd
|
||||
def execute(cmd):
|
||||
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
|
||||
for line in p.stdout:
|
||||
sys.stdout.write(line.split("\b")[0])
|
||||
if "\b" in line:
|
||||
sys.stdout.flush()
|
||||
time.sleep(0.1)
|
||||
sys.stdout.write("\b" * 10 + "\b".join(line.split("\b")[1:-1]))
|
||||
|
||||
if p.returncode != 0:
|
||||
return p.stderr
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# function to get all the folders benchmark folder
|
||||
def get_all_folders(models, exclude) -> dict:
|
||||
folders = dict()
|
||||
@@ -212,11 +153,12 @@ def get_all_results(folders) -> dict:
|
||||
result["information_ratio_with_cost"] = list()
|
||||
result["max_drawdown_with_cost"] = list()
|
||||
for recorder_id in recorders:
|
||||
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
|
||||
metrics = recorder.list_metrics()
|
||||
result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
|
||||
result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
|
||||
result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
|
||||
if recorders[recorder_id].status == "FINISHED":
|
||||
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
|
||||
metrics = recorder.list_metrics()
|
||||
result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
|
||||
result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
|
||||
result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
|
||||
results[fn] = result
|
||||
return results
|
||||
|
||||
@@ -237,6 +179,7 @@ def gen_and_save_md_table(metrics):
|
||||
|
||||
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(times=1, models=None, exclude=False):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
@@ -275,53 +218,48 @@ def run(times=1, models=None, exclude=False):
|
||||
"""
|
||||
# get all folders
|
||||
folders = get_all_folders(models, exclude)
|
||||
# set up
|
||||
compatible = True
|
||||
if sys.version_info < (3, 3):
|
||||
compatible = False
|
||||
elif not hasattr(sys, "base_prefix"):
|
||||
compatible = False
|
||||
if not compatible:
|
||||
raise ValueError("This script is only for use with " "Python 3.3 or later")
|
||||
if os.name == "nt":
|
||||
use_symlinks = False
|
||||
else:
|
||||
use_symlinks = True
|
||||
builder = ExtendedEnvBuilder(
|
||||
system_site_packages=False,
|
||||
clear=False,
|
||||
symlinks=use_symlinks,
|
||||
upgrade=False,
|
||||
nodist=False,
|
||||
nopip=False,
|
||||
verbose=False,
|
||||
)
|
||||
# init error messages:
|
||||
errors = dict()
|
||||
# run all the model for iterations
|
||||
for fn in folders:
|
||||
# create env
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
env_path = Path(temp_dir).absolute()
|
||||
sys.stderr.write(f"Creating Virtual Environment with path: {env_path}...\n")
|
||||
builder.create(str(env_path))
|
||||
python_path = env_path / "bin" / "python" # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# create env by anaconda
|
||||
env_path, python_path, conda_activate = create_env()
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn])
|
||||
sys.stderr.write("\n")
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
os.system(f"{python_path} -m pip install -r {req_path}")
|
||||
execute(f"{python_path} -m pip install -r {req_path}")
|
||||
sys.stderr.write("\n")
|
||||
# setup gpu for tft
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn"
|
||||
)
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
os.system(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
|
||||
os.system(f"{python_path} -m pip install -e git+https://github.com/you-n-g/qlib#egg=pyqlib") # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e git+https://github.com/you-n-g/qlib#egg=pyqlib"
|
||||
) # TODO: FIX ME!
|
||||
else:
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e git+https://github.com/you-n-g/qlib#egg=pyqlib"
|
||||
) # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config for multiple times
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
os.system(f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn}")
|
||||
errs = execute(
|
||||
f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
|
||||
)
|
||||
if errs is not None:
|
||||
_errs = errors.get(fn, {})
|
||||
_errs.update({i: errs})
|
||||
errors[fn] = _errs
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
@@ -335,13 +273,12 @@ def run(times=1, models=None, exclude=False):
|
||||
# generating md table
|
||||
sys.stderr.write(f"Generating markdown table...\n")
|
||||
gen_and_save_md_table(results)
|
||||
sys.stderr.write("\n")
|
||||
# print erros
|
||||
sys.stderr.write(f"Here are some of the errors of the models...\n")
|
||||
pprint(errors)
|
||||
sys.stderr.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
rc = 1
|
||||
try:
|
||||
fire.Fire(run) # run all the model
|
||||
rc = 0
|
||||
except Exception as e:
|
||||
print("Error: %s" % e, file=sys.stderr)
|
||||
sys.exit(rc)
|
||||
fire.Fire(run) # run all the model
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -10,14 +17,43 @@
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys, site\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" import qlib\n",
|
||||
"except ImportError:\n",
|
||||
" # install qlib\n",
|
||||
" ! pip install pyqlib\n",
|
||||
" # reload\n",
|
||||
" site.main()\n",
|
||||
"\n",
|
||||
"scripts_dir = Path.cwd().parent.joinpath(\"scripts\")\n",
|
||||
"if not scripts_dir.joinpath(\"get_data.py\").exists():\n",
|
||||
" # download get_data.py script\n",
|
||||
" scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
|
||||
" scripts_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" import requests\n",
|
||||
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n",
|
||||
" with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
|
||||
" fp.write(resp.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import qlib\n",
|
||||
"import pandas as pd\n",
|
||||
@@ -32,7 +68,7 @@
|
||||
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
|
||||
"from qlib.utils import flatten_dict"
|
||||
"from qlib.utils import flatten_dict\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -48,7 +84,7 @@
|
||||
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
|
||||
"if not exists_qlib_data(provider_uri):\n",
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
|
||||
" sys.path.append(str(scripts_dir))\n",
|
||||
" from get_data import GetData\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
@@ -202,7 +238,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
@@ -320,7 +358,8 @@
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "ALSTM",
|
||||
"module_path": "qlib.contrib.model.pytorch_alstm",
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.0,
|
||||
"n_epochs": 200,
|
||||
"lr": 1e-3,
|
||||
"early_stop": 20,
|
||||
"batch_size": 800,
|
||||
"metric": "IC",
|
||||
"loss": "mse",
|
||||
"seed": 0,
|
||||
"GPU": "0",
|
||||
"rnn_type": "GRU",
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
@@ -1,140 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "GAT",
|
||||
"module_path": "qlib.contrib.model.pytorch_gats",
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.7,
|
||||
"n_epochs": 200,
|
||||
"lr": 1e-4,
|
||||
"early_stop": 20,
|
||||
"metric": "loss",
|
||||
"loss": "mse",
|
||||
"base_model": "LSTM",
|
||||
"with_pretrain": True,
|
||||
"seed": 0,
|
||||
"GPU": "0",
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
@@ -1,144 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.pytorch_gru import GRU
|
||||
from qlib.contrib.data.handler import ALPHA360_Denoise
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
# from qlib.model.learner import train_model
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
import pickle
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "GRU",
|
||||
"module_path": "qlib.contrib.model.pytorch_gru",
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.0,
|
||||
"n_epochs": 200,
|
||||
"lr": 1e-3,
|
||||
"early_stop": 20,
|
||||
"batch_size": 800,
|
||||
"metric": "loss",
|
||||
"loss": "mse",
|
||||
"seed": 0,
|
||||
"GPU": 0,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
# model = train_model(task)
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
@@ -1,136 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "HATS",
|
||||
"module_path": "qlib.contrib.model.pytorch_hats",
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.7,
|
||||
"n_epochs": 200,
|
||||
"lr": 1e-4,
|
||||
"early_stop": 20,
|
||||
"metric": "loss",
|
||||
"loss": "mse",
|
||||
"base_model": "LSTM",
|
||||
"seed": 0,
|
||||
"GPU": "2",
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset, save_path="benchmarks/HATS/model_hat.pkl")
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
@@ -1,144 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.pytorch_lstm import LSTM
|
||||
from qlib.contrib.data.handler import ALPHA360_Denoise
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
# from qlib.model.learner import train_model
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
import pickle
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LSTM",
|
||||
"module_path": "qlib.contrib.model.pytorch_lstm",
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.0,
|
||||
"n_epochs": 200,
|
||||
"lr": 1e-3,
|
||||
"early_stop": 20,
|
||||
"batch_size": 800,
|
||||
"metric": "IC",
|
||||
"loss": "mse",
|
||||
"seed": 0,
|
||||
"GPU": 0,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
# model = train_model(task)
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
@@ -1,158 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.pytorch_gru import GRU
|
||||
from qlib.contrib.data.handler import ALPHA360_Denoise
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
import pickle
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "SFM",
|
||||
"module_path": "qlib.contrib.model.pytorch_sfm",
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"output_dim": 32,
|
||||
"freq_dim": 25,
|
||||
"dropout_W": 0.5,
|
||||
"dropout_U": 0.5,
|
||||
"n_epochs": 15,
|
||||
"lr": 1e-3,
|
||||
"metric": "",
|
||||
"batch_size": 1600,
|
||||
"early_stop": 20,
|
||||
"eval_steps": 5,
|
||||
"loss": "mse",
|
||||
"lr_decay": 0.96,
|
||||
"lr_decay_steps": 100,
|
||||
"optimizer": "adam",
|
||||
"GPU": 3,
|
||||
"seed": 710,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
# model = train_model(task)
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
@@ -1,142 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.tabnet import TabNetModel
|
||||
from qlib.contrib.data.handler import ALPHA360_Denoise
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
# from qlib.model.learner import train_model
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
import pickle
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "TabNetModel",
|
||||
"module_path": "qlib.contrib.model.tabnet",
|
||||
"kwargs": {
|
||||
"n_d": 8,
|
||||
"n_a": 8,
|
||||
"n_steps": 3,
|
||||
"gamma": 1.3,
|
||||
"n_independent": 2,
|
||||
"n_shared": 2,
|
||||
"seed": 0,
|
||||
"momentum": 0.02,
|
||||
"lambda_sparse": 1e-3,
|
||||
"optimizer_params": {"lr": 2e-3},
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
# model = train_model(task)
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
@@ -41,7 +41,9 @@ class CatBoostModel(Model):
|
||||
**kwargs
|
||||
):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
@@ -11,7 +11,12 @@ import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
@@ -109,7 +114,10 @@ class ALSTM(Model):
|
||||
)
|
||||
|
||||
self.ALSTM_model = ALSTMModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
|
||||
@@ -141,7 +149,7 @@ class ALSTM(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -219,7 +227,9 @@ class ALSTM(Model):
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
@@ -328,10 +338,16 @@ class ALSTMModel(nn.Module):
|
||||
)
|
||||
self.fc_out = nn.Linear(in_features=self.hid_size * 2, out_features=1)
|
||||
self.att_net = nn.Sequential()
|
||||
self.att_net.add_module("att_fc_in", nn.Linear(in_features=self.hid_size, out_features=int(self.hid_size / 2)))
|
||||
self.att_net.add_module(
|
||||
"att_fc_in",
|
||||
nn.Linear(in_features=self.hid_size, out_features=int(self.hid_size / 2)),
|
||||
)
|
||||
self.att_net.add_module("att_dropout", torch.nn.Dropout(self.dropout))
|
||||
self.att_net.add_module("att_act", nn.Tanh())
|
||||
self.att_net.add_module("att_fc_out", nn.Linear(in_features=int(self.hid_size / 2), out_features=1, bias=False))
|
||||
self.att_net.add_module(
|
||||
"att_fc_out",
|
||||
nn.Linear(in_features=int(self.hid_size / 2), out_features=1, bias=False),
|
||||
)
|
||||
self.att_net.add_module("att_softmax", nn.Softmax(dim=1))
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
79
qlib/contrib/model/pytorch_gats.py
Executable file → Normal file
79
qlib/contrib/model/pytorch_gats.py
Executable file → Normal file
@@ -12,6 +12,7 @@ import copy
|
||||
from ...utils import create_save_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
@@ -19,10 +20,12 @@ import torch.optim as optim
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
|
||||
class GAT(Model):
|
||||
"""GAT Model
|
||||
class GATs(Model):
|
||||
"""GATs Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -57,8 +60,8 @@ class GAT(Model):
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GAT")
|
||||
self.logger.info("GAT pytorch version...")
|
||||
self.logger = get_module_logger("GATs")
|
||||
self.logger.info("GATs pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
@@ -78,7 +81,7 @@ class GAT(Model):
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"GAT parameters setting:"
|
||||
"GATs parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
@@ -149,18 +152,18 @@ class GAT(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily inter as daily batches
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle the daily inter data
|
||||
# shuffle data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
@@ -172,7 +175,7 @@ class GAT(Model):
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
self.GAT_model.train()
|
||||
|
||||
# organize the train data into daily inter as daily batches
|
||||
# organize the train data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
@@ -203,7 +206,7 @@ class GAT(Model):
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
# organize the test data into daily inter as daily batches
|
||||
# organize the test data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
@@ -233,7 +236,9 @@ class GAT(Model):
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
@@ -251,15 +256,13 @@ class GAT(Model):
|
||||
if self.with_pretrain:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
if self.base_model == "LSTM":
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
|
||||
pretrained_model = LSTMModel()
|
||||
pretrained_model.load_state_dict(torch.load("benchmarks/LSTM/model_lstm_csi300.pkl"))
|
||||
elif self.base_model == "GRU":
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
pretrained_model.load_state_dict(torch.load("benchmarks/GRU/model_gru_csi300.pkl"))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
@@ -269,7 +272,6 @@ class GAT(Model):
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
# return
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -310,7 +312,7 @@ class GAT(Model):
|
||||
x_values = x_test.values
|
||||
preds = []
|
||||
|
||||
# organize the data into daily inter as daily batches
|
||||
# organize the data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
@@ -355,22 +357,29 @@ class GATModel(nn.Module):
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.bn1 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
|
||||
self.fc = nn.Linear(hidden_size, hidden_size)
|
||||
self.bn2 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
|
||||
self.d_feat = d_feat
|
||||
self.transformation = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
self.a = nn.Parameter(torch.randn(self.hidden_size * 2, 1))
|
||||
self.a.requires_grad = True
|
||||
self.fc = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
self.fc_out = nn.Linear(hidden_size, 1)
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
self.d_feat = d_feat
|
||||
|
||||
def cal_convariance(self, x, y): # the 2nd dimension of x and y are the same
|
||||
e_x = torch.mean(x, dim=1).reshape(-1, 1)
|
||||
e_y = torch.mean(y, dim=1).reshape(-1, 1)
|
||||
e_x_e_y = e_x.mm(torch.t(e_y))
|
||||
x_extend = x.reshape(x.shape[0], 1, x.shape[1]).repeat(1, y.shape[0], 1)
|
||||
y_extend = y.reshape(1, y.shape[0], y.shape[1]).repeat(x.shape[0], 1, 1)
|
||||
e_xy = torch.mean(x_extend * y_extend, dim=2)
|
||||
return e_xy - e_x_e_y
|
||||
def cal_attention(self, x, y):
|
||||
x = self.transformation(x)
|
||||
y = self.transformation(y)
|
||||
|
||||
sample_num = x.shape[0]
|
||||
dim = x.shape[1]
|
||||
e_x = x.expand(sample_num, sample_num, dim)
|
||||
e_y = torch.transpose(e_x, 0, 1)
|
||||
attention_in = torch.cat((e_x, e_y), 2).view(-1, dim * 2)
|
||||
self.a_t = torch.t(self.a)
|
||||
attention_out = self.a_t.mm(torch.t(attention_in)).view(sample_num, sample_num)
|
||||
attention_out = self.leaky_relu(attention_out)
|
||||
att_weight = self.softmax(attention_out)
|
||||
return att_weight
|
||||
|
||||
def forward(self, x):
|
||||
# x: [N, F*T]
|
||||
@@ -378,10 +387,8 @@ class GATModel(nn.Module):
|
||||
x = x.permute(0, 2, 1) # [N, T, F]
|
||||
out, _ = self.rnn(x)
|
||||
hidden = out[:, -1, :]
|
||||
hidden = self.bn1(hidden)
|
||||
gamma = self.cal_convariance(hidden, hidden)
|
||||
output = gamma.mm(hidden)
|
||||
output = self.fc(output)
|
||||
output = self.bn2(output)
|
||||
output = self.leaky_relu(output)
|
||||
return self.fc_out(output).squeeze()
|
||||
att_weight = self.cal_attention(hidden, hidden)
|
||||
hidden = att_weight.mm(hidden) + hidden
|
||||
hidden = self.fc(hidden)
|
||||
hidden = self.leaky_relu(hidden)
|
||||
return self.fc_out(hidden).squeeze()
|
||||
|
||||
@@ -11,7 +11,12 @@ import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
@@ -109,7 +114,10 @@ class GRU(Model):
|
||||
)
|
||||
|
||||
self.gru_model = GRUModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.gru_model.parameters(), lr=self.lr)
|
||||
@@ -141,7 +149,7 @@ class GRU(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -219,7 +227,9 @@ class GRU(Model):
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
|
||||
@@ -1,491 +0,0 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import create_save_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class HATS(Model):
|
||||
"""HATS Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.5,
|
||||
n_epochs=200,
|
||||
lr=0.01,
|
||||
metric="",
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
with_pretrain=True,
|
||||
optimizer="adam",
|
||||
GPU="0",
|
||||
seed=0,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("HATS")
|
||||
self.logger.info("HATS pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.visible_GPU = GPU
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"HATS parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nwith_pretrain : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
with_pretrain,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
self.HATS_model = HATSModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.HATS_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.HATS_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
if self.use_gpu:
|
||||
self.HATS_model.cuda()
|
||||
# set the visible GPU
|
||||
if self.visible_GPU:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily inter as daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle the daily inter data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.HATS_model.train()
|
||||
|
||||
# organize the train data into daily inter as daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float()
|
||||
label = torch.from_numpy(y_train_values[batch]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
feature = feature.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
pred = self.HATS_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.HATS_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare testing data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.HATS_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
# organize the test data into daily inter as daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_values[batch]).float()
|
||||
label = torch.from_numpy(y_values[batch]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
feature = feature.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
pred = self.HATS_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.with_pretrain:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
if self.base_model == "LSTM":
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
|
||||
pretrained_model = LSTMModel()
|
||||
pretrained_model.load_state_dict(torch.load("benchmarks/LSTM/model_lstm_csi300.pkl"))
|
||||
elif self.base_model == "GRU":
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
pretrained_model = GRUModel()
|
||||
pretrained_model.load_state_dict(torch.load("benchmarks/GRU/model_gru_csi300.pkl"))
|
||||
model_dict = self.HATS_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.HATS_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.HATS_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.HATS_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
self.HATS_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
# organize the data into daily inter as daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
x_batch = torch.from_numpy(x_values[batch]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
x_batch = x_batch.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.HATS_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.HATS_model(x_batch).detach().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class HATSModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"):
|
||||
super().__init__()
|
||||
|
||||
if base_model == "GRU":
|
||||
self.model = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
elif base_model == "LSTM":
|
||||
self.model = nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.bn1 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
|
||||
self.fc = nn.Linear(hidden_size, hidden_size)
|
||||
self.bn2 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
|
||||
self.fc_out = nn.Linear(hidden_size, 1)
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
self.d_feat = d_feat
|
||||
|
||||
num_head_att = [1] * num_layers
|
||||
hidden_dim = [hidden_size] * num_layers
|
||||
dims = [d_feat] + [d * nh for (d, nh) in zip(hidden_dim, num_head_att[:-1])] + [num_head_att[-1]]
|
||||
in_dims = dims[:-1]
|
||||
out_dims = [d // nh for (d, nh) in zip(dims[1:], num_head_att)]
|
||||
self.attn = nn.ModuleList(
|
||||
[GraphAttention(i, o, nh, dropout) for (i, o, nh) in zip(in_dims, out_dims, num_head_att)]
|
||||
)
|
||||
self.bns = nn.ModuleList([nn.BatchNorm1d(dim) for dim in dims[1:-1]])
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.elu = nn.ELU()
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x = x.permute(0, 2, 1) # [N, T, F]
|
||||
out, _ = self.model(x)
|
||||
hidden = out[:, -1, :]
|
||||
hidden = self.bn1(hidden)
|
||||
attention = GraphAttention.cal_attention(hidden, hidden)
|
||||
output = attention.mm(hidden)
|
||||
output = self.fc(output)
|
||||
output = self.bn2(output)
|
||||
output = self.leaky_relu(output)
|
||||
return self.fc_out(output).squeeze()
|
||||
|
||||
|
||||
class GraphAttention(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, num_heads, dropout=0.5):
|
||||
|
||||
super().__init__()
|
||||
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
input_dim : int
|
||||
Dimension of input node features.
|
||||
output_dim : int
|
||||
Dimension of output node features.
|
||||
num_heads : list of ints
|
||||
Number of attention heads in each hidden layer and output layer. Must be non empty. Note that len(num_heads) = len(hidden_dims)+1.
|
||||
dropout : float
|
||||
Dropout rate. Default: 0.5.
|
||||
"""
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.fcs = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)])
|
||||
self.a = nn.ModuleList([nn.Linear(2 * output_dim, 1) for _ in range(num_heads)])
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.softmax = nn.Softmax(dim=0)
|
||||
self.leakyrelu = nn.LeakyReLU()
|
||||
|
||||
def forward(self, features, nodes, mappings, rows):
|
||||
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
An (n' x input_dim) tensor of input node features.
|
||||
nodes : list of numpy array
|
||||
nodes[i] is an array of the nodes in the ith layer of the
|
||||
computation graph.
|
||||
mappings : list of dictionary
|
||||
mappings[i] is a dictionary mappings node v (labelled 0 to |V|-1)
|
||||
in nodes[i] to its position in nodes[i]. For example,
|
||||
if nodes[i] = [2,5], then mappings[i][2] = 0 and
|
||||
mappings[i][5] = 1.
|
||||
rows : numpy array
|
||||
rows[i] is an array of neighbors of node i.
|
||||
Returns
|
||||
-------
|
||||
out : torch.Tensor
|
||||
An (len(node_layers[-1]) x output_dim) tensor of output node features.
|
||||
"""
|
||||
|
||||
nprime = features.shape[0]
|
||||
rows = [np.array([mappings[v] for v in row], dtype=np.int64) for row in rows]
|
||||
sum_degs = np.hstack(([0], np.cumsum([len(row) for row in rows])))
|
||||
mapped_nodes = [mappings[v] for v in nodes]
|
||||
indices = torch.LongTensor([[v, c] for (v, row) in zip(mapped_nodes, rows) for c in row]).t()
|
||||
|
||||
out = []
|
||||
for k in range(self.num_heads):
|
||||
h = self.fcs[k](features)
|
||||
|
||||
nbr_h = torch.cat(tuple([h[row] for row in rows]), dim=0)
|
||||
self_h = torch.cat(
|
||||
tuple([h[mappings[nodes[i]]].repeat(len(row), 1) for (i, row) in enumerate(rows)]), dim=0
|
||||
)
|
||||
cat_h = torch.cat((self_h, nbr_h), dim=1)
|
||||
|
||||
e = self.leakyrelu(self.a[k](cat_h))
|
||||
|
||||
alpha = [self.softmax(e[lo:hi]) for (lo, hi) in zip(sum_degs, sum_degs[1:])]
|
||||
alpha = torch.cat(tuple(alpha), dim=0)
|
||||
alpha = alpha.squeeze(1)
|
||||
alpha = self.dropout(alpha)
|
||||
|
||||
adj = torch.sparse.FloatTensor(indices, alpha, torch.Size([nprime, nprime]))
|
||||
out.append(torch.sparse.mm(adj, h)[mapped_nodes])
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def cal_attention(x, y):
|
||||
att_x = torch.mean(x, dim=1).reshape(-1, 1)
|
||||
att_y = torch.mean(y, dim=1).reshape(-1, 1)
|
||||
att = att_x.mm(torch.t(att_y))
|
||||
return (
|
||||
torch.mean(
|
||||
x.reshape(x.shape[0], 1, x.shape[1]).repeat(1, y.shape[0], 1)
|
||||
* y.reshape(1, y.shape[0], y.shape[1]).repeat(x.shape[0], 1, 1),
|
||||
dim=2,
|
||||
)
|
||||
- att
|
||||
)
|
||||
@@ -11,7 +11,12 @@ import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
@@ -109,7 +114,10 @@ class LSTM(Model):
|
||||
)
|
||||
|
||||
self.lstm_model = LSTMModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.lstm_model.parameters(), lr=self.lr)
|
||||
@@ -141,7 +149,7 @@ class LSTM(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -219,7 +227,9 @@ class LSTM(Model):
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
|
||||
@@ -19,7 +19,12 @@ import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
@@ -33,7 +38,16 @@ from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class SFM_Model(nn.Module):
|
||||
def __init__(self, d_feat=6, output_dim=1, freq_dim=10, hidden_size=64, dropout_W=0.0, dropout_U=0.0, device="cpu"):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
output_dim=1,
|
||||
freq_dim=10,
|
||||
hidden_size=64,
|
||||
dropout_W=0.0,
|
||||
dropout_U=0.0,
|
||||
device="cpu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = d_feat
|
||||
@@ -157,7 +171,16 @@ class SFM_Model(nn.Module):
|
||||
|
||||
init_state_time = torch.tensor(0).to(self.device)
|
||||
|
||||
self.states = [init_state_p, init_state_h, init_state_S_re, init_state_S_im, init_state_time, None, None, None]
|
||||
self.states = [
|
||||
init_state_p,
|
||||
init_state_h,
|
||||
init_state_S_re,
|
||||
init_state_S_im,
|
||||
init_state_time,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
]
|
||||
|
||||
def get_constants(self, x):
|
||||
constants = []
|
||||
@@ -352,7 +375,9 @@ class SFM(Model):
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
@@ -409,7 +434,7 @@ class SFM(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pytorch_tabnet.tab_model import TabNetRegressor
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class TabNetModel(Model):
|
||||
"""TabNetModel Model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_d,
|
||||
n_a,
|
||||
n_steps,
|
||||
gamma,
|
||||
n_independent,
|
||||
n_shared,
|
||||
seed,
|
||||
momentum,
|
||||
lambda_sparse,
|
||||
optimizer_params,
|
||||
**kwargs
|
||||
):
|
||||
self.model = None
|
||||
|
||||
self.n_d = n_d
|
||||
self.n_a = n_a
|
||||
self.n_steps = n_steps
|
||||
self.gamma = gamma
|
||||
self.n_independent = n_independent
|
||||
self.n_shared = n_shared
|
||||
self.seed = seed
|
||||
self.momentum = momentum
|
||||
self.lambda_sparse = lambda_sparse
|
||||
self.optimizer_params = optimizer_params
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
n_d=8,
|
||||
n_a=8,
|
||||
n_steps=3,
|
||||
gamma=1.3,
|
||||
n_independent=2,
|
||||
n_shared=2,
|
||||
seed=0,
|
||||
momentum=0.02,
|
||||
lambda_sparse=1e-3,
|
||||
optimizer_params={"lr": 2e-3},
|
||||
**kwargs
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"].values, df_train["label"].values * 100
|
||||
x_valid, y_valid = df_valid["feature"].values, df_valid["label"].values * 100
|
||||
|
||||
self.model = TabNetRegressor(
|
||||
n_d=self.n_d,
|
||||
n_a=self.n_a,
|
||||
n_steps=self.n_steps,
|
||||
gamma=self.gamma,
|
||||
n_independent=self.n_independent,
|
||||
n_shared=self.n_shared,
|
||||
seed=self.seed,
|
||||
momentum=self.momentum,
|
||||
lambda_sparse=self.lambda_sparse,
|
||||
optimizer_params=self.optimizer_params,
|
||||
**kwargs
|
||||
)
|
||||
self.model.fit(x_train, y_train, eval_set=[(x_valid, y_valid)])
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
test_pred = self.model.predict(x_test.values)
|
||||
return pd.Series(test_pred.reshape([-1]), index=x_test.index)
|
||||
@@ -38,7 +38,9 @@ class XGBModel(Model):
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
@@ -96,7 +96,19 @@ class BaseGraph(object):
|
||||
"""
|
||||
py.init_notebook_mode()
|
||||
for _fig in figure_list:
|
||||
py.iplot(_fig)
|
||||
# NOTE: displays figures: https://plotly.com/python/renderers/
|
||||
# default: plotly_mimetype+notebook
|
||||
# support renderers: import plotly.io as pio; print(pio.renderers)
|
||||
renderer = None
|
||||
try:
|
||||
# in notebook
|
||||
_ipykernel = str(type(get_ipython()))
|
||||
if "google.colab" in _ipykernel:
|
||||
renderer = "colab"
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
_fig.show(renderer=renderer)
|
||||
|
||||
def _get_layout(self) -> go.Layout:
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import sys, os
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import pandas as pd
|
||||
import ruamel.yaml as yaml
|
||||
from qlib.config import C
|
||||
from qlib.model.trainer import task_train
|
||||
|
||||
|
||||
@@ -41,7 +42,7 @@ def sys_config(config, config_path):
|
||||
|
||||
|
||||
# worflow handler function
|
||||
def workflow(config_path, experiment_name="workflow"):
|
||||
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
with open(config_path) as fp:
|
||||
config = yaml.load(fp, Loader=yaml.Loader)
|
||||
|
||||
@@ -50,7 +51,9 @@ def workflow(config_path, experiment_name="workflow"):
|
||||
|
||||
provider_uri = config.get("provider_uri")
|
||||
region = config.get("region")
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
exp_manager = C["exp_manager"]
|
||||
exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder)
|
||||
qlib.init(provider_uri=provider_uri, region=region, exp_manager=exp_manager)
|
||||
|
||||
task_train(config, experiment_name=experiment_name)
|
||||
|
||||
|
||||
@@ -239,20 +239,17 @@ class MLflowExpManager(ExpManager):
|
||||
return self._client
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
else:
|
||||
self.uri = uri
|
||||
# create experiment
|
||||
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
|
||||
# set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# start the experiment
|
||||
self.active_experiment.start(recorder_name)
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
logger.info(
|
||||
"No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory."
|
||||
)
|
||||
else:
|
||||
self.uri = uri
|
||||
mlflow.set_tracking_uri(self.uri)
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
|
||||
@@ -224,6 +224,8 @@ class MLflowRecorder(Recorder):
|
||||
)
|
||||
|
||||
def start_run(self):
|
||||
# set the tracking uri
|
||||
mlflow.set_tracking_uri(self._uri)
|
||||
# start the run
|
||||
run = mlflow.start_run(self.id, self.experiment_id, self.name)
|
||||
# save the run id and artifact_uri
|
||||
|
||||
@@ -22,5 +22,4 @@ scikit_learn==0.23.2
|
||||
torch==1.6.0
|
||||
tqdm==4.49.0
|
||||
yahooquery==2.2.7
|
||||
mlflow==1.12.1
|
||||
pytorch-tabnet==2.0.1
|
||||
mlflow==1.12.1
|
||||
@@ -44,6 +44,7 @@ class YahooCollector:
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
show_1m_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -67,10 +68,13 @@ class YahooCollector:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1m_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._delay = delay
|
||||
self._show_1m_logging = show_1m_logging
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
@@ -83,7 +87,7 @@ class YahooCollector:
|
||||
self._interval = interval
|
||||
self._check_small_data = check_data_length
|
||||
self._start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
|
||||
self._end_datetime = pd.Timestamp(str(end)) if end else self.END_DATETIME
|
||||
self._end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
|
||||
if self._interval == "1m":
|
||||
self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME)
|
||||
elif self._interval == "1d":
|
||||
@@ -91,8 +95,12 @@ class YahooCollector:
|
||||
else:
|
||||
raise ValueError(f"interval error: {self._interval}")
|
||||
|
||||
# using for 1m
|
||||
self._next_datetime = self.convert_datetime(self._start_datetime.date() + pd.Timedelta(days=1))
|
||||
self._latest_datetime = self.convert_datetime(self._end_datetime.date())
|
||||
|
||||
self._start_datetime = self.convert_datetime(self._start_datetime)
|
||||
self._end_datetime = self.convert_datetime(min(self._end_datetime, self.END_DATETIME))
|
||||
self._end_datetime = self.convert_datetime(self._end_datetime)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@@ -100,20 +108,24 @@ class YahooCollector:
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewirte min_numbers_trading")
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewirte get_stock_list")
|
||||
raise NotImplementedError("rewrite get_stock_list")
|
||||
|
||||
@property
|
||||
@abc.abstractclassmethod
|
||||
@abc.abstractmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def convert_datetime(self, dt: pd.Timestamp):
|
||||
dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
|
||||
return pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
def convert_datetime(self, dt: [pd.Timestamp, datetime.date, str]):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
def _sleep(self):
|
||||
time.sleep(self._delay)
|
||||
@@ -136,7 +148,7 @@ class YahooCollector:
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
with stock_path.open("a") as fp:
|
||||
df.to_csv(fp, index=False, header=None)
|
||||
df.to_csv(fp, index=False, header=False)
|
||||
else:
|
||||
with stock_path.open("w") as fp:
|
||||
df.to_csv(fp, index=False)
|
||||
@@ -155,34 +167,47 @@ class YahooCollector:
|
||||
def _get_from_remote(self, symbol):
|
||||
def _get_simple(start_, end_):
|
||||
self._sleep()
|
||||
error_msg = f"{symbol}-{self._interval}-{start_}-{end_}"
|
||||
|
||||
def _show_logging_func():
|
||||
if self._interval == "1m" and self._show_1m_logging:
|
||||
logger.warning(f"{error_msg}:{_resp}")
|
||||
|
||||
try:
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=start_, end=end_)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
elif isinstance(_resp, dict):
|
||||
_temp_data = _resp.get(symbol, {})
|
||||
if isinstance(_temp_data, str) or (
|
||||
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
|
||||
):
|
||||
_show_logging_func()
|
||||
else:
|
||||
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{_resp}")
|
||||
_show_logging_func()
|
||||
except Exception as e:
|
||||
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{e}")
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
_result = None
|
||||
if self._interval == "1d":
|
||||
_result = _get_simple(self._start_datetime, self._end_datetime)
|
||||
elif self._interval == "1m":
|
||||
_start_date = self._start_datetime.date() + pd.Timedelta(days=1)
|
||||
_end_date = self._end_datetime.date()
|
||||
if _start_date >= _end_date:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
_result = _get_simple(self._start_datetime, self._end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None:
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
for _s, _e in ((self._start_datetime, _start_date), (_end_date, self._end_datetime)):
|
||||
for _s, _e in (
|
||||
(self._start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self._end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(_start_date, _end_date, closed="left"):
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
self._sleep()
|
||||
_get_multi(_start, _end)
|
||||
@@ -472,6 +497,7 @@ class Run:
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
show_1m_logging=False,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
@@ -491,6 +517,9 @@ class Run:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1m_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
@@ -510,6 +539,7 @@ class Run:
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
show_1m_logging=show_1m_logging,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self):
|
||||
@@ -531,6 +561,7 @@ class Run:
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
show_1m_logging=False,
|
||||
):
|
||||
"""download -> normalize
|
||||
|
||||
@@ -550,6 +581,9 @@ class Run:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1m_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
|
||||
Examples
|
||||
-------
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
@@ -562,6 +596,7 @@ class Run:
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
show_1m_logging=show_1m_logging,
|
||||
)
|
||||
self.normalize_data()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user