mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
86 Commits
qlib_monit
...
v0.7.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
215f7e0d22 | ||
|
|
dafef0ac08 | ||
|
|
1cb43ea69b | ||
|
|
7ca9cf79f7 | ||
|
|
35f090a6e4 | ||
|
|
ace7484304 | ||
|
|
2d4f0e80f9 | ||
|
|
946c9392a1 | ||
|
|
b523b27d5a | ||
|
|
0b83fb3564 | ||
|
|
d96f7a67c6 | ||
|
|
a7862387a2 | ||
|
|
c4c438249c | ||
|
|
8709dde65b | ||
|
|
d66733c358 | ||
|
|
9cf574b697 | ||
|
|
107e40f3ee | ||
|
|
4837ba8db3 | ||
|
|
2ab4a9adb3 | ||
|
|
8d0b673341 | ||
|
|
8ebdb1e873 | ||
|
|
39340fbf06 | ||
|
|
0e277723a3 | ||
|
|
1418417034 | ||
|
|
b261f7b501 | ||
|
|
bab50e8837 | ||
|
|
0eee4a0f2e | ||
|
|
21eb71d4a9 | ||
|
|
46714adf4c | ||
|
|
99fb49650a | ||
|
|
985fd0816c | ||
|
|
d0f54343c7 | ||
|
|
a3679e6758 | ||
|
|
b6c31540e8 | ||
|
|
a4f6e04199 | ||
|
|
0aee46ee79 | ||
|
|
9c8d423a86 | ||
|
|
b4efbd53b2 | ||
|
|
5a50d7c952 | ||
|
|
0fe8b281ba | ||
|
|
5331ab93f8 | ||
|
|
64582e9d46 | ||
|
|
9e0e2ff736 | ||
|
|
973c4137e4 | ||
|
|
730f6258d6 | ||
|
|
5850490b24 | ||
|
|
d4b36bdab4 | ||
|
|
40416d8c30 | ||
|
|
567e42840c | ||
|
|
65ddca133f | ||
|
|
d199256d34 | ||
|
|
073fe4668e | ||
|
|
89d53853e5 | ||
|
|
bb6c1572ca | ||
|
|
4c4e77b11f | ||
|
|
38c7b7303a | ||
|
|
02d0eedd68 | ||
|
|
5a3dde93a8 | ||
|
|
177f6a59d2 | ||
|
|
492a62a569 | ||
|
|
9a44fbf9c1 | ||
|
|
03eb0882de | ||
|
|
a845a2271b | ||
|
|
ba021f6007 | ||
|
|
7d9544fb91 | ||
|
|
12b7be333d | ||
|
|
ed54f1213c | ||
|
|
554b9c7826 | ||
|
|
6f150f3fd6 | ||
|
|
2a0d991d9b | ||
|
|
1320e53f81 | ||
|
|
8222795ac4 | ||
|
|
616a742db7 | ||
|
|
811d2c975e | ||
|
|
6272ce108f | ||
|
|
64896745d0 | ||
|
|
b2fe2385d5 | ||
|
|
8d05cd2daf | ||
|
|
231bdf8608 | ||
|
|
ab6b88ce14 | ||
|
|
94ab4bbf3f | ||
|
|
ca0363ded8 | ||
|
|
a467e10974 | ||
|
|
6dfbf00a23 | ||
|
|
b24af7fff6 | ||
|
|
45f73361e3 |
44
README.md
44
README.md
@@ -11,6 +11,7 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| TCTS Model | [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |
|
||||
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
|
||||
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
|
||||
| High-frequency data processing example | [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
|
||||
@@ -68,7 +69,7 @@ Your feedbacks about the features are very important.
|
||||
# Framework of Qlib
|
||||
|
||||
<div style="align: center">
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.1" />
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.2" />
|
||||
</div>
|
||||
|
||||
|
||||
@@ -159,6 +160,28 @@ Users could create the same dataset with it.
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
|
||||
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
### Automatic update of daily frequency data(from yahoo finance)
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
|
||||
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
|
||||
* Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
* use *crontab*: `crontab -e`
|
||||
* set up timed tasks:
|
||||
|
||||
```
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
```
|
||||
* **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
* Manual update of data
|
||||
```
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
```
|
||||
* *trading_date*: start of trading day
|
||||
* *end_date*: end of trading day(not included)
|
||||
|
||||
|
||||
<!--
|
||||
- Run the initialization code and get stock data:
|
||||
|
||||
@@ -254,18 +277,19 @@ The automatic workflow may not suit the research workflow of all Quant researche
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. 2016)](qlib/contrib/model/xgboost.py)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al. 2017)](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. 2017)](qlib/contrib/model/catboost_model.py)
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](qlib/contrib/model/xgboost.py)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](qlib/contrib/model/catboost_model.py)
|
||||
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al. 1997)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural omputation 1997)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](qlib/contrib/model/pytorch_gru.py)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al. 2017)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](qlib/contrib/model/pytorch_tcts.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
|
||||
BIN
docs/_static/img/framework.png
vendored
BIN
docs/_static/img/framework.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 271 KiB After Width: | Height: | Size: 208 KiB |
@@ -67,6 +67,34 @@ After running the above command, users can find china-stock and us-stock data in
|
||||
|
||||
When ``Qlib`` is initialized with this dataset, users could build and evaluate their own models with it. Please refer to `Initialization <../start/initialization.html>`_ for more details.
|
||||
|
||||
Automatic update of daily frequency data
|
||||
----------------------------------------
|
||||
|
||||
**It is recommended that users update the data manually once (\-\-trading_date 2021-05-25) and then set it to update automatically.**
|
||||
|
||||
For more information refer to: `yahoo collector <https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#Automatic-update-of-daily-frequency-data>`_
|
||||
|
||||
- Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
- use *crontab*: `crontab -e`
|
||||
- set up timed tasks:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
|
||||
- **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
- Manual update of data
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
|
||||
- *trading_date*: start of trading day
|
||||
- *end_date*: end of trading day(not included)
|
||||
|
||||
|
||||
|
||||
Converting CSV Format into Qlib Format
|
||||
-------------------------------------------
|
||||
|
||||
|
||||
@@ -90,12 +90,12 @@ Below is a typical config file of ``qrun``.
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
|
||||
|
||||
@@ -142,7 +142,7 @@ The meaning of each field is as follows:
|
||||
|
||||
- `region`
|
||||
- If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
|
||||
- If `region` == "cn", ``Qlib`` will be initialized in china-stock mode.
|
||||
- If `region` == "cn", ``Qlib`` will be initialized in China-stock mode.
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
@@ -61,7 +61,6 @@ task:
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
with_pretrain: True
|
||||
model_path: "benchmarks/LSTM/csi300_lstm_ts.pkl"
|
||||
GPU: 0
|
||||
dataset:
|
||||
|
||||
@@ -54,7 +54,6 @@ task:
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
with_pretrain: True
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
GPU: 0
|
||||
dataset:
|
||||
@@ -81,4 +80,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -4,6 +4,10 @@ Here are the results of each benchmark model running on Qlib's `Alpha360` and `A
|
||||
|
||||
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
|
||||
|
||||
> If you need to reproduce the results below, please use the **v1** dataset: `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn --version v1`
|
||||
>
|
||||
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ
|
||||
|
||||
## Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
@@ -18,6 +22,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|
||||
52
examples/benchmarks/TCTS/README.md
Normal file
52
examples/benchmarks/TCTS/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Temporally Correlated Task Scheduling for Sequence Learning
|
||||
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
|
||||
|
||||
### Background
|
||||
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
|
||||
|
||||
<p align="center">
|
||||
<img src="task_description.png" width="600" height="200"/>
|
||||
</p>
|
||||
|
||||
|
||||
### Method
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
|
||||
<p align="center">
|
||||
<img src="workflow.png"/>
|
||||
</p>
|
||||
|
||||
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
|
||||
|
||||
### DataSet
|
||||
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
|
||||
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
|
||||
### Experiments
|
||||
#### Task Description
|
||||
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
|
||||
<div align=center>
|
||||
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
|
||||
</div>
|
||||
|
||||
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
|
||||
#### Baselines
|
||||
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
|
||||
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
|
||||
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
|
||||
#### Result
|
||||
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
|
||||
| :----: | :----: | :----: | :----: |
|
||||
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
|
||||
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
|
||||
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
|
||||
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
|
||||
BIN
examples/benchmarks/TCTS/task_description.png
Normal file
BIN
examples/benchmarks/TCTS/task_description.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 25 KiB |
BIN
examples/benchmarks/TCTS/workflow.png
Normal file
BIN
examples/benchmarks/TCTS/workflow.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
93
examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
Normal file
93
examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -1) / $close - 1",
|
||||
"Ref($close, -2) / Ref($close, -1) - 1",
|
||||
"Ref($close, -3) / Ref($close, -2) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TCTS
|
||||
module_path: qlib.contrib.model.pytorch_tcts
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
GPU: 0
|
||||
fore_optimizer: adam
|
||||
weight_optimizer: adam
|
||||
output_dim: 3
|
||||
fore_lr: 5e-4
|
||||
weight_lr: 5e-4
|
||||
steps: 3
|
||||
target_label: 1
|
||||
lowest_valid_performance: 0.993
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
label_col: 1
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
81
examples/benchmarks/TRA/README.md
Normal file
81
examples/benchmarks/TRA/README.md
Normal file
@@ -0,0 +1,81 @@
|
||||
# Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport
|
||||
|
||||
This code provides a PyTorch implementation for TRA (Temporal Routing Adaptor), as described in the paper [Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport](http://arxiv.org/abs/2106.12950).
|
||||
|
||||
* TRA (Temporal Routing Adaptor) is a lightweight module that consists of a set of independent predictors for learning multiple patterns as well as a router to dispatch samples to different predictors.
|
||||
* We also design a learning algorithm based on Optimal Transport (OT) to obtain the optimal sample to predictor assignment and effectively optimize the router with such assignment through an auxiliary loss term.
|
||||
|
||||
|
||||
# Running TRA
|
||||
|
||||
## Requirements
|
||||
- Install `Qlib` main branch
|
||||
|
||||
## Running
|
||||
|
||||
We attach our running scripts for the paper in `run.sh`.
|
||||
|
||||
And here are two ways to run the model:
|
||||
|
||||
* Running from scripts with default parameters
|
||||
You can directly run from Qlib command `qrun`:
|
||||
```
|
||||
qrun configs/config_alstm.yaml
|
||||
```
|
||||
|
||||
* Running from code with self-defined parameters
|
||||
Setting different parameters is also allowed. See codes in `example.py`:
|
||||
```
|
||||
python example.py --config_file configs/config_alstm.yaml
|
||||
```
|
||||
|
||||
Here we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scipts.
|
||||
|
||||
# Results
|
||||
|
||||
## Outputs
|
||||
|
||||
After running the scripts, you can find result files in path `./output`:
|
||||
|
||||
`info.json` - config settings and result metrics.
|
||||
|
||||
`log.csv` - running logs.
|
||||
|
||||
`model.bin` - the model parameter dictionary.
|
||||
|
||||
`pred.pkl` - the prediction scores and output for inference.
|
||||
|
||||
## Our Results
|
||||
| Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD |
|
||||
|-------------------|-------------------|---------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|
||||
|Linear|0.163|0.327|0.020|0.132|-3.2%|16.8%|-0.191|32.1%|
|
||||
|LightGBM|0.160(0.000)|0.323(0.000)|0.041|0.292|7.8%|15.5%|0.503|25.7%|
|
||||
|MLP|0.160(0.002)|0.323(0.003)|0.037|0.273|3.7%|15.3%|0.264|26.2%|
|
||||
|SFM|0.159(0.001) |0.321(0.001) |0.047 |0.381 |7.1% |14.3% |0.497 |22.9%|
|
||||
|ALSTM|0.158(0.001) |0.320(0.001) |0.053 |0.419 |12.3% |13.7% |0.897 |20.2%|
|
||||
|Trans.|0.158(0.001) |0.322(0.001) |0.051 |0.400 |14.5% |14.2% |1.028 |22.5%|
|
||||
|ALSTM+TS|0.160(0.002) |0.321(0.002) |0.039 |0.291 |6.7% |14.6% |0.480|22.3%|
|
||||
|Trans.+TS|0.160(0.004) |0.324(0.005) |0.037 |0.278 |10.4% |14.7% |0.722 |23.7%|
|
||||
|ALSTM+TRA(Ours)|0.157(0.000) |0.318(0.000) |0.059 |0.460 |12.4% |14.0% |0.885 |20.4%|
|
||||
|Trans.+TRA(Ours)|0.157(0.000) |0.320(0.000) |0.056 |0.442 |16.1% |14.2% |1.133 |23.1%|
|
||||
|
||||
A more detailed demo for our experiment results in the paper can be found in `Report.ipynb`.
|
||||
|
||||
# Common Issues
|
||||
|
||||
For help or issues using TRA, please submit a GitHub issue.
|
||||
|
||||
Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.
|
||||
|
||||
# Citation
|
||||
If you find this repository useful in your research, please cite:
|
||||
```
|
||||
@inproceedings{HengxuKDD2021,
|
||||
author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
|
||||
title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
|
||||
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
|
||||
series = {KDD '21},
|
||||
year = {2021},
|
||||
publisher = {ACM},
|
||||
}
|
||||
```
|
||||
796
examples/benchmarks/TRA/Reports.ipynb
Normal file
796
examples/benchmarks/TRA/Reports.ipynb
Normal file
File diff suppressed because one or more lines are too long
63
examples/benchmarks/TRA/configs/config_alstm.yaml
Normal file
63
examples/benchmarks/TRA/configs/config_alstm.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
num_heads: 2
|
||||
use_attn: True
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 1
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/alstm
|
||||
model_type: LSTM
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
63
examples/benchmarks/TRA/configs/config_alstm_tra.yaml
Normal file
63
examples/benchmarks/TRA/configs/config_alstm_tra.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
num_heads: 2
|
||||
use_attn: True
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 10
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/alstm_tra
|
||||
model_type: LSTM
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 2.0
|
||||
rho: 0.99
|
||||
freeze_model: True
|
||||
model_init_state: output/test/alstm_tra_init/model.bin
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
63
examples/benchmarks/TRA/configs/config_alstm_tra_init.yaml
Normal file
63
examples/benchmarks/TRA/configs/config_alstm_tra_init.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
num_heads: 2
|
||||
use_attn: True
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/alstm_tra_init
|
||||
model_type: LSTM
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 512
|
||||
63
examples/benchmarks/TRA/configs/config_transformer.yaml
Normal file
63
examples/benchmarks/TRA/configs/config_transformer.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
num_heads: 4
|
||||
use_attn: False
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 1
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/transformer
|
||||
model_type: Transformer
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
63
examples/benchmarks/TRA/configs/config_transformer_tra.yaml
Normal file
63
examples/benchmarks/TRA/configs/config_transformer_tra.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
num_heads: 4
|
||||
use_attn: False
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0005
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/transformer_tra
|
||||
model_type: Transformer
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: True
|
||||
model_init_state: output/test/transformer_tra_init/model.bin
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 512
|
||||
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
num_heads: 4
|
||||
use_attn: False
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/transformer_tra_init
|
||||
model_type: Transformer
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 512
|
||||
1
examples/benchmarks/TRA/data/README.md
Normal file
1
examples/benchmarks/TRA/data/README.md
Normal file
@@ -0,0 +1 @@
|
||||
Data Link: https://drive.google.com/drive/folders/1fMqZYSeLyrHiWmVzygeI4sw3vp5Gt8cY?usp=sharing
|
||||
39
examples/benchmarks/TRA/example.py
Normal file
39
examples/benchmarks/TRA/example.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import argparse
|
||||
|
||||
import qlib
|
||||
import ruamel.yaml as yaml
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
|
||||
def main(seed, config_file="configs/config_alstm.yaml"):
|
||||
|
||||
# set random seed
|
||||
with open(config_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
|
||||
seed_suffix = ""
|
||||
config["task"]["model"]["kwargs"].update(
|
||||
{"seed": seed, "logdir": config["task"]["model"]["kwargs"]["logdir"] + seed_suffix}
|
||||
)
|
||||
|
||||
# initialize workflow
|
||||
qlib.init(
|
||||
provider_uri=config["qlib_init"]["provider_uri"],
|
||||
region=config["qlib_init"]["region"],
|
||||
)
|
||||
dataset = init_instance_by_config(config["task"]["dataset"])
|
||||
model = init_instance_by_config(config["task"]["model"])
|
||||
|
||||
# train model
|
||||
model.fit(dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# set params from cmd
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--seed", type=int, default=1000, help="random seed")
|
||||
parser.add_argument("--config_file", type=str, default="configs/config_alstm.yaml", help="config file")
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
29
examples/benchmarks/TRA/run.sh
Normal file
29
examples/benchmarks/TRA/run.sh
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
|
||||
# we used random seed(1 1000 2000 3000 4000 5000) in our experiments
|
||||
|
||||
# Directly run from Qlib command `qrun`
|
||||
qrun configs/config_alstm.yaml
|
||||
|
||||
qrun configs/config_transformer.yaml
|
||||
|
||||
qrun configs/config_transformer_tra_init.yaml
|
||||
qrun configs/config_transformer_tra.yaml
|
||||
|
||||
qrun configs/config_alstm_tra_init.yaml
|
||||
qrun configs/config_alstm_tra.yaml
|
||||
|
||||
|
||||
# Or setting different parameters with example.py
|
||||
python example.py --config_file configs/config_alstm.yaml
|
||||
|
||||
python example.py --config_file configs/config_transformer.yaml
|
||||
|
||||
python example.py --config_file configs/config_transformer_tra_init.yaml
|
||||
python example.py --config_file configs/config_transformer_tra.yaml
|
||||
|
||||
python example.py --config_file configs/config_alstm_tra_init.yaml
|
||||
python example.py --config_file configs/config_alstm_tra.yaml
|
||||
|
||||
|
||||
|
||||
253
examples/benchmarks/TRA/src/dataset.py
Normal file
253
examples/benchmarks/TRA/src/dataset.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return torch.tensor(x, dtype=torch.float, device=device)
|
||||
return x
|
||||
|
||||
|
||||
def _create_ts_slices(index, seq_len):
|
||||
"""
|
||||
create time series slices from pandas index
|
||||
|
||||
Args:
|
||||
index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
|
||||
seq_len (int): sequence length
|
||||
"""
|
||||
assert index.is_lexsorted(), "index should be sorted"
|
||||
|
||||
# number of dates for each code
|
||||
sample_count_by_codes = pd.Series(0, index=index).groupby(level=0).size().values
|
||||
|
||||
# start_index for each code
|
||||
start_index_of_codes = np.roll(np.cumsum(sample_count_by_codes), 1)
|
||||
start_index_of_codes[0] = 0
|
||||
|
||||
# all the [start, stop) indices of features
|
||||
# features btw [start, stop) are used to predict the `stop - 1` label
|
||||
slices = []
|
||||
for cur_loc, cur_cnt in zip(start_index_of_codes, sample_count_by_codes):
|
||||
for stop in range(1, cur_cnt + 1):
|
||||
end = cur_loc + stop
|
||||
start = max(end - seq_len, 0)
|
||||
slices.append(slice(start, end))
|
||||
slices = np.array(slices)
|
||||
|
||||
return slices
|
||||
|
||||
|
||||
def _get_date_parse_fn(target):
|
||||
"""get date parse function
|
||||
|
||||
This method is used to parse date arguments as target type.
|
||||
|
||||
Example:
|
||||
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
|
||||
get_date_parse_fn(20120101)('2017-01-01') => 20170101
|
||||
"""
|
||||
if isinstance(target, pd.Timestamp):
|
||||
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
|
||||
elif isinstance(target, str) and len(target) == 8:
|
||||
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
|
||||
elif isinstance(target, int):
|
||||
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
|
||||
else:
|
||||
_fn = lambda x: x
|
||||
return _fn
|
||||
|
||||
|
||||
class MTSDatasetH(DatasetH):
|
||||
"""Memory Augmented Time Series Dataset
|
||||
|
||||
Args:
|
||||
handler (DataHandler): data handler
|
||||
segments (dict): data split segments
|
||||
seq_len (int): time series sequence length
|
||||
horizon (int): label horizon (to mask historical loss for TRA)
|
||||
num_states (int): how many memory states to be added (for TRA)
|
||||
batch_size (int): batch size (<0 means daily batch)
|
||||
shuffle (bool): whether shuffle data
|
||||
pin_memory (bool): whether pin data to gpu memory
|
||||
drop_last (bool): whether drop last batch < batch_size
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler,
|
||||
segments,
|
||||
seq_len=60,
|
||||
horizon=0,
|
||||
num_states=1,
|
||||
batch_size=-1,
|
||||
shuffle=True,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
assert horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.horizon = horizon
|
||||
self.num_states = num_states
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.drop_last = drop_last
|
||||
self.pin_memory = pin_memory
|
||||
self.params = (batch_size, drop_last, shuffle) # for train/eval switch
|
||||
|
||||
super().__init__(handler, segments, **kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
|
||||
super().setup_data()
|
||||
|
||||
# change index to <code, date>
|
||||
# NOTE: we will use inplace sort to reduce memory use
|
||||
df = self.handler._data
|
||||
df.index = df.index.swaplevel()
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
self._data = df["feature"].values.astype("float32")
|
||||
self._label = df["label"].squeeze().astype("float32")
|
||||
self._index = df.index
|
||||
|
||||
# add memory to feature
|
||||
self._data = np.c_[self._data, np.zeros((len(self._data), self.num_states), dtype=np.float32)]
|
||||
|
||||
# padding tensor
|
||||
self.zeros = np.zeros((self.seq_len, self._data.shape[1]), dtype=np.float32)
|
||||
|
||||
# pin memory
|
||||
if self.pin_memory:
|
||||
self._data = _to_tensor(self._data)
|
||||
self._label = _to_tensor(self._label)
|
||||
self.zeros = _to_tensor(self.zeros)
|
||||
|
||||
# create batch slices
|
||||
self.batch_slices = _create_ts_slices(self._index, self.seq_len)
|
||||
|
||||
# create daily slices
|
||||
index = [slc.stop - 1 for slc in self.batch_slices]
|
||||
act_index = self.restore_index(index)
|
||||
daily_slices = {date: [] for date in sorted(act_index.unique(level=1))}
|
||||
for i, (code, date) in enumerate(act_index):
|
||||
daily_slices[date].append(self.batch_slices[i])
|
||||
self.daily_slices = list(daily_slices.values())
|
||||
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
fn = _get_date_parse_fn(self._index[0][1])
|
||||
start_date = fn(slc.start)
|
||||
end_date = fn(slc.stop)
|
||||
obj = copy.copy(self) # shallow copy
|
||||
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
|
||||
obj._data = self._data
|
||||
obj._label = self._label
|
||||
obj._index = self._index
|
||||
new_batch_slices = []
|
||||
for batch_slc in self.batch_slices:
|
||||
date = self._index[batch_slc.stop - 1][1]
|
||||
if start_date <= date <= end_date:
|
||||
new_batch_slices.append(batch_slc)
|
||||
obj.batch_slices = np.array(new_batch_slices)
|
||||
new_daily_slices = []
|
||||
for daily_slc in self.daily_slices:
|
||||
date = self._index[daily_slc[0].stop - 1][1]
|
||||
if start_date <= date <= end_date:
|
||||
new_daily_slices.append(daily_slc)
|
||||
obj.daily_slices = new_daily_slices
|
||||
return obj
|
||||
|
||||
def restore_index(self, index):
|
||||
if isinstance(index, torch.Tensor):
|
||||
index = index.cpu().numpy()
|
||||
return self._index[index]
|
||||
|
||||
def assign_data(self, index, vals):
|
||||
if isinstance(self._data, torch.Tensor):
|
||||
vals = _to_tensor(vals)
|
||||
elif isinstance(vals, torch.Tensor):
|
||||
vals = vals.detach().cpu().numpy()
|
||||
index = index.detach().cpu().numpy()
|
||||
self._data[index, -self.num_states :] = vals
|
||||
|
||||
def clear_memory(self):
|
||||
self._data[:, -self.num_states :] = 0
|
||||
|
||||
# TODO: better train/eval mode design
|
||||
def train(self):
|
||||
"""enable traning mode"""
|
||||
self.batch_size, self.drop_last, self.shuffle = self.params
|
||||
|
||||
def eval(self):
|
||||
"""enable evaluation mode"""
|
||||
self.batch_size = -1
|
||||
self.drop_last = False
|
||||
self.shuffle = False
|
||||
|
||||
def _get_slices(self):
|
||||
if self.batch_size < 0:
|
||||
slices = self.daily_slices.copy()
|
||||
batch_size = -1 * self.batch_size
|
||||
else:
|
||||
slices = self.batch_slices.copy()
|
||||
batch_size = self.batch_size
|
||||
return slices, batch_size
|
||||
|
||||
def __len__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
if self.drop_last:
|
||||
return len(slices) // batch_size
|
||||
return (len(slices) + batch_size - 1) // batch_size
|
||||
|
||||
def __iter__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
if self.shuffle:
|
||||
np.random.shuffle(slices)
|
||||
|
||||
for i in range(len(slices))[::batch_size]:
|
||||
if self.drop_last and i + batch_size > len(slices):
|
||||
break
|
||||
# get slices for this batch
|
||||
slices_subset = slices[i : i + batch_size]
|
||||
if self.batch_size < 0:
|
||||
slices_subset = np.concatenate(slices_subset)
|
||||
# collect data
|
||||
data = []
|
||||
label = []
|
||||
index = []
|
||||
for slc in slices_subset:
|
||||
_data = self._data[slc].clone() if self.pin_memory else self._data[slc].copy()
|
||||
if len(_data) != self.seq_len:
|
||||
if self.pin_memory:
|
||||
_data = torch.cat([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
|
||||
else:
|
||||
_data = np.concatenate([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
|
||||
if self.num_states > 0:
|
||||
_data[-self.horizon :, -self.num_states :] = 0
|
||||
data.append(_data)
|
||||
label.append(self._label[slc.stop - 1])
|
||||
index.append(slc.stop - 1)
|
||||
# concate
|
||||
index = torch.tensor(index, device=device)
|
||||
if isinstance(data[0], torch.Tensor):
|
||||
data = torch.stack(data)
|
||||
label = torch.stack(label)
|
||||
else:
|
||||
data = _to_tensor(np.stack(data))
|
||||
label = _to_tensor(np.stack(label))
|
||||
# yield -> generator
|
||||
yield {"data": data, "label": label, "index": index}
|
||||
603
examples/benchmarks/TRA/src/model.py
Normal file
603
examples/benchmarks/TRA/src/model.py
Normal file
@@ -0,0 +1,603 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
import math
|
||||
import json
|
||||
import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
class TRAModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
tra_config,
|
||||
model_type="LSTM",
|
||||
lr=1e-3,
|
||||
n_epochs=500,
|
||||
early_stop=50,
|
||||
smooth_steps=5,
|
||||
max_steps_per_epoch=None,
|
||||
freeze_model=False,
|
||||
model_init_state=None,
|
||||
lamb=0.0,
|
||||
rho=0.99,
|
||||
seed=0,
|
||||
logdir=None,
|
||||
eval_train=True,
|
||||
eval_test=False,
|
||||
avg_params=True,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
self.logger = get_module_logger("TRA")
|
||||
self.logger.info("TRA Model...")
|
||||
|
||||
self.model = eval(model_type)(**model_config).to(device)
|
||||
if model_init_state:
|
||||
self.model.load_state_dict(torch.load(model_init_state, map_location="cpu")["model"])
|
||||
if freeze_model:
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters()]))
|
||||
|
||||
self.tra = TRA(self.model.output_size, **tra_config).to(device)
|
||||
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters()]))
|
||||
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=lr)
|
||||
|
||||
self.model_config = model_config
|
||||
self.tra_config = tra_config
|
||||
self.lr = lr
|
||||
self.n_epochs = n_epochs
|
||||
self.early_stop = early_stop
|
||||
self.smooth_steps = smooth_steps
|
||||
self.max_steps_per_epoch = max_steps_per_epoch
|
||||
self.lamb = lamb
|
||||
self.rho = rho
|
||||
self.seed = seed
|
||||
self.logdir = logdir
|
||||
self.eval_train = eval_train
|
||||
self.eval_test = eval_test
|
||||
self.avg_params = avg_params
|
||||
|
||||
if self.tra.num_states > 1 and not self.eval_train:
|
||||
self.logger.warn("`eval_train` will be ignored when using TRA")
|
||||
|
||||
if self.logdir is not None:
|
||||
if os.path.exists(self.logdir):
|
||||
self.logger.warn(f"logdir {self.logdir} is not empty")
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
|
||||
self.fitted = False
|
||||
self.global_step = -1
|
||||
|
||||
def train_epoch(self, data_set):
|
||||
|
||||
self.model.train()
|
||||
self.tra.train()
|
||||
|
||||
data_set.train()
|
||||
|
||||
max_steps = self.n_epochs
|
||||
if self.max_steps_per_epoch is not None:
|
||||
max_steps = min(self.max_steps_per_epoch, self.n_epochs)
|
||||
|
||||
count = 0
|
||||
total_loss = 0
|
||||
total_count = 0
|
||||
for batch in tqdm(data_set, total=max_steps):
|
||||
count += 1
|
||||
if count > max_steps:
|
||||
break
|
||||
|
||||
self.global_step += 1
|
||||
|
||||
data, label, index = batch["data"], batch["label"], batch["index"]
|
||||
|
||||
feature = data[:, :, : -self.tra.num_states]
|
||||
hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]
|
||||
|
||||
hidden = self.model(feature)
|
||||
pred, all_preds, prob = self.tra(hidden, hist_loss)
|
||||
|
||||
loss = (pred - label).pow(2).mean()
|
||||
|
||||
L = (all_preds.detach() - label[:, None]).pow(2)
|
||||
L -= L.min(dim=-1, keepdim=True).values # normalize & ensure postive input
|
||||
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
|
||||
if prob is not None:
|
||||
P = sinkhorn(-L, epsilon=0.01) # sample assignment matrix
|
||||
lamb = self.lamb * (self.rho ** self.global_step)
|
||||
reg = prob.log().mul(P).sum(dim=-1).mean()
|
||||
loss = loss - lamb * reg
|
||||
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_count += len(pred)
|
||||
|
||||
total_loss /= total_count
|
||||
|
||||
return total_loss
|
||||
|
||||
def test_epoch(self, data_set, return_pred=False):
|
||||
|
||||
self.model.eval()
|
||||
self.tra.eval()
|
||||
data_set.eval()
|
||||
|
||||
preds = []
|
||||
metrics = []
|
||||
for batch in tqdm(data_set):
|
||||
data, label, index = batch["data"], batch["label"], batch["index"]
|
||||
|
||||
feature = data[:, :, : -self.tra.num_states]
|
||||
hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]
|
||||
|
||||
with torch.no_grad():
|
||||
hidden = self.model(feature)
|
||||
pred, all_preds, prob = self.tra(hidden, hist_loss)
|
||||
|
||||
L = (all_preds - label[:, None]).pow(2)
|
||||
|
||||
L -= L.min(dim=-1, keepdim=True).values # normalize & ensure postive input
|
||||
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
|
||||
X = np.c_[
|
||||
pred.cpu().numpy(),
|
||||
label.cpu().numpy(),
|
||||
]
|
||||
columns = ["score", "label"]
|
||||
if prob is not None:
|
||||
X = np.c_[X, all_preds.cpu().numpy(), prob.cpu().numpy()]
|
||||
columns += ["score_%d" % d for d in range(all_preds.shape[1])] + [
|
||||
"prob_%d" % d for d in range(all_preds.shape[1])
|
||||
]
|
||||
|
||||
pred = pd.DataFrame(X, index=index.cpu().numpy(), columns=columns)
|
||||
|
||||
metrics.append(evaluate(pred))
|
||||
|
||||
if return_pred:
|
||||
preds.append(pred)
|
||||
|
||||
metrics = pd.DataFrame(metrics)
|
||||
metrics = {
|
||||
"MSE": metrics.MSE.mean(),
|
||||
"MAE": metrics.MAE.mean(),
|
||||
"IC": metrics.IC.mean(),
|
||||
"ICIR": metrics.IC.mean() / metrics.IC.std(),
|
||||
}
|
||||
|
||||
if return_pred:
|
||||
preds = pd.concat(preds, axis=0)
|
||||
preds.index = data_set.restore_index(preds.index)
|
||||
preds.index = preds.index.swaplevel()
|
||||
preds.sort_index(inplace=True)
|
||||
|
||||
return metrics, preds
|
||||
|
||||
def fit(self, dataset, evals_result=dict()):
|
||||
|
||||
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
|
||||
|
||||
best_score = -1
|
||||
best_epoch = 0
|
||||
stop_rounds = 0
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
params_list = {
|
||||
"model": collections.deque(maxlen=self.smooth_steps),
|
||||
"tra": collections.deque(maxlen=self.smooth_steps),
|
||||
}
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
evals_result["test"] = []
|
||||
|
||||
# train
|
||||
self.fitted = True
|
||||
self.global_step = -1
|
||||
|
||||
if self.tra.num_states > 1:
|
||||
self.logger.info("init memory...")
|
||||
self.test_epoch(train_set)
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
self.logger.info("Epoch %d:", epoch)
|
||||
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(train_set)
|
||||
|
||||
self.logger.info("evaluating...")
|
||||
# average params for inference
|
||||
params_list["model"].append(copy.deepcopy(self.model.state_dict()))
|
||||
params_list["tra"].append(copy.deepcopy(self.tra.state_dict()))
|
||||
self.model.load_state_dict(average_params(params_list["model"]))
|
||||
self.tra.load_state_dict(average_params(params_list["tra"]))
|
||||
|
||||
# NOTE: during evaluating, the whole memory will be refreshed
|
||||
if self.tra.num_states > 1 or self.eval_train:
|
||||
train_set.clear_memory() # NOTE: clear the shared memory
|
||||
train_metrics = self.test_epoch(train_set)[0]
|
||||
evals_result["train"].append(train_metrics)
|
||||
self.logger.info("\ttrain metrics: %s" % train_metrics)
|
||||
|
||||
valid_metrics = self.test_epoch(valid_set)[0]
|
||||
evals_result["valid"].append(valid_metrics)
|
||||
self.logger.info("\tvalid metrics: %s" % valid_metrics)
|
||||
|
||||
if self.eval_test:
|
||||
test_metrics = self.test_epoch(test_set)[0]
|
||||
evals_result["test"].append(test_metrics)
|
||||
self.logger.info("\ttest metrics: %s" % test_metrics)
|
||||
|
||||
if valid_metrics["IC"] > best_score:
|
||||
best_score = valid_metrics["IC"]
|
||||
stop_rounds = 0
|
||||
best_epoch = epoch
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
else:
|
||||
stop_rounds += 1
|
||||
if stop_rounds >= self.early_stop:
|
||||
self.logger.info("early stop @ %s" % epoch)
|
||||
break
|
||||
|
||||
# restore parameters
|
||||
self.model.load_state_dict(params_list["model"][-1])
|
||||
self.tra.load_state_dict(params_list["tra"][-1])
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_params["model"])
|
||||
self.tra.load_state_dict(best_params["tra"])
|
||||
|
||||
metrics, preds = self.test_epoch(test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % metrics)
|
||||
|
||||
if self.logdir:
|
||||
self.logger.info("save model & pred to local directory")
|
||||
|
||||
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
|
||||
self.logdir + "/logs.csv", index=False
|
||||
)
|
||||
|
||||
torch.save(best_params, self.logdir + "/model.bin")
|
||||
|
||||
preds.to_pickle(self.logdir + "/pred.pkl")
|
||||
|
||||
info = {
|
||||
"config": {
|
||||
"model_config": self.model_config,
|
||||
"tra_config": self.tra_config,
|
||||
"lr": self.lr,
|
||||
"n_epochs": self.n_epochs,
|
||||
"early_stop": self.early_stop,
|
||||
"smooth_steps": self.smooth_steps,
|
||||
"max_steps_per_epoch": self.max_steps_per_epoch,
|
||||
"lamb": self.lamb,
|
||||
"rho": self.rho,
|
||||
"seed": self.seed,
|
||||
"logdir": self.logdir,
|
||||
},
|
||||
"best_eval_metric": -best_score, # NOTE: minux -1 for minimize
|
||||
"metric": metrics,
|
||||
}
|
||||
with open(self.logdir + "/info.json", "w") as f:
|
||||
json.dump(info, f)
|
||||
|
||||
def predict(self, dataset, segment="test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
test_set = dataset.prepare(segment)
|
||||
|
||||
metrics, preds = self.test_epoch(test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % metrics)
|
||||
|
||||
return preds
|
||||
|
||||
|
||||
class LSTM(nn.Module):
|
||||
|
||||
"""LSTM Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of hidden layers
|
||||
use_attn (bool): whether use attention layer.
|
||||
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
|
||||
dropout (float): dropout rate
|
||||
input_drop (float): input dropout for data augmentation
|
||||
noise_level (float): add gaussian noise to input for data augmentation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
use_attn=True,
|
||||
dropout=0.0,
|
||||
input_drop=0.0,
|
||||
noise_level=0.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.use_attn = use_attn
|
||||
self.noise_level = noise_level
|
||||
|
||||
self.input_drop = nn.Dropout(input_drop)
|
||||
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
if self.use_attn:
|
||||
self.W = nn.Linear(hidden_size, hidden_size)
|
||||
self.u = nn.Linear(hidden_size, 1, bias=False)
|
||||
self.output_size = hidden_size * 2
|
||||
else:
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.input_drop(x)
|
||||
|
||||
if self.training and self.noise_level > 0:
|
||||
noise = torch.randn_like(x).to(x)
|
||||
x = x + noise * self.noise_level
|
||||
|
||||
rnn_out, _ = self.rnn(x)
|
||||
last_out = rnn_out[:, -1]
|
||||
|
||||
if self.use_attn:
|
||||
laten = self.W(rnn_out).tanh()
|
||||
scores = self.u(laten).softmax(dim=1)
|
||||
att_out = (rnn_out * scores).sum(dim=1).squeeze()
|
||||
last_out = torch.cat([last_out, att_out], dim=1)
|
||||
|
||||
return last_out
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[: x.size(0), :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
"""Transformer Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of transformer layers
|
||||
num_heads (int): number of heads in transformer
|
||||
dropout (float): dropout rate
|
||||
input_drop (float): input dropout for data augmentation
|
||||
noise_level (float): add gaussian noise to input for data augmentation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
num_heads=2,
|
||||
dropout=0.0,
|
||||
input_drop=0.0,
|
||||
noise_level=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.noise_level = noise_level
|
||||
|
||||
self.input_drop = nn.Dropout(input_drop)
|
||||
|
||||
self.input_proj = nn.Linear(input_size, hidden_size)
|
||||
|
||||
self.pe = PositionalEncoding(input_size, dropout)
|
||||
layer = nn.TransformerEncoderLayer(
|
||||
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
|
||||
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.input_drop(x)
|
||||
|
||||
if self.training and self.noise_level > 0:
|
||||
noise = torch.randn_like(x).to(x)
|
||||
x = x + noise * self.noise_level
|
||||
|
||||
x = x.permute(1, 0, 2).contiguous() # the first dim need to be sequence
|
||||
x = self.pe(x)
|
||||
|
||||
x = self.input_proj(x)
|
||||
out = self.encoder(x)
|
||||
|
||||
return out[-1]
|
||||
|
||||
|
||||
class TRA(nn.Module):
|
||||
|
||||
"""Temporal Routing Adaptor (TRA)
|
||||
|
||||
TRA takes historical prediction erros & latent representation as inputs,
|
||||
then routes the input sample to a specific predictor for training & inference.
|
||||
|
||||
Args:
|
||||
input_size (int): input size (RNN/Transformer's hidden size)
|
||||
num_states (int): number of latent states (i.e., trading patterns)
|
||||
If `num_states=1`, then TRA falls back to traditional methods
|
||||
hidden_size (int): hidden size of the router
|
||||
tau (float): gumbel softmax temperature
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"):
|
||||
super().__init__()
|
||||
|
||||
self.num_states = num_states
|
||||
self.tau = tau
|
||||
self.src_info = src_info
|
||||
|
||||
if num_states > 1:
|
||||
self.router = nn.LSTM(
|
||||
input_size=num_states,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
)
|
||||
self.fc = nn.Linear(hidden_size + input_size, num_states)
|
||||
|
||||
self.predictors = nn.Linear(input_size, num_states)
|
||||
|
||||
def forward(self, hidden, hist_loss):
|
||||
|
||||
preds = self.predictors(hidden)
|
||||
|
||||
if self.num_states == 1:
|
||||
return preds.squeeze(-1), preds, None
|
||||
|
||||
# information type
|
||||
router_out, _ = self.router(hist_loss)
|
||||
if "LR" in self.src_info:
|
||||
latent_representation = hidden
|
||||
else:
|
||||
latent_representation = torch.randn(hidden.shape).to(hidden)
|
||||
if "TPE" in self.src_info:
|
||||
temporal_pred_error = router_out[:, -1]
|
||||
else:
|
||||
temporal_pred_error = torch.randn(router_out[:, -1].shape).to(hidden)
|
||||
|
||||
out = self.fc(torch.cat([temporal_pred_error, latent_representation], dim=-1))
|
||||
prob = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=False)
|
||||
|
||||
if self.training:
|
||||
final_pred = (preds * prob).sum(dim=-1)
|
||||
else:
|
||||
final_pred = preds[range(len(preds)), prob.argmax(dim=-1)]
|
||||
|
||||
return final_pred, preds, prob
|
||||
|
||||
|
||||
def evaluate(pred):
|
||||
pred = pred.rank(pct=True) # transform into percentiles
|
||||
score = pred.score
|
||||
label = pred.label
|
||||
diff = score - label
|
||||
MSE = (diff ** 2).mean()
|
||||
MAE = (diff.abs()).mean()
|
||||
IC = score.corr(label)
|
||||
return {"MSE": MSE, "MAE": MAE, "IC": IC}
|
||||
|
||||
|
||||
def average_params(params_list):
|
||||
assert isinstance(params_list, (tuple, list, collections.deque))
|
||||
n = len(params_list)
|
||||
if n == 1:
|
||||
return params_list[0]
|
||||
new_params = collections.OrderedDict()
|
||||
keys = None
|
||||
for i, params in enumerate(params_list):
|
||||
if keys is None:
|
||||
keys = params.keys()
|
||||
for k, v in params.items():
|
||||
if k not in keys:
|
||||
raise ValueError("the %d-th model has different params" % i)
|
||||
if k not in new_params:
|
||||
new_params[k] = v / n
|
||||
else:
|
||||
new_params[k] += v / n
|
||||
return new_params
|
||||
|
||||
|
||||
def shoot_infs(inp_tensor):
|
||||
"""Replaces inf by maximum of tensor"""
|
||||
mask_inf = torch.isinf(inp_tensor)
|
||||
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
|
||||
if len(ind_inf) > 0:
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = 0
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = 0
|
||||
m = torch.max(inp_tensor)
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = m
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = m
|
||||
return inp_tensor
|
||||
|
||||
|
||||
def sinkhorn(Q, n_iters=3, epsilon=0.01):
|
||||
# epsilon should be adjusted according to logits value's scale
|
||||
with torch.no_grad():
|
||||
Q = shoot_infs(Q)
|
||||
Q = torch.exp(Q / epsilon)
|
||||
for i in range(n_iters):
|
||||
Q /= Q.sum(dim=0, keepdim=True)
|
||||
Q /= Q.sum(dim=1, keepdim=True)
|
||||
return Q
|
||||
@@ -1,7 +1,5 @@
|
||||
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.utils import get_cls_kwargs
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.contrib.data.handler import check_transform_proc
|
||||
|
||||
|
||||
class HighFreqHandler(DataHandlerLP):
|
||||
@@ -16,20 +14,9 @@ class HighFreqHandler(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
|
||||
@@ -26,7 +26,7 @@ def get_calendar_day(freq="day", future=False):
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
|
||||
_calendar = np.array(list(map(lambda x: pd.Timestamp(x.date()), Cal.load_calendar(freq, future))))
|
||||
H["c"][flag] = _calendar
|
||||
return _calendar
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class HighfreqWorkflow:
|
||||
"fit_start_time": start_time,
|
||||
"fit_end_time": train_end_time,
|
||||
"instruments": MARKET,
|
||||
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}],
|
||||
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor"}],
|
||||
}
|
||||
DATA_HANDLER_CONFIG1 = {
|
||||
"start_time": start_time,
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be shown in task_collecting.
|
||||
Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
@@ -13,7 +14,7 @@ import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
@@ -68,6 +69,11 @@ class RollingTaskExample:
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
|
||||
print("========== worker ==========")
|
||||
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
|
||||
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
@@ -13,7 +14,7 @@ from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
@@ -22,8 +23,8 @@ class OnlineSimulationExample:
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
@@ -46,7 +47,7 @@ class OnlineSimulationExample:
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
@@ -59,7 +60,7 @@ class OnlineSimulationExample:
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
@@ -85,6 +86,15 @@ class OnlineSimulationExample:
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
# FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
self.trainer.worker()
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
|
||||
@@ -13,11 +13,13 @@ Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
import os
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
@@ -25,16 +27,17 @@ class RollingOnlineExample:
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
rolling_step=550,
|
||||
tasks=None,
|
||||
add_tasks=None,
|
||||
):
|
||||
if add_tasks is None:
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
@@ -53,17 +56,28 @@ class RollingOnlineExample:
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineManager(strategies)
|
||||
self.trainer = trainer
|
||||
self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
self.trainer.worker(experiment_name=name_id)
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
TaskManager(task_pool=name_id).remove()
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
@@ -220,7 +220,7 @@
|
||||
"\n",
|
||||
"# backtest and analysis\n",
|
||||
"with R.start(experiment_name=\"backtest_analysis\"):\n",
|
||||
" recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
|
||||
" recorder = R.get_recorder(recorder_id=rid, experiment_name=\"train_model\")\n",
|
||||
" model = recorder.load_object(\"trained_model\")\n",
|
||||
"\n",
|
||||
" # prediction\n",
|
||||
@@ -249,7 +249,7 @@
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
"from qlib.data import D\n",
|
||||
"recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"pred_df = recorder.load_object(\"pred.pkl\")\n",
|
||||
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
|
||||
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal.pkl\")\n",
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
__version__ = "0.6.3.99"
|
||||
__version__ = "0.7.0"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
|
||||
|
||||
@@ -20,11 +20,17 @@ def init(default_conf="client", **kwargs):
|
||||
from .config import C
|
||||
from .data.cache import H
|
||||
|
||||
H.clear()
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
logger = get_module_logger("Initialization", level=logging.INFO)
|
||||
|
||||
skip_if_reg = kwargs.pop("skip_if_reg", False)
|
||||
if skip_if_reg and C.registered:
|
||||
# if we reinitialize Qlib during running an experiment `R.start`.
|
||||
# it will result in loss of the recorder
|
||||
logger.warning("Skip initialization because `skip_if_reg is True`")
|
||||
return
|
||||
|
||||
H.clear()
|
||||
C.set(default_conf, **kwargs)
|
||||
|
||||
# check path if server/local
|
||||
@@ -197,14 +203,15 @@ def auto_init(**kwargs):
|
||||
- Find the project configuration and init qlib
|
||||
- The parsing process will be affected by the `conf_type` of the configuration file
|
||||
- Init qlib with default config
|
||||
- Skip initialization if already initialized
|
||||
"""
|
||||
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
|
||||
|
||||
try:
|
||||
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
@@ -195,7 +195,10 @@ MODE_CONF = {
|
||||
"timeout": 100,
|
||||
"logging_level": logging.INFO,
|
||||
"region": REG_CN,
|
||||
## Custom Operator
|
||||
# custom operator
|
||||
# each element of custom_ops should be Type[ExpressionOps] or dict
|
||||
# if element of custom_ops is Type[ExpressionOps], it represents the custom operator class
|
||||
# if element of custom_ops is dict, it represents the config of custom operator and should include `class` and `module_path` keys.
|
||||
"custom_ops": [],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -26,8 +26,10 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
# FIXME: the `module_path` parameter is missed.
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
proc_config = {"class": klass.__name__, "kwargs": pkwargs}
|
||||
if isinstance(p, dict) and "module_path" in p:
|
||||
proc_config["module_path"] = p["module_path"]
|
||||
new_l.append(proc_config)
|
||||
else:
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
@@ -53,7 +53,6 @@ class GATs(Model):
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
with_pretrain=True,
|
||||
model_path=None,
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
@@ -76,7 +75,6 @@ class GATs(Model):
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
@@ -94,7 +92,6 @@ class GATs(Model):
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nwith_pretrain : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
@@ -110,7 +107,6 @@ class GATs(Model):
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
with_pretrain,
|
||||
model_path,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
@@ -253,24 +249,22 @@ class GATs(Model):
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.with_pretrain:
|
||||
if self.model_path == None:
|
||||
raise ValueError("the path of the pretrained model should be given first!")
|
||||
self.logger.info("Loading pretrained model...")
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
|
||||
@@ -29,8 +29,8 @@ class DailyBatchSampler(Sampler):
|
||||
def __init__(self, data_source):
|
||||
|
||||
self.data_source = data_source
|
||||
self.data = self.data_source.data.loc[self.data_source.get_index()]
|
||||
self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch
|
||||
# calculate number of samples in each batch
|
||||
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values
|
||||
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
|
||||
self.daily_index[0] = 0
|
||||
|
||||
@@ -72,7 +72,6 @@ class GATs(Model):
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
with_pretrain=True,
|
||||
model_path=None,
|
||||
optimizer="adam",
|
||||
GPU="0",
|
||||
@@ -96,7 +95,6 @@ class GATs(Model):
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
@@ -115,7 +113,6 @@ class GATs(Model):
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nwith_pretrain : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
@@ -131,7 +128,6 @@ class GATs(Model):
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
with_pretrain,
|
||||
model_path,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
@@ -270,28 +266,22 @@ class GATs(Model):
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.with_pretrain:
|
||||
if self.model_path == None:
|
||||
raise ValueError("the path of the pretrained model should be given first!")
|
||||
self.logger.info("Loading pretrained model...")
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
|
||||
)
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
|
||||
)
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
|
||||
@@ -297,7 +297,7 @@ class DNNModelPytorch(Model):
|
||||
_model_path = os.path.join(model_dir, _model_name)
|
||||
# Load model
|
||||
self.dnn_model.load_state_dict(torch.load(_model_path))
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
|
||||
420
qlib/contrib/model/pytorch_tcts.py
Normal file
420
qlib/contrib/model/pytorch_tcts.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
import random
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class TCTS(Model):
|
||||
"""TCTS Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
fore_optimizer="adam",
|
||||
weight_optimizer="adam",
|
||||
output_dim=5,
|
||||
fore_lr=5e-7,
|
||||
weight_lr=5e-7,
|
||||
steps=3,
|
||||
GPU=0,
|
||||
seed=0,
|
||||
target_label=0,
|
||||
lowest_valid_performance=0.993,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCTS")
|
||||
self.logger.info("TCTS pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
self.output_dim = output_dim
|
||||
self.fore_lr = fore_lr
|
||||
self.weight_lr = weight_lr
|
||||
self.steps = steps
|
||||
self.target_label = target_label
|
||||
self.lowest_valid_performance = lowest_valid_performance
|
||||
self._fore_optimizer = fore_optimizer
|
||||
self._weight_optimizer = weight_optimizer
|
||||
|
||||
self.logger.info(
|
||||
"TCTS parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
batch_size,
|
||||
early_stop,
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
def loss_fn(self, pred, label, weight):
|
||||
|
||||
loc = torch.argmax(weight, 1)
|
||||
loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def train_epoch(self, x_train, y_train, x_valid, y_valid):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
init_fore_model = copy.deepcopy(self.fore_model)
|
||||
for p in init_fore_model.parameters():
|
||||
p.init_fore_model = False
|
||||
|
||||
self.fore_model.train()
|
||||
self.weight_model.train()
|
||||
|
||||
for p in self.weight_model.parameters():
|
||||
p.requires_grad = False
|
||||
for p in self.fore_model.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
for i in range(self.steps):
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
init_pred = init_fore_model(feature)
|
||||
pred = self.fore_model(feature)
|
||||
|
||||
dis = init_pred - label.transpose(0, 1)
|
||||
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, init_pred.view(-1, 1)), 1)
|
||||
weight = self.weight_model(weight_feature)
|
||||
|
||||
loss = self.loss_fn(pred, label, weight) # hard
|
||||
|
||||
self.fore_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.fore_model.parameters(), 3.0)
|
||||
self.fore_optimizer.step()
|
||||
|
||||
x_valid_values = x_valid.values
|
||||
y_valid_values = np.squeeze(y_valid.values)
|
||||
|
||||
indices = np.arange(len(x_valid_values))
|
||||
np.random.shuffle(indices)
|
||||
for p in self.weight_model.parameters():
|
||||
p.requires_grad = True
|
||||
for p in self.fore_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# fix forecasting model and valid weight model
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.fore_model(feature)
|
||||
dis = pred - label.transpose(0, 1)
|
||||
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1)), 1)
|
||||
weight = self.weight_model(weight_feature)
|
||||
loc = torch.argmax(weight, 1)
|
||||
valid_loss = torch.mean((pred - label[:, 0]) ** 2)
|
||||
loss = torch.mean(-valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))
|
||||
|
||||
self.weight_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.weight_model.parameters(), 3.0)
|
||||
self.weight_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.fore_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.fore_model(feature)
|
||||
loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2)
|
||||
losses.append(loss.item())
|
||||
|
||||
return np.mean(losses)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = get_or_create_path(save_path)
|
||||
best_loss = np.inf
|
||||
while best_loss > self.lowest_valid_performance:
|
||||
if best_loss < np.inf:
|
||||
print("Failed! Start retraining.")
|
||||
self.seed = random.randint(0, 1000) # reset random seed
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
best_loss = self.training(
|
||||
x_train, y_train, x_valid, y_valid, x_test, y_test, verbose=verbose, save_path=save_path
|
||||
)
|
||||
|
||||
def training(
|
||||
self,
|
||||
x_train,
|
||||
y_train,
|
||||
x_valid,
|
||||
y_valid,
|
||||
x_test,
|
||||
y_test,
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
self.fore_model = GRUModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.weight_model = MLPModel(
|
||||
d_feat=360 + 2 * self.output_dim + 1,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
output_dim=self.output_dim,
|
||||
)
|
||||
if self._fore_optimizer.lower() == "adam":
|
||||
self.fore_optimizer = optim.Adam(self.fore_model.parameters(), lr=self.fore_lr)
|
||||
elif self._fore_optimizer.lower() == "gd":
|
||||
self.fore_optimizer = optim.SGD(self.fore_model.parameters(), lr=self.fore_lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(self._fore_optimizer))
|
||||
if self._weight_optimizer.lower() == "adam":
|
||||
self.weight_optimizer = optim.Adam(self.weight_model.parameters(), lr=self.weight_lr)
|
||||
elif self._weight_optimizer.lower() == "gd":
|
||||
self.weight_optimizer = optim.SGD(self.weight_model.parameters(), lr=self.weight_lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(self._weight_optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.fore_model.to(self.device)
|
||||
self.weight_model.to(self.device)
|
||||
|
||||
best_loss = np.inf
|
||||
best_epoch = 0
|
||||
stop_round = 0
|
||||
fore_best_param = copy.deepcopy(self.fore_optimizer.state_dict())
|
||||
weight_best_param = copy.deepcopy(self.weight_optimizer.state_dict())
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
print("Epoch:", epoch)
|
||||
|
||||
print("training...")
|
||||
self.train_epoch(x_train, y_train, x_valid, y_valid)
|
||||
print("evaluating...")
|
||||
val_loss = self.test_epoch(x_valid, y_valid)
|
||||
test_loss = self.test_epoch(x_test, y_test)
|
||||
|
||||
if verbose:
|
||||
print("valid %.6f, test %.6f" % (val_loss, test_loss))
|
||||
|
||||
if val_loss < best_loss:
|
||||
best_loss = val_loss
|
||||
stop_round = 0
|
||||
best_epoch = epoch
|
||||
torch.save(copy.deepcopy(self.fore_model.state_dict()), save_path + "_fore_model.bin")
|
||||
torch.save(copy.deepcopy(self.weight_model.state_dict()), save_path + "_weight_model.bin")
|
||||
|
||||
else:
|
||||
stop_round += 1
|
||||
if stop_round >= self.early_stop:
|
||||
print("early stop")
|
||||
break
|
||||
|
||||
print("best loss:", best_loss, "@", best_epoch)
|
||||
best_param = torch.load(save_path + "_fore_model.bin")
|
||||
self.fore_model.load_state_dict(best_param)
|
||||
best_param = torch.load(save_path + "_weight_model.bin")
|
||||
self.weight_model.load_state_dict(best_param)
|
||||
self.fitted = True
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return best_loss
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
self.fore_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.fore_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.fore_model(x_batch).detach().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class MLPModel(nn.Module):
|
||||
def __init__(self, d_feat, hidden_size=256, num_layers=3, dropout=0.0, output_dim=1):
|
||||
super().__init__()
|
||||
|
||||
self.mlp = nn.Sequential()
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
|
||||
for i in range(num_layers):
|
||||
if i > 0:
|
||||
self.mlp.add_module("drop_%d" % i, nn.Dropout(dropout))
|
||||
self.mlp.add_module("fc_%d" % i, nn.Linear(d_feat if i == 0 else hidden_size, hidden_size))
|
||||
self.mlp.add_module("relu_%d" % i, nn.ReLU())
|
||||
|
||||
self.mlp.add_module("fc_out", nn.Linear(hidden_size, output_dim))
|
||||
|
||||
def forward(self, x):
|
||||
# feature
|
||||
# [N, F]
|
||||
out = self.mlp(x).squeeze()
|
||||
out = self.softmax(out)
|
||||
return out
|
||||
|
||||
|
||||
class GRUModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.fc_out = nn.Linear(hidden_size, 1)
|
||||
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, x):
|
||||
# x: [N, F*T]
|
||||
x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x = x.permute(0, 2, 1) # [N, T, F]
|
||||
out, _ = self.rnn(x)
|
||||
return self.fc_out(out[:, -1, :]).squeeze()
|
||||
@@ -62,7 +62,7 @@ class XGBModel(Model, FeatureInt):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test)), index=x_test.index)
|
||||
|
||||
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import plotly.tools as tls
|
||||
import plotly.graph_objs as go
|
||||
|
||||
import statsmodels.api as sm
|
||||
@@ -80,9 +79,35 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
|
||||
:param dist:
|
||||
:return:
|
||||
"""
|
||||
fig, ax = plt.subplots(figsize=(8, 5))
|
||||
_mpl_fig = sm.qqplot(data.dropna(), dist, fit=True, line="45", ax=ax)
|
||||
return tls.mpl_to_plotly(_mpl_fig)
|
||||
# NOTE: plotly.tools.mpl_to_plotly not actively maintained, resulting in errors in the new version of matplotlib,
|
||||
# ref: https://github.com/plotly/plotly.py/issues/2913#issuecomment-730071567
|
||||
# removing plotly.tools.mpl_to_plotly for greater compatibility with matplotlib versions
|
||||
_plt_fig = sm.qqplot(data.dropna(), dist=dist, fit=True, line="45")
|
||||
plt.close(_plt_fig)
|
||||
qqplot_data = _plt_fig.gca().lines
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
{
|
||||
"type": "scatter",
|
||||
"x": qqplot_data[0].get_xdata(),
|
||||
"y": qqplot_data[0].get_ydata(),
|
||||
"mode": "markers",
|
||||
"marker": {"color": "#19d3f3"},
|
||||
}
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
{
|
||||
"type": "scatter",
|
||||
"x": qqplot_data[1].get_xdata(),
|
||||
"y": qqplot_data[1].get_ydata(),
|
||||
"mode": "lines",
|
||||
"line": {"color": "#636efa"},
|
||||
}
|
||||
)
|
||||
del qqplot_data
|
||||
return fig
|
||||
|
||||
|
||||
def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple:
|
||||
|
||||
@@ -148,7 +148,6 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
|
||||
pred score for this trade date, index is stock_id, contain 'score' column.
|
||||
current : Position()
|
||||
current position.
|
||||
trade_exchange : Exchange()
|
||||
trade_date : pd.Timestamp
|
||||
trade date.
|
||||
"""
|
||||
|
||||
@@ -237,7 +237,7 @@ class CacheUtils:
|
||||
lock.acquire()
|
||||
except redis_lock.AlreadyAcquired:
|
||||
raise QlibCacheException(
|
||||
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
|
||||
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
|
||||
You can use the following command to clear your redis keys and rerun your commands:
|
||||
$ redis-cli
|
||||
> select {C.redis_task_db}
|
||||
@@ -784,10 +784,10 @@ class DiskDatasetCache(DatasetCache):
|
||||
def build_index_from_data(data, start_index=0):
|
||||
if data.empty:
|
||||
return pd.DataFrame()
|
||||
line_data = data.iloc[:, 0].fillna(0).groupby("datetime").count()
|
||||
line_data = data.groupby("datetime").size()
|
||||
line_data.sort_index(inplace=True)
|
||||
index_end = line_data.cumsum()
|
||||
index_start = index_end.shift(1).fillna(0)
|
||||
index_start = index_end.shift(1, fill_value=0)
|
||||
|
||||
index_data = pd.DataFrame()
|
||||
index_data["start"] = index_start
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
@@ -243,6 +243,8 @@ class TSDataSampler:
|
||||
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
|
||||
dataset based on tabular data.
|
||||
- On time step dimension, the smaller index indicates the historical data and the larger index indicates the future
|
||||
data.
|
||||
|
||||
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
|
||||
more powerful subclasses.
|
||||
@@ -309,11 +311,19 @@ class TSDataSampler:
|
||||
self.data_index = deepcopy(self.data.index)
|
||||
|
||||
if flt_data is not None:
|
||||
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
|
||||
if isinstance(flt_data, pd.DataFrame):
|
||||
assert len(flt_data.columns) == 1
|
||||
flt_data = flt_data.iloc[:, 0]
|
||||
# NOTE: bool(np.nan) is True !!!!!!!!
|
||||
# make sure reindex comes first. Otherwise extra NaN may appear.
|
||||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(
|
||||
start=time_to_slc_point(start), end=time_to_slc_point(end)
|
||||
)
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
del self.data # save memory
|
||||
@@ -341,7 +351,7 @@ class TSDataSampler:
|
||||
setattr(self, k, v)
|
||||
|
||||
@staticmethod
|
||||
def build_index(data: pd.DataFrame) -> dict:
|
||||
def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
|
||||
"""
|
||||
The relation of the data
|
||||
|
||||
@@ -352,9 +362,15 @@ class TSDataSampler:
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict:
|
||||
{<index>: <prev_index or None>}
|
||||
# get the previous index of a line given index
|
||||
Tuple[pd.DataFrame, dict]:
|
||||
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
|
||||
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
|
||||
datetime
|
||||
2021-01-11 0 1 2 3 4 5 ...
|
||||
2021-01-12 4146 4147 4148 4149 4150 4151 ...
|
||||
2021-01-13 8293 8294 8295 8296 8297 8298 ...
|
||||
2021-01-14 12441 12442 12443 12444 12445 12446 ...
|
||||
2) the second element: {<original index>: <row, col>}
|
||||
"""
|
||||
# object incase of pandas converting int to flaot
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
|
||||
@@ -491,7 +507,9 @@ class TSDatasetH(DatasetH):
|
||||
- The dimension of a batch of data <batch_idx, feature, timestep>
|
||||
"""
|
||||
|
||||
def __init__(self, step_len=30, **kwargs):
|
||||
DEFAULT_STEP_LEN = 30
|
||||
|
||||
def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
|
||||
self.step_len = step_len
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import Tuple, Union
|
||||
from qlib.data import D
|
||||
from qlib.data import filter as filter_module
|
||||
from qlib.data.filter import BaseDFilter
|
||||
from qlib.utils import load_dataset, init_instance_by_config
|
||||
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
@@ -207,7 +207,10 @@ class StaticDataLoader(DataLoader):
|
||||
df = self._data.loc(axis=0)[:, instruments]
|
||||
if start_time is None and end_time is None:
|
||||
return df # NOTE: avoid copy by loc
|
||||
return df.loc[pd.Timestamp(start_time) : pd.Timestamp(end_time)]
|
||||
# pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
|
||||
start_time = time_to_slc_point(start_time)
|
||||
end_time = time_to_slc_point(end_time)
|
||||
return df.loc[start_time:end_time]
|
||||
|
||||
def _maybe_load_raw_data(self):
|
||||
if self._data is not None:
|
||||
|
||||
@@ -10,10 +10,12 @@ import abc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Union, List, Type
|
||||
from scipy.stats import percentileofscore
|
||||
|
||||
from .base import Expression, ExpressionOps
|
||||
from ..log import get_module_logger
|
||||
from ..utils import get_cls_kwargs
|
||||
|
||||
try:
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
@@ -1495,16 +1497,34 @@ class OpsWrapper:
|
||||
def reset(self):
|
||||
self._ops = {}
|
||||
|
||||
def register(self, ops_list):
|
||||
for operator in ops_list:
|
||||
if not issubclass(operator, ExpressionOps):
|
||||
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator))
|
||||
def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]):
|
||||
"""register operator
|
||||
|
||||
if operator.__name__ in self._ops:
|
||||
Parameters
|
||||
----------
|
||||
ops_list : List[Union[Type[ExpressionOps], dict]]
|
||||
- if type(ops_list) is List[Type[ExpressionOps]], each element of ops_list represents the operator class, which should be the subclass of `ExpressionOps`.
|
||||
- if type(ops_list) is List[dict], each element of ops_list represents the config of operator, which has the following format:
|
||||
{
|
||||
"class": class_name,
|
||||
"module_path": path,
|
||||
}
|
||||
Note: `class` should be the class name of operator, `module_path` should be a python module or path of file.
|
||||
"""
|
||||
for _operator in ops_list:
|
||||
if isinstance(_operator, dict):
|
||||
_ops_class, _ = get_cls_kwargs(_operator)
|
||||
else:
|
||||
_ops_class = _operator
|
||||
|
||||
if not issubclass(_ops_class, ExpressionOps):
|
||||
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))
|
||||
|
||||
if _ops_class.__name__ in self._ops:
|
||||
get_module_logger(self.__class__.__name__).warning(
|
||||
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
|
||||
"The custom operator [{}] will override the qlib default definition".format(_ops_class.__name__)
|
||||
)
|
||||
self._ops[operator.__name__] = operator
|
||||
self._ops[_ops_class.__name__] = _ops_class
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key not in self._ops:
|
||||
|
||||
10
qlib/log.py
10
qlib/log.py
@@ -28,16 +28,18 @@ class QlibLogger(metaclass=MetaLogger):
|
||||
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
self.level = 0
|
||||
# this feature name conflicts with the attribute with Logger
|
||||
# rename it to avoid some corner cases that result in comparing `str` and `int`
|
||||
self.__level = 0
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
logger = logging.getLogger(self.module_name)
|
||||
logger.setLevel(self.level)
|
||||
logger.setLevel(self.__level)
|
||||
return logger
|
||||
|
||||
def setLevel(self, level):
|
||||
self.level = level
|
||||
self.__level = level
|
||||
|
||||
def __getattr__(self, name):
|
||||
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
|
||||
@@ -68,7 +70,7 @@ def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logge
|
||||
|
||||
class TimeInspector:
|
||||
|
||||
timer_logger = get_module_logger("timer", level=logging.WARNING)
|
||||
timer_logger = get_module_logger("timer", level=logging.INFO)
|
||||
|
||||
time_marks = []
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ class ModelFT(Model):
|
||||
|
||||
# Finetune model based on previous trained model
|
||||
with R.start(experiment_name="finetune model"):
|
||||
recorder = R.get_recorder(rid, experiment_name="init models")
|
||||
recorder = R.get_recorder(recorder_id=rid, experiment_name="init models")
|
||||
model = recorder.load_object("init_model")
|
||||
model.finetune(dataset, num_boost_round=10)
|
||||
|
||||
|
||||
@@ -8,13 +8,15 @@ There are two steps in each Trainer including ``train``(make model recorder) and
|
||||
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
|
||||
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
|
||||
|
||||
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
"""
|
||||
|
||||
import socket
|
||||
import time
|
||||
from typing import Callable, List
|
||||
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
@@ -151,6 +153,9 @@ class Trainer:
|
||||
"""
|
||||
return self.delay
|
||||
|
||||
def __call__(self, *args, **kwargs) -> list:
|
||||
return self.end_train(self.train(*args, **kwargs))
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""
|
||||
@@ -190,6 +195,8 @@ class TrainerR(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
@@ -213,6 +220,8 @@ class TrainerR(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
@@ -250,6 +259,8 @@ class DelayTrainerR(TrainerR):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
@@ -275,7 +286,12 @@ class TrainerRM(Trainer):
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
|
||||
# This tag is the _id in TaskManager to distinguish tasks.
|
||||
TM_ID = "_id in TaskManager"
|
||||
|
||||
def __init__(
|
||||
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
|
||||
):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
@@ -283,11 +299,16 @@ class TrainerRM(Trainer):
|
||||
experiment_name (str): the default name of experiment.
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
skip_run_task (bool):
|
||||
If skip_run_task == True:
|
||||
Only run_task in the worker. Otherwise skip run_task.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.train_func = train_func
|
||||
self.skip_run_task = skip_run_task
|
||||
|
||||
def train(
|
||||
self,
|
||||
@@ -315,6 +336,8 @@ class TrainerRM(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
@@ -326,19 +349,26 @@ class TrainerRM(Trainer):
|
||||
task_pool = experiment_name
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
if not self.skip_run_task:
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not self.is_delay():
|
||||
tm.wait(query=query)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
rec = tm.re_query(_id)["res"]
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
rec.set_tags(**{self.TM_ID: _id})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
@@ -352,10 +382,33 @@ class TrainerRM(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
def worker(
|
||||
self,
|
||||
train_func: Callable = None,
|
||||
experiment_name: str = None,
|
||||
):
|
||||
"""
|
||||
The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.
|
||||
|
||||
Args:
|
||||
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
"""
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
@@ -369,6 +422,7 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool: str = None,
|
||||
train_func=begin_task_train,
|
||||
end_train_func=end_task_train,
|
||||
skip_run_task: bool = False,
|
||||
):
|
||||
"""
|
||||
Init DelayTrainerRM.
|
||||
@@ -378,10 +432,15 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
skip_run_task (bool):
|
||||
If skip_run_task == True:
|
||||
Only run_task in the worker. Otherwise skip run_task.
|
||||
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
|
||||
"""
|
||||
super().__init__(experiment_name, task_pool, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
self.skip_run_task = skip_run_task
|
||||
|
||||
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
@@ -395,6 +454,8 @@ class DelayTrainerRM(TrainerRM):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
return super().train(
|
||||
@@ -410,8 +471,6 @@ class DelayTrainerRM(TrainerRM):
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them.
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
@@ -421,7 +480,8 @@ class DelayTrainerRM(TrainerRM):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
@@ -429,18 +489,45 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tasks = []
|
||||
_id_list = []
|
||||
for rec in recs:
|
||||
tasks.append(rec.load_object("task"))
|
||||
_id_list.append(rec.list_tags()[self.TM_ID])
|
||||
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
if not self.skip_run_task:
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
TaskManager(task_pool=task_pool).wait(query=query)
|
||||
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query={"filter": {"$in": tasks}}, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
def worker(self, end_train_func=None, experiment_name: str = None):
|
||||
"""
|
||||
The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.
|
||||
|
||||
Args:
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
"""
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool=task_pool,
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
)
|
||||
|
||||
@@ -43,17 +43,29 @@ RECORD_CONFIG = [
|
||||
]
|
||||
|
||||
|
||||
def get_data_handler_config(market=CSI300_MARKET):
|
||||
def get_data_handler_config(
|
||||
start_time="2008-01-01",
|
||||
end_time="2020-08-01",
|
||||
fit_start_time="2008-01-01",
|
||||
fit_end_time="2014-12-31",
|
||||
instruments=CSI300_MARKET,
|
||||
):
|
||||
return {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
"instruments": instruments,
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
|
||||
def get_dataset_config(
|
||||
dataset_class=DATASET_ALPHA158_CLASS,
|
||||
train=("2008-01-01", "2014-12-31"),
|
||||
valid=("2015-01-01", "2016-12-31"),
|
||||
test=("2017-01-01", "2020-08-01"),
|
||||
handler_kwargs={"instruments": CSI300_MARKET},
|
||||
):
|
||||
return {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
@@ -61,48 +73,88 @@ def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLAS
|
||||
"handler": {
|
||||
"class": dataset_class,
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": get_data_handler_config(market),
|
||||
"kwargs": get_data_handler_config(**handler_kwargs),
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
"train": train,
|
||||
"valid": valid,
|
||||
"test": test,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_gbdt_task(market=CSI300_MARKET):
|
||||
def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": GBDT_MODEL,
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
}
|
||||
|
||||
|
||||
def get_record_lgb_config(market=CSI300_MARKET):
|
||||
def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_record_xgboost_config(market=CSI300_MARKET):
|
||||
def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
|
||||
CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET})
|
||||
CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET})
|
||||
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET})
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET})
|
||||
|
||||
# use for rolling_online_managment.py
|
||||
ROLLING_HANDLER_CONFIG = {
|
||||
"start_time": "2013-01-01",
|
||||
"end_time": "2020-09-25",
|
||||
"fit_start_time": "2013-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": CSI100_MARKET,
|
||||
}
|
||||
ROLLING_DATASET_CONFIG = {
|
||||
"train": ("2013-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2015-12-31"),
|
||||
"test": ("2016-01-01", "2020-07-10"),
|
||||
}
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config(
|
||||
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
|
||||
)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config(
|
||||
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
|
||||
)
|
||||
|
||||
# use for online_management_simulate.py
|
||||
ONLINE_HANDLER_CONFIG = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": CSI100_MARKET,
|
||||
}
|
||||
ONLINE_DATASET_CONFIG = {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
}
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config(
|
||||
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
|
||||
)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config(
|
||||
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
|
||||
)
|
||||
|
||||
@@ -642,6 +642,28 @@ def split_pred(pred, number=None, split_date=None):
|
||||
return pred_left, pred_right
|
||||
|
||||
|
||||
def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:
|
||||
"""
|
||||
Time slicing in Qlib or Pandas is a frequently-used action.
|
||||
However, user often input all kinds of data format to represent time.
|
||||
This function will help user to convert these inputs into a uniform format which is friendly to time slicing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : Union[None, str, pd.Timestamp]
|
||||
original time
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[None, pd.Timestamp]:
|
||||
"""
|
||||
if t is None:
|
||||
# None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]).
|
||||
return t
|
||||
else:
|
||||
return pd.Timestamp(t)
|
||||
|
||||
|
||||
def can_use_cache():
|
||||
res = True
|
||||
r = get_redis_connection()
|
||||
|
||||
17
qlib/utils/exceptions.py
Normal file
17
qlib/utils/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Base exception class
|
||||
class QlibException(Exception):
|
||||
def __init__(self, message):
|
||||
super(QlibException, self).__init__(message)
|
||||
|
||||
|
||||
# Error type for reinitialization when starting an experiment
|
||||
class RecorderInitializationError(QlibException):
|
||||
pass
|
||||
|
||||
|
||||
# Error type for Recorder when can not load object
|
||||
class LoadObjectError(QlibException):
|
||||
pass
|
||||
@@ -92,16 +92,16 @@ class Serializable:
|
||||
@classmethod
|
||||
def load(cls, filepath):
|
||||
"""
|
||||
Load the collector from a filepath.
|
||||
Load the serializable class from a filepath.
|
||||
|
||||
Args:
|
||||
filepath (str): the path of file
|
||||
|
||||
Raises:
|
||||
TypeError: the pickled file must be `Collector`
|
||||
TypeError: the pickled file must be `type(cls)`
|
||||
|
||||
Returns:
|
||||
Collector: the instance of Collector
|
||||
`type(cls)`: the instance of `type(cls)`
|
||||
"""
|
||||
with open(filepath, "rb") as f:
|
||||
object = cls.get_backend().load(f)
|
||||
|
||||
@@ -7,6 +7,7 @@ from .expm import MLflowExpManager
|
||||
from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
from ..utils import Wrapper
|
||||
from ..utils.exceptions import RecorderInitializationError
|
||||
|
||||
|
||||
class QlibRecorder:
|
||||
@@ -215,9 +216,9 @@ class QlibRecorder:
|
||||
-------
|
||||
A dictionary (id -> recorder) of recorder information that being stored.
|
||||
"""
|
||||
return self.get_exp(experiment_id, experiment_name).list_recorders()
|
||||
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
|
||||
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
|
||||
"""
|
||||
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
|
||||
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
|
||||
@@ -262,7 +263,7 @@ class QlibRecorder:
|
||||
|
||||
# Case 2
|
||||
with R.start('test'):
|
||||
exp = R.get_exp('test1')
|
||||
exp = R.get_exp(experiment_name='test1')
|
||||
|
||||
# Case 3
|
||||
exp = R.get_exp() -> a default experiment.
|
||||
@@ -287,7 +288,9 @@ class QlibRecorder:
|
||||
-------
|
||||
An experiment instance with given id or name.
|
||||
"""
|
||||
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
|
||||
return self.exp_manager.get_exp(
|
||||
experiment_id=experiment_id, experiment_name=experiment_name, create=create, start=False
|
||||
)
|
||||
|
||||
def delete_exp(self, experiment_id=None, experiment_name=None):
|
||||
"""
|
||||
@@ -331,7 +334,9 @@ class QlibRecorder:
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
|
||||
def get_recorder(
|
||||
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
|
||||
) -> Recorder:
|
||||
"""
|
||||
Method for retrieving a recorder.
|
||||
|
||||
@@ -384,7 +389,7 @@ class QlibRecorder:
|
||||
-------
|
||||
A recorder instance.
|
||||
"""
|
||||
return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
|
||||
return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder(
|
||||
recorder_id, recorder_name, create=False, start=False
|
||||
)
|
||||
|
||||
@@ -525,14 +530,29 @@ class QlibRecorder:
|
||||
self.get_exp().get_recorder().set_tags(**kwargs)
|
||||
|
||||
|
||||
class RecorderWrapper(Wrapper):
|
||||
"""
|
||||
Wrapper class for QlibRecorder, which detects whether users reinitialize qlib when already starting an experiment.
|
||||
"""
|
||||
|
||||
def register(self, provider):
|
||||
if self._provider is not None:
|
||||
expm = getattr(self._provider, "exp_manager")
|
||||
if expm.active_experiment is not None:
|
||||
raise RecorderInitializationError(
|
||||
"Please don't reinitialize Qlib if QlibRecorder is already acivated. Otherwise, the experiment stored location will be modified."
|
||||
)
|
||||
self._provider = provider
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated
|
||||
|
||||
QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper]
|
||||
QlibRecorderWrapper = Annotated[QlibRecorder, RecorderWrapper]
|
||||
else:
|
||||
QlibRecorderWrapper = QlibRecorder
|
||||
|
||||
# global record
|
||||
R: QlibRecorderWrapper = Wrapper()
|
||||
R: QlibRecorderWrapper = RecorderWrapper()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union
|
||||
import mlflow, logging
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
@@ -213,11 +214,15 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_get_recorder` method")
|
||||
|
||||
def list_recorders(self):
|
||||
def list_recorders(self, **flt_kwargs):
|
||||
"""
|
||||
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
|
||||
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
|
||||
|
||||
flt_kwargs : dict
|
||||
filter recorders by conditions
|
||||
e.g. list_recorders(status=Recorder.STATUS_FI)
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary (id -> recorder) of recorder information that being stored.
|
||||
@@ -320,11 +325,21 @@ class MLflowExperiment(Experiment):
|
||||
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED):
|
||||
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
max_results : int
|
||||
the number limitation of the results
|
||||
status : str
|
||||
the criteria based on status to filter results.
|
||||
`None` indicates no filtering.
|
||||
"""
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
recorders[runs[i].info.run_id] = recorder
|
||||
if status is None or recorder.status == status:
|
||||
recorders[runs[i].info.run_id] = recorder
|
||||
|
||||
return recorders
|
||||
|
||||
@@ -109,7 +109,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
|
||||
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
|
||||
"""
|
||||
Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.
|
||||
|
||||
@@ -190,7 +190,7 @@ class ExpManager:
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
return self.create_exp(experiment_name), True
|
||||
|
||||
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
|
||||
@@ -352,6 +352,8 @@ class MLflowExpManager(ExpManager):
|
||||
), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder."
|
||||
if experiment_id is not None:
|
||||
try:
|
||||
# NOTE: the mlflow's experiment_id must be str type...
|
||||
# https://www.mlflow.org/docs/latest/python_api/mlflow.tracking.html#mlflow.tracking.MlflowClient.get_experiment
|
||||
exp = self.client.get_experiment(experiment_id)
|
||||
if exp.lifecycle_stage.upper() == "DELETED":
|
||||
raise MlflowException("No valid experiment has been found.")
|
||||
|
||||
@@ -6,7 +6,7 @@ OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run
|
||||
|
||||
With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
|
||||
In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
|
||||
So this module provides a series of methods to control this process.
|
||||
So this module provides a series of methods to control this process.
|
||||
|
||||
This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
|
||||
Which means you can verify your strategy or find a better one.
|
||||
@@ -18,10 +18,12 @@ There are 4 total situations for using different trainers in different situation
|
||||
========================= ===================================================================================
|
||||
Situations Description
|
||||
========================= ===================================================================================
|
||||
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
|
||||
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
|
||||
will train models task by task and strategy by strategy.
|
||||
|
||||
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
|
||||
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
|
||||
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
|
||||
nothing until all tasks have been prepared. It makes user can train all tasks in
|
||||
the end of `routine` or `first_train`.
|
||||
|
||||
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
|
||||
need to consider using Trainer. This means it will REAL train your models in
|
||||
@@ -29,7 +31,7 @@ Simulation + Trainer When your models have some temporal dependence on the
|
||||
|
||||
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
|
||||
for the ability to multitasking. It means all tasks in all routines
|
||||
can be REAL trained at the end of simulating. The signals will be prepared well at
|
||||
can be REAL trained at the end of simulating. The signals will be prepared well at
|
||||
different time segments (based on whether or not any new model is online).
|
||||
========================= ===================================================================================
|
||||
"""
|
||||
@@ -103,17 +105,23 @@ class OnlineManager(Serializable):
|
||||
"""
|
||||
if strategies is None:
|
||||
strategies = self.strategies
|
||||
for strategy in strategies:
|
||||
|
||||
models_list = []
|
||||
for strategy in strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
|
||||
tasks = strategy.first_tasks()
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
|
||||
# FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the
|
||||
# start.
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
for strategy, models in zip(strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
|
||||
def routine(
|
||||
self,
|
||||
cur_time: Union[str, pd.Timestamp] = None,
|
||||
@@ -139,33 +147,41 @@ class OnlineManager(Serializable):
|
||||
cur_time = D.calendar(freq=self.freq).max()
|
||||
self.cur_time = pd.Timestamp(cur_time) # None for latest date
|
||||
|
||||
models_list = []
|
||||
for strategy in self.strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
models = self.trainer.train(tasks)
|
||||
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.trainer.is_delay():
|
||||
# The online model may changes in the above processes
|
||||
# So updating the predictions of online models should be the last step
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
for strategy, models in zip(self.strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
|
||||
def get_collector(self) -> MergeCollector:
|
||||
def get_collector(self, **kwargs) -> MergeCollector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
|
||||
This collector can be a basis as the signals preparation.
|
||||
|
||||
Args:
|
||||
**kwargs: the params for get_collector.
|
||||
|
||||
Returns:
|
||||
MergeCollector: the collector to merge other collectors.
|
||||
"""
|
||||
collector_dict = {}
|
||||
for strategy in self.strategies:
|
||||
collector_dict[strategy.name_id] = strategy.get_collector()
|
||||
collector_dict[strategy.name_id] = strategy.get_collector(**kwargs)
|
||||
return MergeCollector(collector_dict, process_list=[])
|
||||
|
||||
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
|
||||
@@ -225,7 +241,7 @@ class OnlineManager(Serializable):
|
||||
SIM_LOG_NAME = "SIMULATE_INFO"
|
||||
|
||||
def simulate(
|
||||
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
|
||||
self, end_time=None, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
|
||||
) -> Union[pd.Series, pd.DataFrame]:
|
||||
"""
|
||||
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
|
||||
@@ -297,6 +313,7 @@ class OnlineManager(Serializable):
|
||||
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
if signals_time > cur_time:
|
||||
# FIXME: if use DelayTrainer and worker (and worker is faster than main progress), there are some possibilities of showing this warning.
|
||||
self.logger.warn(
|
||||
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
|
||||
)
|
||||
|
||||
@@ -52,6 +52,12 @@ class OnlineStrategy:
|
||||
|
||||
NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
|
||||
|
||||
**NOTE**:
|
||||
Current implementation is very naive. Here is a more complex situation which is more closer to the
|
||||
practical scenarios.
|
||||
1. Train new models at the day before `test_start` (at time stamp `T`)
|
||||
2. Switch models at the `test_start` (at time timestamp `T + 1` typically)
|
||||
|
||||
Args:
|
||||
models (list): a list of models.
|
||||
cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.
|
||||
|
||||
@@ -135,10 +135,9 @@ class PredUpdater(RecordUpdater):
|
||||
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
|
||||
# https://github.com/pytorch/pytorch/issues/16797
|
||||
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
if start_time >= self.to_date:
|
||||
if self.last_end >= self.to_date:
|
||||
self.logger.info(
|
||||
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
|
||||
f"The prediction in {self.record.info['id']} are latest ({self.last_end}). No need to update to {self.to_date}."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -8,8 +8,11 @@ This allows us to use efficient submodels as the market-style changing.
|
||||
"""
|
||||
|
||||
from typing import List, Union
|
||||
from qlib.data.dataset import TSDatasetH
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import get_cls_kwargs
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
@@ -88,15 +91,15 @@ class OnlineToolR(OnlineTool):
|
||||
The implementation of OnlineTool based on (R)ecorder.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str):
|
||||
def __init__(self, default_exp_name: str = None):
|
||||
"""
|
||||
Init OnlineToolR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
default_exp_name (str): the default experiment name.
|
||||
"""
|
||||
super().__init__()
|
||||
self.exp_name = experiment_name
|
||||
self.default_exp_name = default_exp_name
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
"""
|
||||
@@ -125,44 +128,68 @@ class OnlineToolR(OnlineTool):
|
||||
tags = recorder.list_tags()
|
||||
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
|
||||
|
||||
def reset_online_tag(self, recorder: Union[Recorder, List]):
|
||||
def reset_online_tag(self, recorder: Union[Recorder, List], exp_name: str = None):
|
||||
"""
|
||||
Offline all models and set the recorders to 'online'.
|
||||
|
||||
Args:
|
||||
recorder (Union[Recorder, List]):
|
||||
the recorder you want to reset to 'online'.
|
||||
exp_name (str): the experiment name. If None, then use default_exp_name.
|
||||
|
||||
"""
|
||||
exp_name = self._get_exp_name(exp_name)
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
recs = list_recorders(self.exp_name)
|
||||
recs = list_recorders(exp_name)
|
||||
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
|
||||
self.set_online_tag(self.ONLINE_TAG, recorder)
|
||||
|
||||
def online_models(self) -> list:
|
||||
def online_models(self, exp_name: str = None) -> list:
|
||||
"""
|
||||
Get current `online` models
|
||||
|
||||
Args:
|
||||
exp_name (str): the experiment name. If None, then use default_exp_name.
|
||||
|
||||
Returns:
|
||||
list: a list of `online` models.
|
||||
"""
|
||||
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
|
||||
exp_name = self._get_exp_name(exp_name)
|
||||
return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
|
||||
|
||||
def update_online_pred(self, to_date=None):
|
||||
def update_online_pred(self, to_date=None, exp_name: str = None):
|
||||
"""
|
||||
Update the predictions of online models to to_date.
|
||||
|
||||
Args:
|
||||
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
|
||||
exp_name (str): the experiment name. If None, then use default_exp_name.
|
||||
"""
|
||||
online_models = self.online_models()
|
||||
exp_name = self._get_exp_name(exp_name)
|
||||
online_models = self.online_models(exp_name=exp_name)
|
||||
for rec in online_models:
|
||||
hist_ref = 0
|
||||
task = rec.load_object("task")
|
||||
# Special treatment of historical dependencies
|
||||
if task["dataset"]["class"] == "TSDatasetH":
|
||||
hist_ref = task["dataset"]["kwargs"]["step_len"]
|
||||
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
|
||||
cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset")
|
||||
if issubclass(cls, TSDatasetH):
|
||||
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
|
||||
try:
|
||||
updater = PredUpdater(rec, to_date=to_date, hist_ref=hist_ref)
|
||||
except LoadObjectError as e:
|
||||
# skip the recorder without pred
|
||||
self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.")
|
||||
continue
|
||||
updater.update()
|
||||
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {exp_name}.")
|
||||
|
||||
def _get_exp_name(self, exp_name):
|
||||
if exp_name is None:
|
||||
if self.default_exp_name is None:
|
||||
raise ValueError(
|
||||
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
|
||||
)
|
||||
exp_name = self.default_exp_name
|
||||
return exp_name
|
||||
|
||||
@@ -227,10 +227,11 @@ class SigAnaRecord(SignalRecord):
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, **kwargs):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
self.ana_long_short = ana_long_short
|
||||
self.ann_scaler = ann_scaler
|
||||
self.label_col = label_col
|
||||
|
||||
def generate(self, **kwargs):
|
||||
try:
|
||||
@@ -243,7 +244,7 @@ class SigAnaRecord(SignalRecord):
|
||||
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
|
||||
logger.warn(f"Empty label.")
|
||||
return
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, self.label_col])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
"ICIR": ic.mean() / ic.std(),
|
||||
@@ -252,7 +253,7 @@ class SigAnaRecord(SignalRecord):
|
||||
}
|
||||
objects = {"ic.pkl": ic, "ric.pkl": ric}
|
||||
if self.ana_long_short:
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, self.label_col])
|
||||
metrics.update(
|
||||
{
|
||||
"Long-Short Ann Return": long_short_r.mean() * self.ann_scaler,
|
||||
|
||||
@@ -5,6 +5,8 @@ import mlflow, logging
|
||||
import shutil, os, pickle, tempfile, codecs, pickle
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from ..utils.objm import FileManager
|
||||
from ..log import get_module_logger
|
||||
|
||||
@@ -307,10 +309,26 @@ class MLflowRecorder(Recorder):
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def load_object(self, name):
|
||||
"""
|
||||
Load object such as prediction file or model checkpoint in mlflow.
|
||||
|
||||
Args:
|
||||
name (str): the object name
|
||||
|
||||
Raises:
|
||||
LoadObjectError: if raise some exceptions when load the object
|
||||
|
||||
Returns:
|
||||
object: the saved object in mlflow.
|
||||
"""
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
path = self.client.download_artifacts(self.id, name)
|
||||
with Path(path).open("rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
try:
|
||||
path = self.client.download_artifacts(self.id, name)
|
||||
with Path(path).open("rb") as f:
|
||||
return pickle.load(f)
|
||||
except Exception as e:
|
||||
raise LoadObjectError(message=str(e))
|
||||
|
||||
def log_params(self, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
|
||||
@@ -6,6 +6,7 @@ Collector module can collect objects from everywhere and process them such as me
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, List
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow import R
|
||||
|
||||
@@ -192,6 +193,7 @@ class RecorderCollector(Collector):
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
|
||||
logger = get_module_logger("RecorderCollector")
|
||||
for _, rec in recs_flt.items():
|
||||
rec_key = self.rec_key_func(rec)
|
||||
for key in artifacts_key:
|
||||
@@ -205,7 +207,13 @@ class RecorderCollector(Collector):
|
||||
# only collect existing artifact
|
||||
continue
|
||||
raise e
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
# give user some warning if the values are overridden
|
||||
cdd = collect_dict.setdefault(key, {})
|
||||
if rec_key in cdd:
|
||||
logger.warning(
|
||||
f"key '{rec_key}' is duplicated. Previous value will be overrides. Please check you `rec_key_func`"
|
||||
)
|
||||
cdd[rec_key] = artifact
|
||||
|
||||
return collect_dict
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp
|
||||
import abc
|
||||
import copy
|
||||
from typing import List, Union, Callable
|
||||
|
||||
from qlib.utils import transform_end_date
|
||||
from .utils import TimeAdjuster
|
||||
|
||||
|
||||
@@ -199,7 +201,7 @@ class RollingGen(TaskGen):
|
||||
# First rolling
|
||||
# 1) prepare the end point
|
||||
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
test_end = transform_end_date(segments[self.test_key][1])
|
||||
# 2) and init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
|
||||
@@ -69,28 +69,29 @@ class TaskManager:
|
||||
|
||||
ENCODE_FIELDS_PREFIX = ["def", "res"]
|
||||
|
||||
def __init__(self, task_pool: str = None):
|
||||
def __init__(self, task_pool: str):
|
||||
"""
|
||||
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
|
||||
A TaskManager instance serves a specific task pool.
|
||||
The static method of this module serves the whole MongoDB.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
"""
|
||||
self.mdb = get_mongodb()
|
||||
if task_pool is not None:
|
||||
self.task_pool = getattr(self.mdb, task_pool)
|
||||
self.task_pool = getattr(get_mongodb(), task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def list(self) -> list:
|
||||
@staticmethod
|
||||
def list() -> list:
|
||||
"""
|
||||
List the all collection(task_pool) of the db
|
||||
List the all collection(task_pool) of the db.
|
||||
|
||||
Returns:
|
||||
list
|
||||
"""
|
||||
return self.mdb.list_collection_names()
|
||||
return get_mongodb().list_collection_names()
|
||||
|
||||
def _encode_task(self, task):
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
@@ -109,6 +110,25 @@ class TaskManager:
|
||||
def _dict_to_str(self, flt):
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def _decode_query(self, query):
|
||||
"""
|
||||
If the query includes any `_id`, then it needs `ObjectId` to decode.
|
||||
For example, when using TrainerRM, it needs query `{"_id": {"$in": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`.
|
||||
|
||||
Args:
|
||||
query (dict): query dict. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
dict: the query after decoding.
|
||||
"""
|
||||
if "_id" in query:
|
||||
if isinstance(query["_id"], dict):
|
||||
for key in query["_id"]:
|
||||
query["_id"][key] = [ObjectId(i) for i in query["_id"][key]]
|
||||
else:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
return query
|
||||
|
||||
def replace_task(self, task, new_task):
|
||||
"""
|
||||
Use a new task to replace a old one
|
||||
@@ -224,8 +244,7 @@ class TaskManager:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
query.update({"status": status})
|
||||
task = self.task_pool.find_one_and_update(
|
||||
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
|
||||
@@ -253,10 +272,10 @@ class TaskManager:
|
||||
task = self.fetch_task(query=query, status=status)
|
||||
try:
|
||||
yield task
|
||||
except Exception:
|
||||
except (Exception, KeyboardInterrupt): # KeyboardInterrupt is not a subclass of Exception
|
||||
if task is not None:
|
||||
self.logger.info("Returning task before raising error")
|
||||
self.return_task(task)
|
||||
self.return_task(task, status=status) # return task as the original status
|
||||
self.logger.info("Task returned")
|
||||
raise
|
||||
|
||||
@@ -283,12 +302,11 @@ class TaskManager:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
for t in self.task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def re_query(self, _id):
|
||||
def re_query(self, _id) -> dict:
|
||||
"""
|
||||
Use _id to query task.
|
||||
|
||||
@@ -339,8 +357,7 @@ class TaskManager:
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
self.task_pool.delete_many(query)
|
||||
|
||||
def task_stat(self, query={}) -> dict:
|
||||
@@ -354,8 +371,7 @@ class TaskManager:
|
||||
dict
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
tasks = self.query(query=query, decode=False)
|
||||
status_stat = {}
|
||||
for t in tasks:
|
||||
@@ -377,8 +393,7 @@ class TaskManager:
|
||||
|
||||
def reset_status(self, query, status):
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
def prioritize(self, task, priority: int):
|
||||
@@ -396,15 +411,29 @@ class TaskManager:
|
||||
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
|
||||
def _get_undone_n(self, task_stat):
|
||||
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
|
||||
return (
|
||||
task_stat.get(self.STATUS_WAITING, 0)
|
||||
+ task_stat.get(self.STATUS_RUNNING, 0)
|
||||
+ task_stat.get(self.STATUS_PART_DONE, 0)
|
||||
)
|
||||
|
||||
def _get_total(self, task_stat):
|
||||
return sum(task_stat.values())
|
||||
|
||||
def wait(self, query={}):
|
||||
"""
|
||||
When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks.
|
||||
So main progress should wait until all tasks are trained well by other progress or machines.
|
||||
|
||||
Args:
|
||||
query (dict, optional): the query dict. Defaults to {}.
|
||||
"""
|
||||
task_stat = self.task_stat(query)
|
||||
total = self._get_total(task_stat)
|
||||
last_undone_n = self._get_undone_n(task_stat)
|
||||
if last_undone_n == 0:
|
||||
return
|
||||
self.logger.warning(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
|
||||
with tqdm(total=total, initial=total - last_undone_n) as pbar:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
||||
@@ -17,7 +17,6 @@ def experiment_exit_handler():
|
||||
Thus, if any exception or user interuption occurs beforehead, we should handle them first. Once `R` is
|
||||
ended, another call of `R.end_exp` will not take effect.
|
||||
"""
|
||||
signal.signal(signal.SIGINT, experiment_kill_signal_handler) # handle user keyboard interupt
|
||||
sys.excepthook = experiment_exception_hook # handle uncaught exception
|
||||
atexit.register(R.end_exp, recorder_status=Recorder.STATUS_FI) # will not take effect if experiment ends
|
||||
|
||||
@@ -39,10 +38,3 @@ def experiment_exception_hook(type, value, tb):
|
||||
print(f"{type.__name__}: {value}")
|
||||
|
||||
R.end_exp(recorder_status=Recorder.STATUS_FA)
|
||||
|
||||
|
||||
def experiment_kill_signal_handler(signum, frame):
|
||||
"""
|
||||
End an experiment when user kill the program through keyboard (CTRL+C, etc.).
|
||||
"""
|
||||
R.end_exp(recorder_status=Recorder.STATUS_FA)
|
||||
|
||||
@@ -7,12 +7,13 @@ import time
|
||||
import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from typing import Type, Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from joblib import Parallel, delayed
|
||||
from qlib.utils import code_to_fname
|
||||
|
||||
|
||||
@@ -22,9 +23,9 @@ class BaseCollector(abc.ABC):
|
||||
NORMAL_FLAG = "NORMAL"
|
||||
|
||||
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
||||
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1)).date()
|
||||
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)).date()
|
||||
DEFAULT_END_DATETIME_1MIN = DEFAULT_END_DATETIME_1D
|
||||
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
@@ -35,10 +36,10 @@ class BaseCollector(abc.ABC):
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_workers=1,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
@@ -48,7 +49,7 @@ class BaseCollector(abc.ABC):
|
||||
save_dir: str
|
||||
instrument save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
@@ -59,8 +60,8 @@ class BaseCollector(abc.ABC):
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
@@ -72,7 +73,7 @@ class BaseCollector(abc.ABC):
|
||||
self.max_collector_count = max_collector_count
|
||||
self.mini_symbol_map = {}
|
||||
self.interval = interval
|
||||
self.check_small_data = check_data_length
|
||||
self.check_data_length = max(int(check_data_length) if check_data_length is not None else 0, 0)
|
||||
|
||||
self.start_datetime = self.normalize_start_datetime(start)
|
||||
self.end_datetime = self.normalize_end_datetime(end)
|
||||
@@ -99,14 +100,6 @@ class BaseCollector(abc.ABC):
|
||||
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_instrument_list(self):
|
||||
raise NotImplementedError("rewrite get_instrument_list")
|
||||
@@ -132,7 +125,7 @@ class BaseCollector(abc.ABC):
|
||||
|
||||
Returns
|
||||
---------
|
||||
pd.DataFrame, "symbol" in pd.columns
|
||||
pd.DataFrame, "symbol" and "date"in pd.columns
|
||||
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
@@ -151,7 +144,7 @@ class BaseCollector(abc.ABC):
|
||||
self.sleep()
|
||||
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
|
||||
_result = self.NORMAL_FLAG
|
||||
if self.check_small_data:
|
||||
if self.check_data_length > 0:
|
||||
_result = self.cache_small_data(symbol, df)
|
||||
if _result == self.NORMAL_FLAG:
|
||||
self.save_instrument(symbol, df)
|
||||
@@ -181,8 +174,8 @@ class BaseCollector(abc.ABC):
|
||||
df.to_csv(instrument_path, index=False)
|
||||
|
||||
def cache_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
if len(df) < self.check_data_length:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.check_data_length}!")
|
||||
_temp = self.mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return self.CACHE_FLAG
|
||||
@@ -194,12 +187,12 @@ class BaseCollector(abc.ABC):
|
||||
def _collector(self, instrument_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(instrument_list)) as p_bar:
|
||||
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
res = Parallel(n_jobs=self.max_workers)(
|
||||
delayed(self._simple_collector)(_inst) for _inst in tqdm(instrument_list)
|
||||
)
|
||||
for _symbol, _result in zip(instrument_list, res):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(instrument_list)}")
|
||||
@@ -217,20 +210,16 @@ class BaseCollector(abc.ABC):
|
||||
instrument_list = self._collector(instrument_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self.mini_symbol_map.items():
|
||||
self.save_instrument(
|
||||
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
|
||||
)
|
||||
_df = pd.concat(_df_list, sort=False)
|
||||
if not _df.empty:
|
||||
self.save_instrument(_symbol, _df.drop_duplicates(["date"]).sort_values(["date"]))
|
||||
if self.mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.warning(f"less than {self.check_data_length} instrument list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
|
||||
|
||||
|
||||
class BaseNormalize(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -242,7 +231,7 @@ class BaseNormalize(abc.ABC):
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self.kwargs = kwargs
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -251,7 +240,7 @@ class BaseNormalize(abc.ABC):
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
@@ -265,6 +254,7 @@ class Normalize:
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -288,16 +278,23 @@ class Normalize:
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
self._end_date = kwargs.get("end_date", None)
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
self._normalize_obj = normalize_class(
|
||||
date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs
|
||||
)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
if df is not None and not df.empty:
|
||||
if self._end_date is not None:
|
||||
_mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date)
|
||||
df = df[_mask]
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
@@ -311,7 +308,7 @@ class Normalize:
|
||||
|
||||
|
||||
class BaseRun(abc.ABC):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -321,7 +318,7 @@ class BaseRun(abc.ABC):
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
Concurrent number, default is 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
"""
|
||||
@@ -361,7 +358,7 @@ class BaseRun(abc.ABC):
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
check_data_length: int = None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
@@ -378,8 +375,8 @@ class BaseRun(abc.ABC):
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
@@ -404,7 +401,7 @@ class BaseRun(abc.ABC):
|
||||
limit_nums=limit_nums,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
@@ -426,5 +423,6 @@ class BaseRun(abc.ABC):
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
**kwargs,
|
||||
)
|
||||
yc.normalize()
|
||||
|
||||
@@ -19,12 +19,31 @@ CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
from data_collector.index import IndexBase
|
||||
from data_collector.utils import get_calendar_list, get_trading_date_by_shift
|
||||
from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
|
||||
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
|
||||
# INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
# 2020-11-27 Announcement title change
|
||||
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89"
|
||||
|
||||
REQ_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.101 Safari/537.36 Edg/91.0.864.48"
|
||||
}
|
||||
|
||||
|
||||
@deco_retry
|
||||
def retry_request(url: str, method: str = "get", exclude_status: List = None):
|
||||
if exclude_status is None:
|
||||
exclude_status = []
|
||||
method_func = getattr(requests, method)
|
||||
_resp = method_func(url, headers=REQ_HEADERS)
|
||||
_status = _resp.status_code
|
||||
if _status not in exclude_status and _status != 200:
|
||||
raise ValueError(f"response status: {_status}, url={url}")
|
||||
return _resp
|
||||
|
||||
|
||||
class CSIIndex(IndexBase):
|
||||
@@ -134,9 +153,8 @@ class CSIIndex(IndexBase):
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
resp = requests.get(url)
|
||||
resp = retry_request(url)
|
||||
_text = resp.text
|
||||
|
||||
date_list = re.findall(r"(\d{4}).*?年.*?(\d+).*?月.*?(\d+).*?日", _text)
|
||||
if len(date_list) >= 2:
|
||||
add_date = pd.Timestamp("-".join(date_list[0]))
|
||||
@@ -147,7 +165,7 @@ class CSIIndex(IndexBase):
|
||||
logger.info(f"get {add_date} changes")
|
||||
try:
|
||||
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
|
||||
content = requests.get(f"http://www.csindex.com.cn{excel_url}").content
|
||||
content = retry_request(f"http://www.csindex.com.cn{excel_url}", exclude_status=[404]).content
|
||||
_io = BytesIO(content)
|
||||
df_map = pd.read_excel(_io, sheet_name=None)
|
||||
with self.cache_dir.joinpath(
|
||||
@@ -201,7 +219,7 @@ class CSIIndex(IndexBase):
|
||||
-------
|
||||
[url1, url2]
|
||||
"""
|
||||
resp = requests.get(self.changes_url)
|
||||
resp = retry_request(self.changes_url)
|
||||
html = etree.HTML(resp.text)
|
||||
return html.xpath("//*[@id='itemContainer']//li/a/@href")
|
||||
|
||||
@@ -221,7 +239,7 @@ class CSIIndex(IndexBase):
|
||||
end_date: pd.Timestamp
|
||||
"""
|
||||
logger.info("get new companies......")
|
||||
context = requests.get(self.new_companies_url).content
|
||||
context = retry_request(self.new_companies_url).content
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
|
||||
).open("wb") as fp:
|
||||
@@ -292,7 +310,7 @@ def get_instruments(
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
_cur_module = importlib.import_module("collector")
|
||||
_cur_module = importlib.import_module("data_collector.cn_index.collector")
|
||||
obj = getattr(_cur_module, f"{index_name.upper()}")(
|
||||
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
|
||||
23
scripts/data_collector/contrib/fill_cn_1min_data/README.md
Normal file
23
scripts/data_collector/contrib/fill_cn_1min_data/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Use 1d data to fill in the missing symbols relative to 1min
|
||||
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## fill 1min data
|
||||
|
||||
```bash
|
||||
python fill_1min_using_1d.py --data_1min_dir ~/.qlib/csv_data/cn_data_1min --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- ata_1min_dir: csv data
|
||||
- qlib_data_1d_dir: qlib data directory
|
||||
- max_workers: `ThreadPoolExecutor(max_workers=max_workers)`, by default *16*
|
||||
- date_field_name: date field name, by default *date*
|
||||
- symbol_field_name: symbol field name, by default *symbol*
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from qlib.data import D
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent.parent))
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
|
||||
|
||||
def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: str = "date"):
|
||||
csv_files = list(data_1min_dir.glob("*.csv"))
|
||||
min_date = None
|
||||
max_date = None
|
||||
with tqdm(total=len(csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for _file, _result in zip(csv_files, executor.map(pd.read_csv, csv_files)):
|
||||
if not _result.empty:
|
||||
_dates = pd.to_datetime(_result[date_field_name])
|
||||
|
||||
_tmp_min = _dates.min()
|
||||
min_date = min(min_date, _tmp_min) if min_date is not None else _tmp_min
|
||||
_tmp_max = _dates.max()
|
||||
max_date = max(max_date, _tmp_max) if max_date is not None else _tmp_max
|
||||
p_bar.update()
|
||||
return min_date, max_date
|
||||
|
||||
|
||||
def get_symbols(data_1min_dir: Path):
|
||||
return list(map(lambda x: x.name[:-4].upper(), data_1min_dir.glob("*.csv")))
|
||||
|
||||
|
||||
def fill_1min_using_1d(
|
||||
data_1min_dir: [str, Path],
|
||||
qlib_data_1d_dir: [str, Path],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""Use 1d data to fill in the missing symbols relative to 1min
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_1min_dir: str
|
||||
1min data dir
|
||||
qlib_data_1d_dir: str
|
||||
1d qlib data(bin data) dir, from: https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format
|
||||
max_workers: int
|
||||
ThreadPoolExecutor(max_workers), by default 16
|
||||
date_field_name: str
|
||||
date field name, by default date
|
||||
symbol_field_name: str
|
||||
symbol field name, by default symbol
|
||||
|
||||
"""
|
||||
data_1min_dir = Path(data_1min_dir).expanduser().resolve()
|
||||
qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve()
|
||||
|
||||
min_date, max_date = get_date_range(data_1min_dir, max_workers, date_field_name)
|
||||
symbols_1min = get_symbols(data_1min_dir)
|
||||
|
||||
qlib.init(provider_uri=str(qlib_data_1d_dir))
|
||||
data_1d = D.features(D.instruments("all"), ["$close"], min_date, max_date, freq="day")
|
||||
|
||||
miss_symbols = set(data_1d.index.get_level_values(level="instrument").unique()) - set(symbols_1min)
|
||||
if not miss_symbols:
|
||||
logger.warning("More symbols in 1min than 1d, no padding required")
|
||||
return
|
||||
|
||||
logger.info(f"miss_symbols {len(miss_symbols)}: {miss_symbols}")
|
||||
tmp_df = pd.read_csv(list(data_1min_dir.glob("*.csv"))[0])
|
||||
columns = tmp_df.columns
|
||||
_si = tmp_df[symbol_field_name].first_valid_index()
|
||||
is_lower = tmp_df.loc[_si][symbol_field_name].islower()
|
||||
for symbol in tqdm(miss_symbols):
|
||||
if is_lower:
|
||||
symbol = symbol.lower()
|
||||
index_1d = data_1d.loc(axis=0)[symbol.upper()].index
|
||||
index_1min = generate_minutes_calendar_from_daily(index_1d)
|
||||
index_1min.name = date_field_name
|
||||
_df = pd.DataFrame(columns=columns, index=index_1min)
|
||||
if date_field_name in _df.columns:
|
||||
del _df[date_field_name]
|
||||
_df.reset_index(inplace=True)
|
||||
_df[symbol_field_name] = symbol
|
||||
_df["paused_num"] = 0
|
||||
_df.to_csv(data_1min_dir.joinpath(f"{symbol}.csv"), index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(fill_1min_using_1d)
|
||||
@@ -0,0 +1,5 @@
|
||||
fire
|
||||
pandas
|
||||
loguru
|
||||
tqdm
|
||||
pyqlib
|
||||
@@ -14,7 +14,7 @@ from loguru import logger
|
||||
import baostock as bs
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
sys.path.append(str(CUR_DIR.parent.parent.parent))
|
||||
|
||||
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
@@ -3,18 +3,13 @@
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import copy
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from dateutil.tz import tzlocal
|
||||
@@ -38,7 +33,7 @@ class FundCollector(BaseCollector):
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
@@ -59,8 +54,8 @@ class FundCollector(BaseCollector):
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
@@ -168,9 +163,7 @@ class FundollectorCN(FundCollector, ABC):
|
||||
|
||||
|
||||
class FundCollectorCN1d(FundollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
pass
|
||||
|
||||
|
||||
class FundNormalize(BaseNormalize):
|
||||
@@ -261,7 +254,7 @@ class Run(BaseRun):
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
check_data_length: int = None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
@@ -278,8 +271,8 @@ class Run(BaseRun):
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool # if this param useful?
|
||||
check data length, by default False
|
||||
check_data_length: int # if this param useful?
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
|
||||
@@ -271,7 +271,7 @@ def get_instruments(
|
||||
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
_cur_module = importlib.import_module("collector")
|
||||
_cur_module = importlib.import_module("data_collector.us_index.collector")
|
||||
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
|
||||
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
import bisect
|
||||
import pickle
|
||||
@@ -10,7 +9,7 @@ import random
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple
|
||||
from typing import Iterable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -47,7 +46,7 @@ _CALENDAR_MAP = {}
|
||||
MINIMUM_SYMBOLS_NUM = 3900
|
||||
|
||||
|
||||
def get_calendar_list(bench_code="CSI300") -> list:
|
||||
def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
"""get SH/SZ history calendar list
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
|
||||
- [Collector Data](#collector-data)
|
||||
- [Get Qlib data](#get-qlib-databin-file)
|
||||
- [Collector *YahooFinance* data to qlib](#collector-yahoofinance-data-to-qlib)
|
||||
- [Automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
- [Using qlib data](#using-qlib-data)
|
||||
|
||||
|
||||
# Collect Data From Yahoo Finance
|
||||
|
||||
> *Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
|
||||
@@ -18,113 +26,170 @@ pip install -r requirements.txt
|
||||
|
||||
## Collector Data
|
||||
|
||||
### Get Qlib data(`bin file`)
|
||||
> `qlib-data` from *YahooFinance*, is the data that has been dumped and can be used directly in `qlib`
|
||||
|
||||
### CN Data
|
||||
- get data: `python scripts/get_data.py qlib_data`
|
||||
- parameters:
|
||||
- `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data*
|
||||
- `version`: dataset version, value from [`v1`, `v2`], by default `v1`
|
||||
- `v2` end date is *2021-06*, `v1` end date is *2020-09*
|
||||
- user can append data to `v2`: [automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
- **the [benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks) for qlib use `v1`**, *due to the unstable access to historical data by YahooFinance, there are some differences between `v2` and `v1`*
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
- `region`: `cn` or `us`, by default `cn`
|
||||
- `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`
|
||||
- `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`
|
||||
- examples:
|
||||
```bash
|
||||
# cn 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn
|
||||
# cn 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
|
||||
# us 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d --region us --interval 1d
|
||||
# us 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1min --region us --interval 1min
|
||||
```
|
||||
|
||||
#### 1d from yahoo
|
||||
### Collector *YahooFinance* data to qlib
|
||||
> collector *YahooFinance* data and *dump* into `qlib` format
|
||||
1. download data to csv: `python scripts/data_collector/yahoo/collector.py download_data`
|
||||
|
||||
```bash
|
||||
- parameters:
|
||||
- `source_dir`: save the directory
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
> **due to the limitation of the *YahooFinance API*, only the last month's data is available in `1min`**
|
||||
- `region`: `CN` or `US`, by default `CN`
|
||||
- `delay`: `time.sleep(delay)`, by default *0.5*
|
||||
- `start`: start datetime, by default *"2000-01-01"*; *closed interval(including start)*
|
||||
- `end`: end datetime, by default `pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`; *open interval(excluding end)*
|
||||
- `max_workers`: get the number of concurrent symbols, it is not recommended to change this parameter in order to maintain the integrity of the symbol data, by default *1*
|
||||
- `check_data_length`: check the number of rows per *symbol*, by default `None`
|
||||
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
|
||||
- `max_collector_count`: number of *"failed"* symbol retries, by default 2
|
||||
- examples:
|
||||
```bash
|
||||
# cn 1d data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US
|
||||
# cn 1min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --delay 1 --interval 1min --region CN
|
||||
# us 1d data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US
|
||||
# us 1min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1min --delay 1 --interval 1min --region US
|
||||
```
|
||||
2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data`
|
||||
|
||||
- parameters:
|
||||
- `source_dir`: csv directory
|
||||
- `normalize_dir`: result directory
|
||||
- `max_workers`: number of concurrent, by default *1*
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
> if **`interval == 1min`**, `qlib_data_1d_dir` cannot be `None`
|
||||
- `region`: `CN` or `US`, by default `CN`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`
|
||||
- `qlib_data_1d_dir`: qlib directory(1d data)
|
||||
```
|
||||
if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;
|
||||
|
||||
qlib_data_1d can be obtained like this:
|
||||
$ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d
|
||||
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
|
||||
or:
|
||||
download 1d data from YahooFinance
|
||||
|
||||
```
|
||||
- examples:
|
||||
```bash
|
||||
# normalize 1d cn
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1d --normalize_dir ~/.qlib/stock_data/source/cn_1d_nor --region CN --interval 1d
|
||||
# normalize 1min cn
|
||||
python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/qlib_cn_1d --source_dir ~/.qlib/stock_data/source/cn_1min --normalize_dir ~/.qlib/stock_data/source/cn_1min_nor --region CN --interval 1min
|
||||
```
|
||||
3. dump data: `python scripts/dump_bin.py dump_all`
|
||||
|
||||
- parameters:
|
||||
- `csv_path`: stock data path or directory, **normalize result(normalize_dir)**
|
||||
- `qlib_dir`: qlib(dump) data director
|
||||
- `freq`: transaction frequency, by default `day`
|
||||
> `freq_map = {1d:day, 1mih: 1min}`
|
||||
- `max_workers`: number of threads, by default *16*
|
||||
- `include_fields`: dump fields, by default `""`
|
||||
- `exclude_fields`: fields not dumped, by default `"""
|
||||
> dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- examples:
|
||||
```bash
|
||||
# dump 1d cn
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,symbol
|
||||
# dump 1min cn
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1min --freq 1min --exclude_fields date,symbol
|
||||
```
|
||||
|
||||
# download from yahoo finance
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
### Automatic update of daily frequency data(from yahoo finance)
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1d --normalize_dir ~/.qlib/stock_data/source/cn_1d_nor --region CN --interval 1d
|
||||
* Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
* use *crontab*: `crontab -e`
|
||||
* set up timed tasks:
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
|
||||
```
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
```
|
||||
* **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
```
|
||||
* Manual update of data
|
||||
```
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
```
|
||||
* `trading_date`: start of trading day
|
||||
* `end_date`: end of trading day(not included)
|
||||
* `check_data_length`: check the number of rows per *symbol*, by default `None`
|
||||
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
|
||||
|
||||
### 1d from qlib
|
||||
```bash
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn
|
||||
```
|
||||
|
||||
### using data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="cn")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
```
|
||||
|
||||
#### 1min from yahoo
|
||||
|
||||
```bash
|
||||
|
||||
# download from yahoo finance
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1min
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1min --normalize_dir ~/.qlib/stock_data/source/cn_1min_nor --region CN --interval 1min
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1min --freq 1min --exclude_fields date,adjclose,dividends,splits,symbol
|
||||
```
|
||||
|
||||
### 1min from qlib
|
||||
```bash
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --interval 1min --region cn
|
||||
```
|
||||
|
||||
### using data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="cn")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="1min")
|
||||
|
||||
```
|
||||
|
||||
### US Data
|
||||
|
||||
#### 1d from yahoo
|
||||
|
||||
```bash
|
||||
|
||||
# download from yahoo finance
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1d --region US --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/us_1d --normalize_dir ~/.qlib/stock_data/source/us_1d_nor --region US --interval 1d
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/us_1d_nor --qlib_dir ~/.qlib/stock_data/source/qlib_us_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
|
||||
```
|
||||
|
||||
#### 1d from qlib
|
||||
|
||||
```bash
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d --region us
|
||||
```
|
||||
|
||||
### using data
|
||||
|
||||
```python
|
||||
# using
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="us")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
|
||||
```
|
||||
* `scripts/data_collector/yahoo/collector.py update_data_to_bin` parameters:
|
||||
* `source_dir`: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
* `normalize_dir`: Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
* `qlib_data_1d_dir`: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
|
||||
* `trading_date`: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
* `end_date`: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
* `region`: region, value from ["CN", "US"], default "CN"
|
||||
|
||||
|
||||
### Help
|
||||
```bash
|
||||
python collector.py collector_data --help
|
||||
```
|
||||
## Using qlib data
|
||||
|
||||
## Parameters
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
# 1d data cn
|
||||
# freq=day, freq default day
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="cn")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
|
||||
# 1min data cn
|
||||
# freq=1min
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="cn")
|
||||
inst = D.list_instruments(D.instruments("all"), freq="1min", as_list=True)
|
||||
# get 100 symbols
|
||||
df = D.features(inst[:100], ["$close"], freq="1min")
|
||||
# get all symbol data
|
||||
# df = D.features(D.instruments("all"), ["$close"], freq="1min")
|
||||
|
||||
# 1d data us
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="us")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
|
||||
# 1min data us
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1min", region="cn")
|
||||
inst = D.list_instruments(D.instruments("all"), freq="1min", as_list=True)
|
||||
# get 100 symbols
|
||||
df = D.features(inst[:100], ["$close"], freq="1min")
|
||||
# get all symbol data
|
||||
# df = D.features(D.instruments("all"), ["$close"], freq="1min")
|
||||
```
|
||||
|
||||
- interval: 1min or 1d
|
||||
- region: CN or US
|
||||
|
||||
@@ -8,8 +8,9 @@ import time
|
||||
import datetime
|
||||
import importlib
|
||||
from abc import ABC
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
from typing import Iterable
|
||||
|
||||
import fire
|
||||
import requests
|
||||
@@ -18,13 +19,18 @@ import pandas as pd
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.utils import code_to_fname, fname_to_code
|
||||
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.utils import code_to_fname, fname_to_code, exists_qlib_data
|
||||
from qlib.config import REG_CN as REGION_CN
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
|
||||
from dump_bin import DumpDataUpdate
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize
|
||||
from data_collector.utils import (
|
||||
deco_retry,
|
||||
get_calendar_list,
|
||||
get_hs_stock_symbols,
|
||||
get_us_stock_symbols,
|
||||
@@ -44,7 +50,7 @@ class YahooCollector(BaseCollector):
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
@@ -65,8 +71,8 @@ class YahooCollector(BaseCollector):
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, by default None
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
@@ -92,10 +98,6 @@ class YahooCollector(BaseCollector):
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
# using for 1min
|
||||
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
|
||||
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@@ -140,40 +142,39 @@ class YahooCollector(BaseCollector):
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
@deco_retry(retry_sleep=self.delay)
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
|
||||
return self.get_data_from_remote(
|
||||
resp = self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
if resp is None or resp.empty:
|
||||
raise ValueError(f"get data error: {symbol}--{start_}--{end_}")
|
||||
return resp
|
||||
|
||||
_result = None
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
elif interval == self.INTERVAL_1min:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
try:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
for _s, _e in (
|
||||
(self.start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self.end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
except ValueError as e:
|
||||
pass
|
||||
elif interval == self.INTERVAL_1min:
|
||||
_res = []
|
||||
_start = self.start_datetime
|
||||
while _start < self.end_datetime:
|
||||
_tmp_end = min(_start + pd.Timedelta(days=7), self.end_datetime)
|
||||
try:
|
||||
_resp = _get_simple(_start, _tmp_end)
|
||||
_res.append(_resp)
|
||||
except ValueError as e:
|
||||
pass
|
||||
_start = _tmp_end
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self.interval}")
|
||||
return pd.DataFrame() if _result is None else _result
|
||||
@@ -207,10 +208,6 @@ class YahooCollectorCN(YahooCollector, ABC):
|
||||
|
||||
|
||||
class YahooCollectorCN1d(YahooCollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: from MSN
|
||||
_format = "%Y%m%d"
|
||||
@@ -244,13 +241,12 @@ class YahooCollectorCN1d(YahooCollectorCN):
|
||||
|
||||
|
||||
class YahooCollectorCN1min(YahooCollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 60 * 4 * 5
|
||||
def get_instrument_list(self):
|
||||
symbols = super(YahooCollectorCN1min, self).get_instrument_list()
|
||||
return symbols + ["000300.ss", "000905.ss", "000903.ss"]
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: 1m
|
||||
logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data")
|
||||
pass
|
||||
|
||||
|
||||
class YahooCollectorUS(YahooCollector, ABC):
|
||||
@@ -276,15 +272,11 @@ class YahooCollectorUS(YahooCollector, ABC):
|
||||
|
||||
|
||||
class YahooCollectorUS1d(YahooCollectorUS):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
pass
|
||||
|
||||
|
||||
class YahooCollectorUS1min(YahooCollectorUS):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 60 * 6.5 * 5
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalize(BaseNormalize):
|
||||
@@ -297,6 +289,7 @@ class YahooNormalize(BaseNormalize):
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
last_close: float = None,
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
@@ -318,7 +311,10 @@ class YahooNormalize(BaseNormalize):
|
||||
df.sort_index(inplace=True)
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan
|
||||
_tmp_series = df["close"].fillna(method="ffill")
|
||||
df["change"] = _tmp_series / _tmp_series.shift(1) - 1
|
||||
_tmp_shift_series = _tmp_series.shift(1)
|
||||
if last_close is not None:
|
||||
_tmp_shift_series.iloc[0] = float(last_close)
|
||||
df["change"] = _tmp_series / _tmp_shift_series - 1
|
||||
columns += ["change"]
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
||||
|
||||
@@ -367,6 +363,17 @@ class YahooNormalize1d(YahooNormalize, ABC):
|
||||
df = self._manual_adj_data(df)
|
||||
return df
|
||||
|
||||
def _get_first_close(self, df: pd.DataFrame) -> float:
|
||||
"""get first close value
|
||||
|
||||
Notes
|
||||
-----
|
||||
For incremental updates(append) to Yahoo 1D data, user need to use a close that is not 0 on the first trading day of the existing data
|
||||
"""
|
||||
df = df.loc[df["close"].first_valid_index() :]
|
||||
_close = df["close"].iloc[0]
|
||||
return _close
|
||||
|
||||
def _manual_adj_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""manual adjust data: All fields (except change) are standardized according to the close of the first day"""
|
||||
if df.empty:
|
||||
@@ -374,45 +381,112 @@ class YahooNormalize1d(YahooNormalize, ABC):
|
||||
df = df.copy()
|
||||
df.sort_values(self._date_field_name, inplace=True)
|
||||
df = df.set_index(self._date_field_name)
|
||||
df = df.loc[df["close"].first_valid_index() :]
|
||||
_close = df["close"].iloc[0]
|
||||
_close = self._get_first_close(df)
|
||||
for _col in df.columns:
|
||||
if _col == self._symbol_field_name:
|
||||
# NOTE: retain original adjclose, required for incremental updates
|
||||
if _col in [self._symbol_field_name, "adjclose", "change"]:
|
||||
continue
|
||||
if _col == "volume":
|
||||
df[_col] = df[_col] * _close
|
||||
elif _col != "change":
|
||||
df[_col] = df[_col] / _close
|
||||
else:
|
||||
pass
|
||||
df[_col] = df[_col] / _close
|
||||
return df.reset_index()
|
||||
|
||||
|
||||
class YahooNormalize1dExtend(YahooNormalize1d):
|
||||
def __init__(
|
||||
self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
old_qlib_data_dir: str, Path
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name)
|
||||
self._first_close_field = "first_close"
|
||||
self._ori_close_field = "ori_close"
|
||||
self.old_qlib_data = self._get_old_data(old_qlib_data_dir)
|
||||
|
||||
def _get_old_data(self, qlib_data_dir: [str, Path]):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())
|
||||
qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)
|
||||
df = D.features(D.instruments("all"), ["$close/$factor", "$adjclose/$close"])
|
||||
df.columns = [self._ori_close_field, self._first_close_field]
|
||||
return df
|
||||
|
||||
def _get_close(self, df: pd.DataFrame, field_name: str):
|
||||
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
|
||||
_df = self.old_qlib_data.loc(axis=0)[_symbol]
|
||||
_close = _df.loc[_df.last_valid_index()][field_name]
|
||||
return _close
|
||||
|
||||
def _get_first_close(self, df: pd.DataFrame) -> float:
|
||||
try:
|
||||
_close = self._get_close(df, field_name=self._first_close_field)
|
||||
except KeyError:
|
||||
_close = super(YahooNormalize1dExtend, self)._get_first_close(df)
|
||||
return _close
|
||||
|
||||
def _get_last_close(self, df: pd.DataFrame) -> float:
|
||||
try:
|
||||
_close = self._get_close(df, field_name=self._ori_close_field)
|
||||
except KeyError:
|
||||
_close = None
|
||||
return _close
|
||||
|
||||
def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp:
|
||||
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
|
||||
try:
|
||||
_df = self.old_qlib_data.loc(axis=0)[_symbol]
|
||||
_date = _df.index.max()
|
||||
except KeyError:
|
||||
_date = None
|
||||
return _date
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
_last_close = self._get_last_close(df)
|
||||
# reindex
|
||||
_last_date = self._get_last_date(df)
|
||||
if _last_date is not None:
|
||||
df = df.set_index(self._date_field_name)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
_max_date = df.index.max()
|
||||
df = df.reindex(self._calendar_list).loc[:_max_date].reset_index()
|
||||
df = df[df[self._date_field_name] > _last_date]
|
||||
if df.empty:
|
||||
return pd.DataFrame()
|
||||
_si = df["close"].first_valid_index()
|
||||
if _si > df.index[0]:
|
||||
logger.warning(
|
||||
f"{df.loc[_si][self._symbol_field_name]} missing data: {df.loc[:_si - 1][self._date_field_name].to_list()}"
|
||||
)
|
||||
# normalize
|
||||
df = self.normalize_yahoo(
|
||||
df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close
|
||||
)
|
||||
# adjusted price
|
||||
df = self.adjusted_price(df)
|
||||
df = self._manual_adj_data(df)
|
||||
return df
|
||||
|
||||
|
||||
class YahooNormalize1min(YahooNormalize, ABC):
|
||||
AM_RANGE = None # type: tuple # eg: ("09:30:00", "11:29:00")
|
||||
PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00")
|
||||
|
||||
# Whether the trading day of 1min data is consistent with 1d
|
||||
CONSISTENT_1d = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name)
|
||||
_class_name = self.__class__.__name__.replace("min", "d")
|
||||
_class = getattr(importlib.import_module("collector"), _class_name) # type: Type[YahooNormalize]
|
||||
self.data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
|
||||
CONSISTENT_1d = True
|
||||
CALC_PAUSED_NUM = True
|
||||
|
||||
@property
|
||||
def calendar_list_1d(self):
|
||||
@@ -427,24 +501,40 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
|
||||
if not (data_1d is None or data_1d.empty):
|
||||
_class_name = self.__class__.__name__.replace("min", "d")
|
||||
_class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name)
|
||||
data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
|
||||
data_1d = data_1d_obj.normalize(data_1d)
|
||||
return data_1d
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# TODO: using daily data factor
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df = df.sort_values(self._date_field_name)
|
||||
symbol = df.iloc[0][self._symbol_field_name]
|
||||
# get 1d data from yahoo
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
|
||||
data_1d = YahooCollector.get_data_from_remote(
|
||||
self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end
|
||||
)
|
||||
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
|
||||
data_1d = data_1d.copy()
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1
|
||||
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"]
|
||||
# TODO: np.nan or 1 or 0
|
||||
df["paused"] = np.nan
|
||||
else:
|
||||
data_1d = self.data_1d_obj.normalize(data_1d) # type: pd.DataFrame
|
||||
# NOTE: volume is np.nan or volume <= 0, paused = 1
|
||||
# FIXME: find a more accurate data source
|
||||
data_1d["paused"] = 0
|
||||
@@ -452,9 +542,13 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
data_1d = data_1d.set_index(self._date_field_name)
|
||||
|
||||
# add factor from 1d data
|
||||
# NOTE: yahoo 1d data info:
|
||||
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
|
||||
df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
df.set_index("date_tmp", inplace=True)
|
||||
df.loc[:, "factor"] = data_1d["factor"]
|
||||
df.loc[:, "factor"] = data_1d["close"] / df["close"]
|
||||
df.loc[:, "paused"] = data_1d["paused"]
|
||||
df.reset_index("date_tmp", drop=True, inplace=True)
|
||||
|
||||
@@ -478,6 +572,54 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
df[_col] = df[_col] / df["factor"]
|
||||
else:
|
||||
df[_col] = df[_col] * df["factor"]
|
||||
|
||||
if self.CALC_PAUSED_NUM:
|
||||
df = self.calc_paused_num(df)
|
||||
return df
|
||||
|
||||
def calc_paused_num(self, df: pd.DataFrame):
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
df = df.copy()
|
||||
df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
# remove data that starts and ends with `np.nan` all day
|
||||
all_data = []
|
||||
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
|
||||
all_nan_nums = 0
|
||||
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
|
||||
not_nan_nums = 0
|
||||
for _date, _df in df.groupby("_tmp_date"):
|
||||
_df["paused"] = 0
|
||||
if not _df.loc[_df["volume"] < 0].empty:
|
||||
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
|
||||
_df.loc[_df["volume"] < 0, "volume"] = np.nan
|
||||
|
||||
check_fields = set(_df.columns) - {
|
||||
"_tmp_date",
|
||||
"paused",
|
||||
"factor",
|
||||
self._date_field_name,
|
||||
self._symbol_field_name,
|
||||
}
|
||||
if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all():
|
||||
all_nan_nums += 1
|
||||
not_nan_nums = 0
|
||||
_df["paused"] = 1
|
||||
if all_data:
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
else:
|
||||
all_nan_nums = 0
|
||||
not_nan_nums += 1
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
all_data = all_data[: len(all_data) - all_nan_nums]
|
||||
if all_data:
|
||||
df = pd.concat(all_data, sort=False)
|
||||
else:
|
||||
logger.warning(f"data is empty: {_symbol}")
|
||||
df = pd.DataFrame()
|
||||
return df
|
||||
del df["_tmp_date"]
|
||||
return df
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -485,12 +627,67 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
raise NotImplementedError("rewrite symbol_to_yahoo")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_1d_calendar_list(self):
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
raise NotImplementedError("rewrite _get_1d_calendar_list")
|
||||
|
||||
|
||||
class YahooNormalize1minOffline(YahooNormalize1min):
|
||||
"""Normalised to 1min using local 1d data"""
|
||||
|
||||
def __init__(
|
||||
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_data_1d_dir: str, Path
|
||||
the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self.qlib_data_1d_dir = qlib_data_1d_dir
|
||||
super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name)
|
||||
self._all_1d_data = self._get_all_1d_data()
|
||||
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
return list(D.calendar(freq="day"))
|
||||
|
||||
def _get_all_1d_data(self):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
|
||||
df.reset_index(inplace=True)
|
||||
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
|
||||
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
|
||||
return df
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
return self._all_1d_data[
|
||||
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
|
||||
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
|
||||
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
|
||||
]
|
||||
|
||||
|
||||
class YahooNormalizeUS:
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: from MSN
|
||||
return get_calendar_list("US_ALL")
|
||||
|
||||
@@ -499,10 +696,10 @@ class YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
|
||||
CONSISTENT_1d = False
|
||||
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline):
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: support 1min
|
||||
raise ValueError("Does not support 1min")
|
||||
|
||||
@@ -514,7 +711,7 @@ class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
|
||||
|
||||
|
||||
class YahooNormalizeCN:
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: from MSN
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
@@ -523,28 +720,30 @@ class YahooNormalizeCN1d(YahooNormalizeCN, YahooNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline):
|
||||
AM_RANGE = ("09:30:00", "11:29:00")
|
||||
PM_RANGE = ("13:00:00", "14:59:00")
|
||||
|
||||
CONSISTENT_1d = True
|
||||
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return self.generate_1min_from_daily(self.calendar_list_1d)
|
||||
|
||||
def symbol_to_yahoo(self, symbol):
|
||||
if "." not in symbol:
|
||||
_exchange = symbol[:2]
|
||||
_exchange = "ss" if _exchange == "sh" else _exchange
|
||||
_exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange
|
||||
symbol = symbol[2:] + "." + _exchange
|
||||
return symbol
|
||||
|
||||
def _get_1d_calendar_list(self):
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -554,7 +753,7 @@ class Run(BaseRun):
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
region: str
|
||||
@@ -578,10 +777,10 @@ class Run(BaseRun):
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
delay=0.5,
|
||||
start=None,
|
||||
end=None,
|
||||
check_data_length=False,
|
||||
check_data_length=None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
@@ -591,16 +790,23 @@ class Run(BaseRun):
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
time.sleep(delay), default 0.5
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
start datetime, default "2000-01-01"; closed interval(including start)
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Notes
|
||||
-----
|
||||
check_data_length, example:
|
||||
daily, one year: 252 // 4
|
||||
us 1min, a week: 6.5 * 60 * 5
|
||||
cn 1min, a week: 4 * 60 * 5
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
@@ -612,7 +818,13 @@ class Run(BaseRun):
|
||||
max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums
|
||||
)
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
def normalize_data(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
end_date: str = None,
|
||||
qlib_data_1d_dir: str = None,
|
||||
):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
@@ -621,12 +833,205 @@ class Run(BaseRun):
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
end_date: str
|
||||
if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None
|
||||
qlib_data_1d_dir: str
|
||||
if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;
|
||||
|
||||
qlib_data_1d can be obtained like this:
|
||||
$ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d
|
||||
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
|
||||
or:
|
||||
download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d
|
||||
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
|
||||
"""
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
if self.interval.lower() == "1min":
|
||||
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
|
||||
raise ValueError(
|
||||
"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/zhupr/qlib/tree/support_extend_data/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
|
||||
)
|
||||
super(Run, self).normalize_data(
|
||||
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
|
||||
)
|
||||
|
||||
def normalize_data_1d_extend(
|
||||
self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
|
||||
):
|
||||
"""normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
|
||||
|
||||
Notes
|
||||
-----
|
||||
Steps to extend yahoo qlib data:
|
||||
|
||||
1. download qlib data: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data; save to <dir1>
|
||||
|
||||
2. collector source data: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#collector-data; save to <dir2>
|
||||
|
||||
3. normalize new source data(from step 2): python scripts/data_collector/yahoo/collector.py normalize_data_1d_extend --old_qlib_dir <dir1> --source_dir <dir2> --normalize_dir <dir3> --region CN --interval 1d
|
||||
|
||||
4. dump data: python scripts/dump_bin.py dump_update --csv_path <dir3> --qlib_dir <dir1> --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date
|
||||
|
||||
5. update instrument(eg. csi300): python python scripts/data_collector/cn_index/collector.py --index_name CSI300 --qlib_dir <dir1> --method parse_instruments
|
||||
|
||||
Parameters
|
||||
----------
|
||||
old_qlib_data_dir: str
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data_1d_extend --old_qlib_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"{self.normalize_class_name}Extend")
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
old_qlib_data_dir=old_qlib_data_dir,
|
||||
)
|
||||
yc.normalize()
|
||||
|
||||
def download_today_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0.5,
|
||||
check_data_length=None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download today data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0.5
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Notes
|
||||
-----
|
||||
Download today's data:
|
||||
start_time = datetime.datetime.now().date(); closed interval(including start)
|
||||
end_time = pd.Timestamp(start_time + pd.Timedelta(days=1)).date(); open interval(excluding end)
|
||||
|
||||
check_data_length, example:
|
||||
daily, one year: 252 // 4
|
||||
us 1min, a week: 6.5 * 60 * 5
|
||||
cn 1min, a week: 4 * 60 * 5
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1m
|
||||
"""
|
||||
start = datetime.datetime.now().date()
|
||||
end = pd.Timestamp(start + pd.Timedelta(days=1)).date()
|
||||
self.download_data(
|
||||
max_collector_count,
|
||||
delay,
|
||||
start.strftime("%Y-%m-%d"),
|
||||
end.strftime("%Y-%m-%d"),
|
||||
check_data_length,
|
||||
limit_nums,
|
||||
)
|
||||
|
||||
def update_data_to_bin(
|
||||
self,
|
||||
qlib_data_1d_dir: str,
|
||||
trading_date: str = None,
|
||||
end_date: str = None,
|
||||
check_data_length: int = None,
|
||||
delay: float = 1,
|
||||
):
|
||||
"""update yahoo data to bin
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_data_1d_dir: str
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
|
||||
trading_date: str
|
||||
trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
end_date: str
|
||||
end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
delay: float
|
||||
time.sleep(delay), default 1
|
||||
Notes
|
||||
-----
|
||||
If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
# get 1m data
|
||||
"""
|
||||
|
||||
if self.interval.lower() != "1d":
|
||||
logger.warning(f"currently supports 1d data updates: --interval 1d")
|
||||
|
||||
# start/end date
|
||||
if trading_date is None:
|
||||
trading_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
logger.warning(f"trading_date is None, use the current date: {trading_date}")
|
||||
|
||||
if end_date is None:
|
||||
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
# download qlib 1d data
|
||||
qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
|
||||
if not exists_qlib_data(qlib_data_1d_dir):
|
||||
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
|
||||
|
||||
# download data from yahoo
|
||||
# NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1
|
||||
self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length)
|
||||
# NOTE: a larger max_workers setting here would be faster
|
||||
self.max_workers = (
|
||||
max(multiprocessing.cpu_count() - 2, 1)
|
||||
if self.max_workers is None or self.max_workers <= 1
|
||||
else self.max_workers
|
||||
)
|
||||
# normalize data
|
||||
self.normalize_data_1d_extend(qlib_data_1d_dir)
|
||||
|
||||
# dump bin
|
||||
_dump = DumpDataUpdate(
|
||||
csv_path=self.normalize_dir,
|
||||
qlib_dir=qlib_data_1d_dir,
|
||||
exclude_fields="symbol,date",
|
||||
max_workers=self.max_workers,
|
||||
)
|
||||
_dump.dump()
|
||||
|
||||
# parse index
|
||||
_region = self.region.lower()
|
||||
if _region not in ["cn", "us"]:
|
||||
logger.warning(f"Unsupported region: region={_region}, component downloads will be ignored")
|
||||
return
|
||||
index_list = ["CSI100", "CSI300"] if _region == "cn" else ["SP500", "NASDAQ100", "DJIA", "SP400"]
|
||||
get_instruments = getattr(
|
||||
importlib.import_module(f"data_collector.{_region}_index.collector"), "get_instruments"
|
||||
)
|
||||
for _index in index_list:
|
||||
get_instruments(str(qlib_data_1d_dir), _index)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -6,3 +6,4 @@ pandas
|
||||
tqdm
|
||||
lxml
|
||||
yahooquery
|
||||
joblib
|
||||
|
||||
@@ -401,6 +401,8 @@ class DumpDataUpdate(DumpDataBase):
|
||||
)
|
||||
self._mode = self.UPDATE_MODE
|
||||
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
|
||||
# NOTE: all.txt only exists once for each stock
|
||||
# NOTE: if a stock corresponds to multiple different time ranges, user need to modify self._update_instruments
|
||||
self._update_instruments = (
|
||||
self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))
|
||||
.set_index([self.symbol_field_name])
|
||||
@@ -409,10 +411,9 @@ class DumpDataUpdate(DumpDataBase):
|
||||
|
||||
# load all csv files
|
||||
self._all_data = self._load_all_source_data() # type: pd.DataFrame
|
||||
self._update_calendars = sorted(
|
||||
self._new_calendar_list = self._old_calendar_list + sorted(
|
||||
filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique())
|
||||
)
|
||||
self._new_calendar_list = self._old_calendar_list + self._update_calendars
|
||||
|
||||
def _load_all_source_data(self):
|
||||
# NOTE: Need more memory
|
||||
@@ -452,8 +453,16 @@ class DumpDataUpdate(DumpDataBase):
|
||||
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
|
||||
continue
|
||||
if _code in self._update_instruments:
|
||||
# exists stock, will append data
|
||||
_update_calendars = (
|
||||
_df[_df[self.date_field_name] > self._update_instruments[_code][self.INSTRUMENTS_START_FIELD]][
|
||||
self.date_field_name
|
||||
]
|
||||
.sort_values()
|
||||
.to_list()
|
||||
)
|
||||
self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
|
||||
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
|
||||
futures[executor.submit(self._dump_bin, _df, _update_calendars)] = _code
|
||||
else:
|
||||
# new stock
|
||||
_dt_range = self._update_instruments.setdefault(_code, dict())
|
||||
|
||||
Reference in New Issue
Block a user