1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00

Compare commits

...

86 Commits

Author SHA1 Message Date
Young
215f7e0d22 update version for release 0.7.0 2021-07-11 14:34:44 +00:00
xiaowuhu
dafef0ac08 Update workflow.rst
should be China instead of china
2021-07-06 09:22:11 +08:00
xiaowuhu
1cb43ea69b Update workflow.rst
remove 空格 before module_path, kwargs, etc, otherwise, yaml parser will report error: ruamel.yaml.scanner.ScannerError: mapping values are not allowed here
2021-07-06 09:21:14 +08:00
you-n-g
7ca9cf79f7 Update README.md 2021-07-05 19:47:49 +08:00
you-n-g
35f090a6e4 Update what's new 2021-07-04 16:47:33 +08:00
Lewen Wang
ace7484304 Update TCTS. (#495)
* Update TCTS Model.

Co-authored-by: lewwang <lwwang@microsoft.com>
2021-07-04 16:45:05 +08:00
bxdd
2d4f0e80f9 black format 2021-07-02 08:47:52 +08:00
bxdd
946c9392a1 support check_transform_proc module_path 2021-07-02 08:47:52 +08:00
lzh222333
b523b27d5a add docstring 2021-06-30 10:59:34 +08:00
lzh222333
0b83fb3564 more general exception 2021-06-30 10:59:34 +08:00
lzh222333
d96f7a67c6 bug & docs fixed 2021-06-30 10:59:34 +08:00
lzh222333
a7862387a2 fixed update bugs 2021-06-30 10:59:34 +08:00
lzh222333
c4c438249c modify OnlineToolR 2021-06-30 10:59:34 +08:00
you-n-g
8709dde65b Merge pull request #481 from ai4stocks/working_workflow_fix_ipynb
examples/workflow_by_code.ipynd: fix an error in R.get_recorder() par…
2021-06-26 19:57:24 +08:00
Guodong Xu
d66733c358 examples/workflow_by_code.ipynd: fix an error in R.get_recorder() parameters
get_recorder() needs specify 'recorder_id='. However workflow_by_code.ipynd
didn't. This patch fixes it.

Without this fix, here is the error message jupyter-notebook reports:

"---------------------------------------------------------------------------
TypeError Traceback (most recent call last)

<ipython-input-7-e6a7b5f4da00> in <module>
26 # backtest and analysis
27 with R.start(experiment_name="backtest_analysis"):
---> 28 recorder = R.get_recorder(rid, experiment_name="train_model")
29 model = recorder.load_object("trained_model")
30

TypeError: get_recorder() takes 1 positional argument but 2 positional arguments (and 1 keyword-only argument) were given"

Signed-off-by: Guodong Xu <guodong.xu@linaro.org>
2021-06-26 18:25:47 +08:00
Dong Zhou
9cf574b697 Merge pull request #479 from linhx25/main
Add TRA Model
2021-06-25 18:08:23 +08:00
linhx25
107e40f3ee Add TRA Model 2021-06-25 16:12:50 +08:00
you-n-g
4837ba8db3 Merge pull request #476 from bxdd/qlib_ops_config
Support using config to register custom operators
2021-06-24 20:41:48 +08:00
Qian Chen
2ab4a9adb3 Set self.fitted = True instead of self._fitted. 2021-06-24 20:40:59 +08:00
bxdd
8d0b673341 add custom_ops docstring 2021-06-24 15:00:45 +08:00
you-n-g
8ebdb1e873 Merge pull request #463 from zhupr/support_extend_data
Support extend data
2021-06-24 13:53:30 +08:00
zhupr
39340fbf06 fix: typo 2021-06-24 11:07:40 +08:00
zhupr
0e277723a3 Merge remote-tracking branch 'qlib/main' into qlib_main
# Conflicts:
#	scripts/data_collector/yahoo/README.md
2021-06-24 00:09:54 +08:00
zhupr
1418417034 fix automatic update of daily frequency data 2021-06-23 23:59:59 +08:00
you-n-g
b261f7b501 Update README.md 2021-06-23 20:51:21 +08:00
zhupr
bab50e8837 fix YahooNormalize1min && update docs 2021-06-23 16:13:26 +08:00
bxdd
0eee4a0f2e support config custom_ops 2021-06-23 15:56:36 +08:00
Young
21eb71d4a9 update framework for online serving 2021-06-23 02:05:38 +00:00
zhupr
46714adf4c modify the YahooNormalize1min factor calculation 2021-06-22 11:15:09 +08:00
zhupr
99fb49650a add end_date parameter to collector.normalize_data 2021-06-21 17:20:37 +08:00
zhupr
985fd0816c Fix cn_index.collector network error 2021-06-21 17:18:04 +08:00
Young
d0f54343c7 support subclass of TSDatasetH 2021-06-21 00:24:31 +08:00
Young
a3679e6758 simplify the code and prevent float when shifting 2021-06-21 00:24:31 +08:00
zhupr
b6c31540e8 add function to automatically update daily frequency data 2021-06-17 23:07:56 +08:00
zhupr
a4f6e04199 modify dump_update starts with the last end date of each symbol 2021-06-17 22:33:31 +08:00
you-n-g
0aee46ee79 Merge pull request #466 from you-n-g/online_hotfix
Online bug fix, enhancement &  docs for dataset, workflow, trainer ...
2021-06-17 11:38:44 +08:00
Young
9c8d423a86 fix ModelUpdater 2021-06-16 14:10:51 +00:00
zhupr
b4efbd53b2 Fix 'report' compatibility with matplotlib versions 2021-06-16 22:00:43 +08:00
you-n-g
5a50d7c952 Merge pull request #471 from Derek-Wds/main
Update Recorder Wrapper to prevent reinitialization
2021-06-16 17:46:31 +08:00
Jactus
0fe8b281ba Update R wrapper logic 2021-06-16 12:28:20 +08:00
lewwang
5331ab93f8 Update TCTS README. 2021-06-16 12:23:22 +08:00
Jactus
64582e9d46 Add QlibException 2021-06-15 15:02:11 +08:00
Jactus
9e0e2ff736 Update QlibRecorder wrapper 2021-06-15 14:46:31 +08:00
Young
973c4137e4 fix mlflow & task bug 2021-06-12 13:54:26 +00:00
Young
730f6258d6 add warning and * 2021-06-11 10:40:56 +00:00
Young
5850490b24 simplify the code and add docs 2021-06-11 08:29:10 +00:00
Young
d4b36bdab4 Online fix
- Skip duplicated qlib.auto_init()
- Fix TSDatasetH flt_col bug!
- Resolve qlib log attribute confliction
- Trainer API enhancement
- More docs and user-friendly warning
2021-06-11 02:06:07 +00:00
you-n-g
40416d8c30 Merge pull request #464 from lwwang1995/main
Add TCTS baseline.
2021-06-10 10:18:20 +08:00
lewwang
567e42840c asdf 2021-06-09 18:37:25 +08:00
lewwang
65ddca133f asdf 2021-06-09 18:36:12 +08:00
lewwang
d199256d34 asdf 2021-06-09 18:35:14 +08:00
lewwang
073fe4668e asdf 2021-06-09 18:34:31 +08:00
lewwang
89d53853e5 asdf 2021-06-09 18:30:42 +08:00
lewwang
bb6c1572ca asdf 2021-06-09 18:29:55 +08:00
lewwang
4c4e77b11f asdf 2021-06-09 18:28:31 +08:00
lewwang
38c7b7303a dsaf 2021-06-09 18:26:50 +08:00
lewwang
02d0eedd68 update 2021-06-09 18:21:16 +08:00
lewwang
5a3dde93a8 update 2021-06-09 18:15:06 +08:00
lewwang
177f6a59d2 asdf 2021-06-09 17:47:24 +08:00
lewwang
492a62a569 tcts demo page 2021-06-09 17:32:24 +08:00
zhupr
9a44fbf9c1 fix PEP8: qlib/scripts/data_collector/fund/collector.py 2021-06-08 22:52:31 +08:00
zhupr
03eb0882de fix YahooNormalizeCN1minOffline bugs 2021-06-08 22:23:05 +08:00
zhupr
a845a2271b add normalize 1min to use local data && change the default parameters for collecting 1min 2021-06-08 14:45:20 +08:00
you-n-g
ba021f6007 Merge pull request #462 from arisliang/patch-1
Remove non-existing parameter description
2021-06-08 13:03:43 +08:00
al
7d9544fb91 Remove non-existing parameter from doc
Remove non-existing TradeExchange parameter from generate_target_weight_position doc
2021-06-08 09:35:36 +08:00
you-n-g
12b7be333d Merge pull request #461 from Derek-Wds/main
Fix exception hook bug
2021-06-07 21:07:33 +08:00
Jactus
ed54f1213c Fix exception hook bug 2021-06-07 17:13:36 +08:00
zhupr
554b9c7826 fix YahooCollector getting 1min data occasionally missing 2021-06-05 23:43:48 +08:00
zhupr
6f150f3fd6 Add YahooCollector support for extend data 2021-06-04 22:28:42 +08:00
you-n-g
2a0d991d9b Merge pull request #459 from you-n-g/online_srv
fix DelayTrainerRM
2021-06-03 15:55:11 +08:00
lzh222333
1320e53f81 fix DelayTrainerRM 2021-06-03 03:23:48 +00:00
Young
8222795ac4 fix format with black 2021-06-02 09:16:46 +00:00
you-n-g
616a742db7 Merge pull request #435 from you-n-g/online_srv
Multiprocessing support for Online Serving
2021-06-02 17:12:19 +08:00
lzh222333
811d2c975e update & fix 2021-06-02 08:56:15 +00:00
lzh222333
6272ce108f Merge remote-tracking branch 'microsoft/main' into online_srv 2021-06-02 08:32:12 +00:00
you-n-g
64896745d0 Merge pull request #457 from zhupr/fix_XGBoost_predict_error
fix XGBoost predict error
2021-06-02 16:14:18 +08:00
zhupr
b2fe2385d5 fix XGBoost predict error 2021-06-01 21:02:32 +08:00
lzh222333
8d05cd2daf modify tests.config.py 2021-06-01 09:40:53 +00:00
lzh222333
231bdf8608 Merge remote-tracking branch 'microsoft/main' into online_srv 2021-06-01 08:29:02 +00:00
lzh222333
ab6b88ce14 delete useless import 2021-06-01 07:48:14 +00:00
lzh222333
94ab4bbf3f add docs 2021-06-01 07:45:39 +00:00
lzh222333
ca0363ded8 update trainer and manage 2021-05-27 06:04:46 +00:00
lzh222333
a467e10974 Merge remote-tracking branch 'microsoft/main' into online_srv 2021-05-24 05:10:15 +00:00
lzh222333
6dfbf00a23 Merge branch 'microsoft_main' into online_srv 2021-05-24 05:07:53 +00:00
lzh222333
b24af7fff6 multiprocessing support 2021-05-24 05:07:38 +00:00
lwwang1995
45f73361e3 add tcts baseline 2021-03-18 11:17:42 +08:00
81 changed files with 4360 additions and 550 deletions

View File

@@ -11,6 +11,7 @@
Recent released features
| Feature | Status |
| -- | ------ |
| TCTS Model | [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
| High-frequency data processing example | [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
@@ -68,7 +69,7 @@ Your feedbacks about the features are very important.
# Framework of Qlib
<div style="align: center">
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.1" />
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.2" />
</div>
@@ -159,6 +160,28 @@ Users could create the same dataset with it.
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
### Automatic update of daily frequency data(from yahoo finance)
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
* Automatic update of data to the "qlib" directory each trading day(Linux)
* use *crontab*: `crontab -e`
* set up timed tasks:
```
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
```
* **script path**: *scripts/data_collector/yahoo/collector.py*
* Manual update of data
```
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
```
* *trading_date*: start of trading day
* *end_date*: end of trading day(not included)
<!--
- Run the initialization code and get stock data:
@@ -254,18 +277,19 @@ The automatic workflow may not suit the research workflow of all Quant researche
# [Quant Model Zoo](examples/benchmarks)
Here is a list of models built on `Qlib`.
- [GBDT based on XGBoost (Tianqi Chen, et al. 2016)](qlib/contrib/model/xgboost.py)
- [GBDT based on LightGBM (Guolin Ke, et al. 2017)](qlib/contrib/model/gbdt.py)
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. 2017)](qlib/contrib/model/catboost_model.py)
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](qlib/contrib/model/xgboost.py)
- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](qlib/contrib/model/gbdt.py)
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](qlib/contrib/model/catboost_model.py)
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
- [LSTM based on pytorch (Sepp Hochreiter, et al. 1997)](qlib/contrib/model/pytorch_lstm.py)
- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural omputation 1997)](qlib/contrib/model/pytorch_lstm.py)
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](qlib/contrib/model/pytorch_gru.py)
- [ALSTM based on pytorch (Yao Qin, et al. 2017)](qlib/contrib/model/pytorch_alstm.py)
- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](qlib/contrib/model/pytorch_alstm.py)
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](qlib/contrib/model/pytorch_gats.py)
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. 2020)](qlib/contrib/model/double_ensemble.py)
- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](qlib/contrib/model/pytorch_sfm.py)
- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/tft.py)
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](qlib/contrib/model/pytorch_tabnet.py)
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](qlib/contrib/model/double_ensemble.py)
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](qlib/contrib/model/pytorch_tcts.py)
Your PR of new Quant models is highly welcomed.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 271 KiB

After

Width:  |  Height:  |  Size: 208 KiB

View File

@@ -67,6 +67,34 @@ After running the above command, users can find china-stock and us-stock data in
When ``Qlib`` is initialized with this dataset, users could build and evaluate their own models with it. Please refer to `Initialization <../start/initialization.html>`_ for more details.
Automatic update of daily frequency data
----------------------------------------
**It is recommended that users update the data manually once (\-\-trading_date 2021-05-25) and then set it to update automatically.**
For more information refer to: `yahoo collector <https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#Automatic-update-of-daily-frequency-data>`_
- Automatic update of data to the "qlib" directory each trading day(Linux)
- use *crontab*: `crontab -e`
- set up timed tasks:
.. code-block:: bash
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
- **script path**: *scripts/data_collector/yahoo/collector.py*
- Manual update of data
.. code-block:: bash
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
- *trading_date*: start of trading day
- *end_date*: end of trading day(not included)
Converting CSV Format into Qlib Format
-------------------------------------------

View File

@@ -90,12 +90,12 @@ Below is a typical config file of ``qrun``.
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
module_path: qlib.workflow.record_temp
kwargs: {}
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
@@ -142,7 +142,7 @@ The meaning of each field is as follows:
- `region`
- If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
- If `region` == "cn", ``Qlib`` will be initialized in china-stock mode.
- If `region` == "cn", ``Qlib`` will be initialized in China-stock mode.
.. note::

View File

@@ -61,7 +61,6 @@ task:
metric: loss
loss: mse
base_model: LSTM
with_pretrain: True
model_path: "benchmarks/LSTM/csi300_lstm_ts.pkl"
GPU: 0
dataset:

View File

@@ -54,7 +54,6 @@ task:
metric: loss
loss: mse
base_model: LSTM
with_pretrain: True
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
GPU: 0
dataset:
@@ -81,4 +80,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -4,6 +4,10 @@ Here are the results of each benchmark model running on Qlib's `Alpha360` and `A
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
> If you need to reproduce the results below, please use the **v1** dataset: `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn --version v1`
>
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ
## Alpha360 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|---|---|---|---|---|---|---|---|---|
@@ -18,6 +22,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |

View File

@@ -0,0 +1,52 @@
# Temporally Correlated Task Scheduling for Sequence Learning
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
### Background
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
<p align="center">
<img src="task_description.png" width="600" height="200"/>
</p>
### Method
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
<p align="center">
<img src="workflow.png"/>
</p>
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
### DataSet
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
### Experiments
#### Task Description
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
<div align=center>
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
</div>
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
#### Baselines
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
#### Result
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
| :----: | :----: | :----: | :----: |
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

View File

@@ -0,0 +1,93 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -1) / $close - 1",
"Ref($close, -2) / Ref($close, -1) - 1",
"Ref($close, -3) / Ref($close, -2) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TCTS
module_path: qlib.contrib.model.pytorch_tcts
kwargs:
d_feat: 6
hidden_size: 64
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 1e-3
early_stop: 20
batch_size: 800
metric: loss
loss: mse
GPU: 0
fore_optimizer: adam
weight_optimizer: adam
output_dim: 3
fore_lr: 5e-4
weight_lr: 5e-4
steps: 3
target_label: 1
lowest_valid_performance: 0.993
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
label_col: 1
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,81 @@
# Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport
This code provides a PyTorch implementation for TRA (Temporal Routing Adaptor), as described in the paper [Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport](http://arxiv.org/abs/2106.12950).
* TRA (Temporal Routing Adaptor) is a lightweight module that consists of a set of independent predictors for learning multiple patterns as well as a router to dispatch samples to different predictors.
* We also design a learning algorithm based on Optimal Transport (OT) to obtain the optimal sample to predictor assignment and effectively optimize the router with such assignment through an auxiliary loss term.
# Running TRA
## Requirements
- Install `Qlib` main branch
## Running
We attach our running scripts for the paper in `run.sh`.
And here are two ways to run the model:
* Running from scripts with default parameters
You can directly run from Qlib command `qrun`:
```
qrun configs/config_alstm.yaml
```
* Running from code with self-defined parameters
Setting different parameters is also allowed. See codes in `example.py`:
```
python example.py --config_file configs/config_alstm.yaml
```
Here we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scipts.
# Results
## Outputs
After running the scripts, you can find result files in path `./output`:
`info.json` - config settings and result metrics.
`log.csv` - running logs.
`model.bin` - the model parameter dictionary.
`pred.pkl` - the prediction scores and output for inference.
## Our Results
| Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD |
|-------------------|-------------------|---------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|Linear|0.163|0.327|0.020|0.132|-3.2%|16.8%|-0.191|32.1%|
|LightGBM|0.160(0.000)|0.323(0.000)|0.041|0.292|7.8%|15.5%|0.503|25.7%|
|MLP|0.160(0.002)|0.323(0.003)|0.037|0.273|3.7%|15.3%|0.264|26.2%|
|SFM|0.159(0.001) |0.321(0.001) |0.047 |0.381 |7.1% |14.3% |0.497 |22.9%|
|ALSTM|0.158(0.001) |0.320(0.001) |0.053 |0.419 |12.3% |13.7% |0.897 |20.2%|
|Trans.|0.158(0.001) |0.322(0.001) |0.051 |0.400 |14.5% |14.2% |1.028 |22.5%|
|ALSTM+TS|0.160(0.002) |0.321(0.002) |0.039 |0.291 |6.7% |14.6% |0.480|22.3%|
|Trans.+TS|0.160(0.004) |0.324(0.005) |0.037 |0.278 |10.4% |14.7% |0.722 |23.7%|
|ALSTM+TRA(Ours)|0.157(0.000) |0.318(0.000) |0.059 |0.460 |12.4% |14.0% |0.885 |20.4%|
|Trans.+TRA(Ours)|0.157(0.000) |0.320(0.000) |0.056 |0.442 |16.1% |14.2% |1.133 |23.1%|
A more detailed demo for our experiment results in the paper can be found in `Report.ipynb`.
# Common Issues
For help or issues using TRA, please submit a GitHub issue.
Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.
# Citation
If you find this repository useful in your research, please cite:
```
@inproceedings{HengxuKDD2021,
author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
series = {KDD '21},
year = {2021},
publisher = {ACM},
}
```

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 256
num_layers: 2
num_heads: 2
use_attn: True
dropout: 0.1
num_states: &num_states 1
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0002
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/alstm
model_type: LSTM
model_config: *model_config
tra_config: *tra_config
lamb: 1.0
rho: 0.99
freeze_model: False
model_init_state:
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 1024

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 256
num_layers: 2
num_heads: 2
use_attn: True
dropout: 0.1
num_states: &num_states 10
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0001
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/alstm_tra
model_type: LSTM
model_config: *model_config
tra_config: *tra_config
lamb: 2.0
rho: 0.99
freeze_model: True
model_init_state: output/test/alstm_tra_init/model.bin
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 1024

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 256
num_layers: 2
num_heads: 2
use_attn: True
dropout: 0.1
num_states: &num_states 3
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0002
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/alstm_tra_init
model_type: LSTM
model_config: *model_config
tra_config: *tra_config
lamb: 1.0
rho: 0.99
freeze_model: False
model_init_state:
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 512

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 64
num_layers: 2
num_heads: 4
use_attn: False
dropout: 0.1
num_states: &num_states 1
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0002
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/transformer
model_type: Transformer
model_config: *model_config
tra_config: *tra_config
lamb: 1.0
rho: 0.99
freeze_model: False
model_init_state:
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 1024

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 64
num_layers: 2
num_heads: 4
use_attn: False
dropout: 0.1
num_states: &num_states 3
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0005
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/transformer_tra
model_type: Transformer
model_config: *model_config
tra_config: *tra_config
lamb: 1.0
rho: 0.99
freeze_model: True
model_init_state: output/test/transformer_tra_init/model.bin
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 512

View File

@@ -0,0 +1,63 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
data_loader_config: &data_loader_config
class: StaticDataLoader
module_path: qlib.data.dataset.loader
kwargs:
config:
feature: data/feature.pkl
label: data/label.pkl
model_config: &model_config
input_size: 16
hidden_size: 64
num_layers: 2
num_heads: 4
use_attn: False
dropout: 0.1
num_states: &num_states 3
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
task:
model:
class: TRAModel
module_path: src/model.py
kwargs:
lr: 0.0002
n_epochs: 500
max_steps_per_epoch: 100
early_stop: 20
seed: 1000
logdir: output/test/transformer_tra_init
model_type: Transformer
model_config: *model_config
tra_config: *tra_config
lamb: 1.0
rho: 0.99
freeze_model: False
model_init_state:
dataset:
class: MTSDatasetH
module_path: src/dataset.py
kwargs:
handler:
class: DataHandler
module_path: qlib.data.dataset.handler
kwargs:
data_loader: *data_loader_config
segments:
train: [2007-10-30, 2016-05-27]
valid: [2016-09-26, 2018-05-29]
test: [2018-09-21, 2020-06-30]
seq_len: 60
horizon: 21
num_states: *num_states
batch_size: 512

View File

@@ -0,0 +1 @@
Data Link: https://drive.google.com/drive/folders/1fMqZYSeLyrHiWmVzygeI4sw3vp5Gt8cY?usp=sharing

View File

@@ -0,0 +1,39 @@
import argparse
import qlib
import ruamel.yaml as yaml
from qlib.utils import init_instance_by_config
def main(seed, config_file="configs/config_alstm.yaml"):
# set random seed
with open(config_file) as f:
config = yaml.safe_load(f)
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
seed_suffix = ""
config["task"]["model"]["kwargs"].update(
{"seed": seed, "logdir": config["task"]["model"]["kwargs"]["logdir"] + seed_suffix}
)
# initialize workflow
qlib.init(
provider_uri=config["qlib_init"]["provider_uri"],
region=config["qlib_init"]["region"],
)
dataset = init_instance_by_config(config["task"]["dataset"])
model = init_instance_by_config(config["task"]["model"])
# train model
model.fit(dataset)
if __name__ == "__main__":
# set params from cmd
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--seed", type=int, default=1000, help="random seed")
parser.add_argument("--config_file", type=str, default="configs/config_alstm.yaml", help="config file")
args = parser.parse_args()
main(**vars(args))

View File

@@ -0,0 +1,29 @@
#!/bin/bash
# we used random seed(1 1000 2000 3000 4000 5000) in our experiments
# Directly run from Qlib command `qrun`
qrun configs/config_alstm.yaml
qrun configs/config_transformer.yaml
qrun configs/config_transformer_tra_init.yaml
qrun configs/config_transformer_tra.yaml
qrun configs/config_alstm_tra_init.yaml
qrun configs/config_alstm_tra.yaml
# Or setting different parameters with example.py
python example.py --config_file configs/config_alstm.yaml
python example.py --config_file configs/config_transformer.yaml
python example.py --config_file configs/config_transformer_tra_init.yaml
python example.py --config_file configs/config_transformer_tra.yaml
python example.py --config_file configs/config_alstm_tra_init.yaml
python example.py --config_file configs/config_alstm_tra.yaml

View File

@@ -0,0 +1,253 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import torch
import numpy as np
import pandas as pd
from qlib.utils import init_instance_by_config
from qlib.data.dataset import DatasetH, DataHandler
device = "cuda" if torch.cuda.is_available() else "cpu"
def _to_tensor(x):
if not isinstance(x, torch.Tensor):
return torch.tensor(x, dtype=torch.float, device=device)
return x
def _create_ts_slices(index, seq_len):
"""
create time series slices from pandas index
Args:
index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
seq_len (int): sequence length
"""
assert index.is_lexsorted(), "index should be sorted"
# number of dates for each code
sample_count_by_codes = pd.Series(0, index=index).groupby(level=0).size().values
# start_index for each code
start_index_of_codes = np.roll(np.cumsum(sample_count_by_codes), 1)
start_index_of_codes[0] = 0
# all the [start, stop) indices of features
# features btw [start, stop) are used to predict the `stop - 1` label
slices = []
for cur_loc, cur_cnt in zip(start_index_of_codes, sample_count_by_codes):
for stop in range(1, cur_cnt + 1):
end = cur_loc + stop
start = max(end - seq_len, 0)
slices.append(slice(start, end))
slices = np.array(slices)
return slices
def _get_date_parse_fn(target):
"""get date parse function
This method is used to parse date arguments as target type.
Example:
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
get_date_parse_fn(20120101)('2017-01-01') => 20170101
"""
if isinstance(target, pd.Timestamp):
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
elif isinstance(target, str) and len(target) == 8:
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
elif isinstance(target, int):
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
else:
_fn = lambda x: x
return _fn
class MTSDatasetH(DatasetH):
"""Memory Augmented Time Series Dataset
Args:
handler (DataHandler): data handler
segments (dict): data split segments
seq_len (int): time series sequence length
horizon (int): label horizon (to mask historical loss for TRA)
num_states (int): how many memory states to be added (for TRA)
batch_size (int): batch size (<0 means daily batch)
shuffle (bool): whether shuffle data
pin_memory (bool): whether pin data to gpu memory
drop_last (bool): whether drop last batch < batch_size
"""
def __init__(
self,
handler,
segments,
seq_len=60,
horizon=0,
num_states=1,
batch_size=-1,
shuffle=True,
pin_memory=False,
drop_last=False,
**kwargs
):
assert horizon > 0, "please specify `horizon` to avoid data leakage"
self.seq_len = seq_len
self.horizon = horizon
self.num_states = num_states
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.pin_memory = pin_memory
self.params = (batch_size, drop_last, shuffle) # for train/eval switch
super().__init__(handler, segments, **kwargs)
def setup_data(self, handler_kwargs: dict = None, **kwargs):
super().setup_data()
# change index to <code, date>
# NOTE: we will use inplace sort to reduce memory use
df = self.handler._data
df.index = df.index.swaplevel()
df.sort_index(inplace=True)
self._data = df["feature"].values.astype("float32")
self._label = df["label"].squeeze().astype("float32")
self._index = df.index
# add memory to feature
self._data = np.c_[self._data, np.zeros((len(self._data), self.num_states), dtype=np.float32)]
# padding tensor
self.zeros = np.zeros((self.seq_len, self._data.shape[1]), dtype=np.float32)
# pin memory
if self.pin_memory:
self._data = _to_tensor(self._data)
self._label = _to_tensor(self._label)
self.zeros = _to_tensor(self.zeros)
# create batch slices
self.batch_slices = _create_ts_slices(self._index, self.seq_len)
# create daily slices
index = [slc.stop - 1 for slc in self.batch_slices]
act_index = self.restore_index(index)
daily_slices = {date: [] for date in sorted(act_index.unique(level=1))}
for i, (code, date) in enumerate(act_index):
daily_slices[date].append(self.batch_slices[i])
self.daily_slices = list(daily_slices.values())
def _prepare_seg(self, slc, **kwargs):
fn = _get_date_parse_fn(self._index[0][1])
start_date = fn(slc.start)
end_date = fn(slc.stop)
obj = copy.copy(self) # shallow copy
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
obj._data = self._data
obj._label = self._label
obj._index = self._index
new_batch_slices = []
for batch_slc in self.batch_slices:
date = self._index[batch_slc.stop - 1][1]
if start_date <= date <= end_date:
new_batch_slices.append(batch_slc)
obj.batch_slices = np.array(new_batch_slices)
new_daily_slices = []
for daily_slc in self.daily_slices:
date = self._index[daily_slc[0].stop - 1][1]
if start_date <= date <= end_date:
new_daily_slices.append(daily_slc)
obj.daily_slices = new_daily_slices
return obj
def restore_index(self, index):
if isinstance(index, torch.Tensor):
index = index.cpu().numpy()
return self._index[index]
def assign_data(self, index, vals):
if isinstance(self._data, torch.Tensor):
vals = _to_tensor(vals)
elif isinstance(vals, torch.Tensor):
vals = vals.detach().cpu().numpy()
index = index.detach().cpu().numpy()
self._data[index, -self.num_states :] = vals
def clear_memory(self):
self._data[:, -self.num_states :] = 0
# TODO: better train/eval mode design
def train(self):
"""enable traning mode"""
self.batch_size, self.drop_last, self.shuffle = self.params
def eval(self):
"""enable evaluation mode"""
self.batch_size = -1
self.drop_last = False
self.shuffle = False
def _get_slices(self):
if self.batch_size < 0:
slices = self.daily_slices.copy()
batch_size = -1 * self.batch_size
else:
slices = self.batch_slices.copy()
batch_size = self.batch_size
return slices, batch_size
def __len__(self):
slices, batch_size = self._get_slices()
if self.drop_last:
return len(slices) // batch_size
return (len(slices) + batch_size - 1) // batch_size
def __iter__(self):
slices, batch_size = self._get_slices()
if self.shuffle:
np.random.shuffle(slices)
for i in range(len(slices))[::batch_size]:
if self.drop_last and i + batch_size > len(slices):
break
# get slices for this batch
slices_subset = slices[i : i + batch_size]
if self.batch_size < 0:
slices_subset = np.concatenate(slices_subset)
# collect data
data = []
label = []
index = []
for slc in slices_subset:
_data = self._data[slc].clone() if self.pin_memory else self._data[slc].copy()
if len(_data) != self.seq_len:
if self.pin_memory:
_data = torch.cat([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
else:
_data = np.concatenate([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
if self.num_states > 0:
_data[-self.horizon :, -self.num_states :] = 0
data.append(_data)
label.append(self._label[slc.stop - 1])
index.append(slc.stop - 1)
# concate
index = torch.tensor(index, device=device)
if isinstance(data[0], torch.Tensor):
data = torch.stack(data)
label = torch.stack(label)
else:
data = _to_tensor(np.stack(data))
label = _to_tensor(np.stack(label))
# yield -> generator
yield {"data": data, "label": label, "index": index}

View File

@@ -0,0 +1,603 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import copy
import math
import json
import collections
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from qlib.utils import get_or_create_path
from qlib.log import get_module_logger
from qlib.model.base import Model
device = "cuda" if torch.cuda.is_available() else "cpu"
class TRAModel(Model):
def __init__(
self,
model_config,
tra_config,
model_type="LSTM",
lr=1e-3,
n_epochs=500,
early_stop=50,
smooth_steps=5,
max_steps_per_epoch=None,
freeze_model=False,
model_init_state=None,
lamb=0.0,
rho=0.99,
seed=0,
logdir=None,
eval_train=True,
eval_test=False,
avg_params=True,
**kwargs,
):
np.random.seed(seed)
torch.manual_seed(seed)
self.logger = get_module_logger("TRA")
self.logger.info("TRA Model...")
self.model = eval(model_type)(**model_config).to(device)
if model_init_state:
self.model.load_state_dict(torch.load(model_init_state, map_location="cpu")["model"])
if freeze_model:
for param in self.model.parameters():
param.requires_grad_(False)
else:
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters()]))
self.tra = TRA(self.model.output_size, **tra_config).to(device)
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters()]))
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=lr)
self.model_config = model_config
self.tra_config = tra_config
self.lr = lr
self.n_epochs = n_epochs
self.early_stop = early_stop
self.smooth_steps = smooth_steps
self.max_steps_per_epoch = max_steps_per_epoch
self.lamb = lamb
self.rho = rho
self.seed = seed
self.logdir = logdir
self.eval_train = eval_train
self.eval_test = eval_test
self.avg_params = avg_params
if self.tra.num_states > 1 and not self.eval_train:
self.logger.warn("`eval_train` will be ignored when using TRA")
if self.logdir is not None:
if os.path.exists(self.logdir):
self.logger.warn(f"logdir {self.logdir} is not empty")
os.makedirs(self.logdir, exist_ok=True)
self.fitted = False
self.global_step = -1
def train_epoch(self, data_set):
self.model.train()
self.tra.train()
data_set.train()
max_steps = self.n_epochs
if self.max_steps_per_epoch is not None:
max_steps = min(self.max_steps_per_epoch, self.n_epochs)
count = 0
total_loss = 0
total_count = 0
for batch in tqdm(data_set, total=max_steps):
count += 1
if count > max_steps:
break
self.global_step += 1
data, label, index = batch["data"], batch["label"], batch["index"]
feature = data[:, :, : -self.tra.num_states]
hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]
hidden = self.model(feature)
pred, all_preds, prob = self.tra(hidden, hist_loss)
loss = (pred - label).pow(2).mean()
L = (all_preds.detach() - label[:, None]).pow(2)
L -= L.min(dim=-1, keepdim=True).values # normalize & ensure postive input
data_set.assign_data(index, L) # save loss to memory
if prob is not None:
P = sinkhorn(-L, epsilon=0.01) # sample assignment matrix
lamb = self.lamb * (self.rho ** self.global_step)
reg = prob.log().mul(P).sum(dim=-1).mean()
loss = loss - lamb * reg
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
total_loss += loss.item()
total_count += len(pred)
total_loss /= total_count
return total_loss
def test_epoch(self, data_set, return_pred=False):
self.model.eval()
self.tra.eval()
data_set.eval()
preds = []
metrics = []
for batch in tqdm(data_set):
data, label, index = batch["data"], batch["label"], batch["index"]
feature = data[:, :, : -self.tra.num_states]
hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]
with torch.no_grad():
hidden = self.model(feature)
pred, all_preds, prob = self.tra(hidden, hist_loss)
L = (all_preds - label[:, None]).pow(2)
L -= L.min(dim=-1, keepdim=True).values # normalize & ensure postive input
data_set.assign_data(index, L) # save loss to memory
X = np.c_[
pred.cpu().numpy(),
label.cpu().numpy(),
]
columns = ["score", "label"]
if prob is not None:
X = np.c_[X, all_preds.cpu().numpy(), prob.cpu().numpy()]
columns += ["score_%d" % d for d in range(all_preds.shape[1])] + [
"prob_%d" % d for d in range(all_preds.shape[1])
]
pred = pd.DataFrame(X, index=index.cpu().numpy(), columns=columns)
metrics.append(evaluate(pred))
if return_pred:
preds.append(pred)
metrics = pd.DataFrame(metrics)
metrics = {
"MSE": metrics.MSE.mean(),
"MAE": metrics.MAE.mean(),
"IC": metrics.IC.mean(),
"ICIR": metrics.IC.mean() / metrics.IC.std(),
}
if return_pred:
preds = pd.concat(preds, axis=0)
preds.index = data_set.restore_index(preds.index)
preds.index = preds.index.swaplevel()
preds.sort_index(inplace=True)
return metrics, preds
def fit(self, dataset, evals_result=dict()):
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
best_score = -1
best_epoch = 0
stop_rounds = 0
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
params_list = {
"model": collections.deque(maxlen=self.smooth_steps),
"tra": collections.deque(maxlen=self.smooth_steps),
}
evals_result["train"] = []
evals_result["valid"] = []
evals_result["test"] = []
# train
self.fitted = True
self.global_step = -1
if self.tra.num_states > 1:
self.logger.info("init memory...")
self.test_epoch(train_set)
for epoch in range(self.n_epochs):
self.logger.info("Epoch %d:", epoch)
self.logger.info("training...")
self.train_epoch(train_set)
self.logger.info("evaluating...")
# average params for inference
params_list["model"].append(copy.deepcopy(self.model.state_dict()))
params_list["tra"].append(copy.deepcopy(self.tra.state_dict()))
self.model.load_state_dict(average_params(params_list["model"]))
self.tra.load_state_dict(average_params(params_list["tra"]))
# NOTE: during evaluating, the whole memory will be refreshed
if self.tra.num_states > 1 or self.eval_train:
train_set.clear_memory() # NOTE: clear the shared memory
train_metrics = self.test_epoch(train_set)[0]
evals_result["train"].append(train_metrics)
self.logger.info("\ttrain metrics: %s" % train_metrics)
valid_metrics = self.test_epoch(valid_set)[0]
evals_result["valid"].append(valid_metrics)
self.logger.info("\tvalid metrics: %s" % valid_metrics)
if self.eval_test:
test_metrics = self.test_epoch(test_set)[0]
evals_result["test"].append(test_metrics)
self.logger.info("\ttest metrics: %s" % test_metrics)
if valid_metrics["IC"] > best_score:
best_score = valid_metrics["IC"]
stop_rounds = 0
best_epoch = epoch
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
else:
stop_rounds += 1
if stop_rounds >= self.early_stop:
self.logger.info("early stop @ %s" % epoch)
break
# restore parameters
self.model.load_state_dict(params_list["model"][-1])
self.tra.load_state_dict(params_list["tra"][-1])
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_params["model"])
self.tra.load_state_dict(best_params["tra"])
metrics, preds = self.test_epoch(test_set, return_pred=True)
self.logger.info("test metrics: %s" % metrics)
if self.logdir:
self.logger.info("save model & pred to local directory")
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
self.logdir + "/logs.csv", index=False
)
torch.save(best_params, self.logdir + "/model.bin")
preds.to_pickle(self.logdir + "/pred.pkl")
info = {
"config": {
"model_config": self.model_config,
"tra_config": self.tra_config,
"lr": self.lr,
"n_epochs": self.n_epochs,
"early_stop": self.early_stop,
"smooth_steps": self.smooth_steps,
"max_steps_per_epoch": self.max_steps_per_epoch,
"lamb": self.lamb,
"rho": self.rho,
"seed": self.seed,
"logdir": self.logdir,
},
"best_eval_metric": -best_score, # NOTE: minux -1 for minimize
"metric": metrics,
}
with open(self.logdir + "/info.json", "w") as f:
json.dump(info, f)
def predict(self, dataset, segment="test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
test_set = dataset.prepare(segment)
metrics, preds = self.test_epoch(test_set, return_pred=True)
self.logger.info("test metrics: %s" % metrics)
return preds
class LSTM(nn.Module):
"""LSTM Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of hidden layers
use_attn (bool): whether use attention layer.
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
dropout (float): dropout rate
input_drop (float): input dropout for data augmentation
noise_level (float): add gaussian noise to input for data augmentation
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
use_attn=True,
dropout=0.0,
input_drop=0.0,
noise_level=0.0,
*args,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.use_attn = use_attn
self.noise_level = noise_level
self.input_drop = nn.Dropout(input_drop)
self.rnn = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
if self.use_attn:
self.W = nn.Linear(hidden_size, hidden_size)
self.u = nn.Linear(hidden_size, 1, bias=False)
self.output_size = hidden_size * 2
else:
self.output_size = hidden_size
def forward(self, x):
x = self.input_drop(x)
if self.training and self.noise_level > 0:
noise = torch.randn_like(x).to(x)
x = x + noise * self.noise_level
rnn_out, _ = self.rnn(x)
last_out = rnn_out[:, -1]
if self.use_attn:
laten = self.W(rnn_out).tanh()
scores = self.u(laten).softmax(dim=1)
att_out = (rnn_out * scores).sum(dim=1).squeeze()
last_out = torch.cat([last_out, att_out], dim=1)
return last_out
class PositionalEncoding(nn.Module):
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class Transformer(nn.Module):
"""Transformer Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of transformer layers
num_heads (int): number of heads in transformer
dropout (float): dropout rate
input_drop (float): input dropout for data augmentation
noise_level (float): add gaussian noise to input for data augmentation
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
num_heads=2,
dropout=0.0,
input_drop=0.0,
noise_level=0.0,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.noise_level = noise_level
self.input_drop = nn.Dropout(input_drop)
self.input_proj = nn.Linear(input_size, hidden_size)
self.pe = PositionalEncoding(input_size, dropout)
layer = nn.TransformerEncoderLayer(
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.output_size = hidden_size
def forward(self, x):
x = self.input_drop(x)
if self.training and self.noise_level > 0:
noise = torch.randn_like(x).to(x)
x = x + noise * self.noise_level
x = x.permute(1, 0, 2).contiguous() # the first dim need to be sequence
x = self.pe(x)
x = self.input_proj(x)
out = self.encoder(x)
return out[-1]
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction erros & latent representation as inputs,
then routes the input sample to a specific predictor for training & inference.
Args:
input_size (int): input size (RNN/Transformer's hidden size)
num_states (int): number of latent states (i.e., trading patterns)
If `num_states=1`, then TRA falls back to traditional methods
hidden_size (int): hidden size of the router
tau (float): gumbel softmax temperature
"""
def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"):
super().__init__()
self.num_states = num_states
self.tau = tau
self.src_info = src_info
if num_states > 1:
self.router = nn.LSTM(
input_size=num_states,
hidden_size=hidden_size,
num_layers=1,
batch_first=True,
)
self.fc = nn.Linear(hidden_size + input_size, num_states)
self.predictors = nn.Linear(input_size, num_states)
def forward(self, hidden, hist_loss):
preds = self.predictors(hidden)
if self.num_states == 1:
return preds.squeeze(-1), preds, None
# information type
router_out, _ = self.router(hist_loss)
if "LR" in self.src_info:
latent_representation = hidden
else:
latent_representation = torch.randn(hidden.shape).to(hidden)
if "TPE" in self.src_info:
temporal_pred_error = router_out[:, -1]
else:
temporal_pred_error = torch.randn(router_out[:, -1].shape).to(hidden)
out = self.fc(torch.cat([temporal_pred_error, latent_representation], dim=-1))
prob = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=False)
if self.training:
final_pred = (preds * prob).sum(dim=-1)
else:
final_pred = preds[range(len(preds)), prob.argmax(dim=-1)]
return final_pred, preds, prob
def evaluate(pred):
pred = pred.rank(pct=True) # transform into percentiles
score = pred.score
label = pred.label
diff = score - label
MSE = (diff ** 2).mean()
MAE = (diff.abs()).mean()
IC = score.corr(label)
return {"MSE": MSE, "MAE": MAE, "IC": IC}
def average_params(params_list):
assert isinstance(params_list, (tuple, list, collections.deque))
n = len(params_list)
if n == 1:
return params_list[0]
new_params = collections.OrderedDict()
keys = None
for i, params in enumerate(params_list):
if keys is None:
keys = params.keys()
for k, v in params.items():
if k not in keys:
raise ValueError("the %d-th model has different params" % i)
if k not in new_params:
new_params[k] = v / n
else:
new_params[k] += v / n
return new_params
def shoot_infs(inp_tensor):
"""Replaces inf by maximum of tensor"""
mask_inf = torch.isinf(inp_tensor)
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
if len(ind_inf) > 0:
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = 0
elif len(ind) == 1:
inp_tensor[ind[0]] = 0
m = torch.max(inp_tensor)
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = m
elif len(ind) == 1:
inp_tensor[ind[0]] = m
return inp_tensor
def sinkhorn(Q, n_iters=3, epsilon=0.01):
# epsilon should be adjusted according to logits value's scale
with torch.no_grad():
Q = shoot_infs(Q)
Q = torch.exp(Q / epsilon)
for i in range(n_iters):
Q /= Q.sum(dim=0, keepdim=True)
Q /= Q.sum(dim=1, keepdim=True)
return Q

View File

@@ -1,7 +1,5 @@
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
from qlib.data.dataset.processor import Processor
from qlib.utils import get_cls_kwargs
from qlib.log import TimeInspector
from qlib.contrib.data.handler import check_transform_proc
class HighFreqHandler(DataHandlerLP):
@@ -16,20 +14,9 @@ class HighFreqHandler(DataHandlerLP):
fit_end_time=None,
drop_raw=True,
):
def check_transform_proc(proc_l):
new_l = []
for p in proc_l:
p["kwargs"].update(
{
"fit_start_time": fit_start_time,
"fit_end_time": fit_end_time,
}
)
new_l.append(p)
return new_l
infer_processors = check_transform_proc(infer_processors)
learn_processors = check_transform_proc(learn_processors)
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
"class": "QlibDataLoader",

View File

@@ -26,7 +26,7 @@ def get_calendar_day(freq="day", future=False):
if flag in H["c"]:
_calendar = H["c"][flag]
else:
_calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
_calendar = np.array(list(map(lambda x: pd.Timestamp(x.date()), Cal.load_calendar(freq, future))))
H["c"][flag] = _calendar
return _calendar

View File

@@ -33,7 +33,7 @@ class HighfreqWorkflow:
"fit_start_time": start_time,
"fit_end_time": train_end_time,
"instruments": MARKET,
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}],
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor"}],
}
DATA_HANDLER_CONFIG1 = {
"start_time": start_time,

View File

@@ -4,6 +4,7 @@
"""
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be shown in task_collecting.
Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
"""
from pprint import pprint
@@ -13,7 +14,7 @@ import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
@@ -68,6 +69,11 @@ class RollingTaskExample:
trainer = TrainerRM(self.experiment_name, self.task_pool)
trainer.train(tasks)
def worker(self):
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
print("========== worker ==========")
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
def task_collecting(self):
print("========== task_collecting ==========")

View File

@@ -5,6 +5,7 @@
This example is about how can simulate the OnlineManager based on rolling tasks.
"""
from pprint import pprint
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
@@ -13,7 +14,7 @@ from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
class OnlineSimulationExample:
@@ -22,8 +23,8 @@ class OnlineSimulationExample:
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
exp_name="rolling_exp",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
task_pool="rolling_task",
rolling_step=80,
start_time="2018-09-10",
@@ -46,7 +47,7 @@ class OnlineSimulationExample:
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
@@ -59,7 +60,7 @@ class OnlineSimulationExample:
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
@@ -85,6 +86,15 @@ class OnlineSimulationExample:
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
def worker(self):
# train tasks by other progress or machines for multiprocessing
# FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
print("========== worker ==========")
if isinstance(self.trainer, TrainerRM):
self.trainer.worker()
else:
print(f"{type(self.trainer)} is not supported for worker.")
if __name__ == "__main__":
## to run all workflow automatically with your own parameters, use the command below

View File

@@ -13,11 +13,13 @@ Finally, the OnlineManager will finish second routine and update all strategies.
import os
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.online.manager import OnlineManager
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
from qlib.workflow.task.manage import TaskManager
class RollingOnlineExample:
@@ -25,16 +27,17 @@ class RollingOnlineExample:
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
rolling_step=550,
tasks=None,
add_tasks=None,
):
if add_tasks is None:
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
@@ -53,17 +56,28 @@ class RollingOnlineExample:
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager = OnlineManager(strategies)
self.trainer = trainer
self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
_ROLLING_MANAGER_PATH = (
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
)
def worker(self):
# train tasks by other progress or machines for multiprocessing
print("========== worker ==========")
if isinstance(self.trainer, TrainerRM):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
self.trainer.worker(experiment_name=name_id)
else:
print(f"{type(self.trainer)} is not supported for worker.")
# Reset all things to the first status, be careful to save important data
def reset(self):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
TaskManager(task_pool=name_id).remove()
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)

View File

@@ -220,7 +220,7 @@
"\n",
"# backtest and analysis\n",
"with R.start(experiment_name=\"backtest_analysis\"):\n",
" recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
" recorder = R.get_recorder(recorder_id=rid, experiment_name=\"train_model\")\n",
" model = recorder.load_object(\"trained_model\")\n",
"\n",
" # prediction\n",
@@ -249,7 +249,7 @@
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
"from qlib.data import D\n",
"recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n",
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
"pred_df = recorder.load_object(\"pred.pkl\")\n",
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal.pkl\")\n",

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
__version__ = "0.6.3.99"
__version__ = "0.7.0"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
@@ -20,11 +20,17 @@ def init(default_conf="client", **kwargs):
from .config import C
from .data.cache import H
H.clear()
# FIXME: this logger ignored the level in config
logger = get_module_logger("Initialization", level=logging.INFO)
skip_if_reg = kwargs.pop("skip_if_reg", False)
if skip_if_reg and C.registered:
# if we reinitialize Qlib during running an experiment `R.start`.
# it will result in loss of the recorder
logger.warning("Skip initialization because `skip_if_reg is True`")
return
H.clear()
C.set(default_conf, **kwargs)
# check path if server/local
@@ -197,14 +203,15 @@ def auto_init(**kwargs):
- Find the project configuration and init qlib
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
- Skip initialization if already initialized
"""
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
try:
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
except FileNotFoundError:
init(**kwargs)
else:
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)

View File

@@ -195,7 +195,10 @@ MODE_CONF = {
"timeout": 100,
"logging_level": logging.INFO,
"region": REG_CN,
## Custom Operator
# custom operator
# each element of custom_ops should be Type[ExpressionOps] or dict
# if element of custom_ops is Type[ExpressionOps], it represents the custom operator class
# if element of custom_ops is dict, it represents the config of custom operator and should include `class` and `module_path` keys.
"custom_ops": [],
},
}

View File

@@ -26,8 +26,10 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
"fit_end_time": fit_end_time,
}
)
# FIXME: the `module_path` parameter is missed.
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
proc_config = {"class": klass.__name__, "kwargs": pkwargs}
if isinstance(p, dict) and "module_path" in p:
proc_config["module_path"] = p["module_path"]
new_l.append(proc_config)
else:
new_l.append(p)
return new_l

View File

@@ -53,7 +53,6 @@ class GATs(Model):
early_stop=20,
loss="mse",
base_model="GRU",
with_pretrain=True,
model_path=None,
optimizer="adam",
GPU=0,
@@ -76,7 +75,6 @@ class GATs(Model):
self.optimizer = optimizer.lower()
self.loss = loss
self.base_model = base_model
self.with_pretrain = with_pretrain
self.model_path = model_path
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
@@ -94,7 +92,6 @@ class GATs(Model):
"\noptimizer : {}"
"\nloss_type : {}"
"\nbase_model : {}"
"\nwith_pretrain : {}"
"\nmodel_path : {}"
"\ndevice : {}"
"\nuse_GPU : {}"
@@ -110,7 +107,6 @@ class GATs(Model):
optimizer.lower(),
loss,
base_model,
with_pretrain,
model_path,
self.device,
self.use_gpu,
@@ -253,24 +249,22 @@ class GATs(Model):
evals_result["valid"] = []
# load pretrained base_model
if self.with_pretrain:
if self.model_path == None:
raise ValueError("the path of the pretrained model should be given first!")
self.logger.info("Loading pretrained model...")
if self.base_model == "LSTM":
pretrained_model = LSTMModel()
pretrained_model.load_state_dict(torch.load(self.model_path))
elif self.base_model == "GRU":
pretrained_model = GRUModel()
pretrained_model.load_state_dict(torch.load(self.model_path))
else:
raise ValueError("unknown base model name `%s`" % self.base_model)
if self.base_model == "LSTM":
pretrained_model = LSTMModel()
elif self.base_model == "GRU":
pretrained_model = GRUModel()
else:
raise ValueError("unknown base model name `%s`" % self.base_model)
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")
if self.model_path is not None:
self.logger.info("Loading pretrained model...")
pretrained_model.load_state_dict(torch.load(self.model_path))
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")
# train
self.logger.info("training...")

View File

@@ -29,8 +29,8 @@ class DailyBatchSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
self.data = self.data_source.data.loc[self.data_source.get_index()]
self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch
# calculate number of samples in each batch
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
self.daily_index[0] = 0
@@ -72,7 +72,6 @@ class GATs(Model):
early_stop=20,
loss="mse",
base_model="GRU",
with_pretrain=True,
model_path=None,
optimizer="adam",
GPU="0",
@@ -96,7 +95,6 @@ class GATs(Model):
self.optimizer = optimizer.lower()
self.loss = loss
self.base_model = base_model
self.with_pretrain = with_pretrain
self.model_path = model_path
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
@@ -115,7 +113,6 @@ class GATs(Model):
"\noptimizer : {}"
"\nloss_type : {}"
"\nbase_model : {}"
"\nwith_pretrain : {}"
"\nmodel_path : {}"
"\nvisible_GPU : {}"
"\nuse_GPU : {}"
@@ -131,7 +128,6 @@ class GATs(Model):
optimizer.lower(),
loss,
base_model,
with_pretrain,
model_path,
GPU,
self.use_gpu,
@@ -270,28 +266,22 @@ class GATs(Model):
evals_result["valid"] = []
# load pretrained base_model
if self.with_pretrain:
if self.model_path == None:
raise ValueError("the path of the pretrained model should be given first!")
self.logger.info("Loading pretrained model...")
if self.base_model == "LSTM":
pretrained_model = LSTMModel(
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
)
pretrained_model.load_state_dict(torch.load(self.model_path))
elif self.base_model == "GRU":
pretrained_model = GRUModel(
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
)
pretrained_model.load_state_dict(torch.load(self.model_path))
else:
raise ValueError("unknown base model name `%s`" % self.base_model)
if self.base_model == "LSTM":
pretrained_model = LSTMModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
elif self.base_model == "GRU":
pretrained_model = GRUModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
else:
raise ValueError("unknown base model name `%s`" % self.base_model)
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")
if self.model_path is not None:
self.logger.info("Loading pretrained model...")
pretrained_model.load_state_dict(torch.load(self.model_path))
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")
# train
self.logger.info("training...")

View File

@@ -297,7 +297,7 @@ class DNNModelPytorch(Model):
_model_path = os.path.join(model_dir, _model_name)
# Load model
self.dnn_model.load_state_dict(torch.load(_model_path))
self._fitted = True
self.fitted = True
class AverageMeter:

View File

@@ -0,0 +1,420 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
import random
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
import torch
import torch.nn as nn
import torch.optim as optim
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
class TCTS(Model):
"""TCTS Model
Parameters
----------
d_feat : int
input dimension for each time step
metric: str
the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
"""
def __init__(
self,
d_feat=6,
hidden_size=64,
num_layers=2,
dropout=0.0,
n_epochs=200,
batch_size=2000,
early_stop=20,
loss="mse",
fore_optimizer="adam",
weight_optimizer="adam",
output_dim=5,
fore_lr=5e-7,
weight_lr=5e-7,
steps=3,
GPU=0,
seed=0,
target_label=0,
lowest_valid_performance=0.993,
**kwargs
):
# Set logger.
self.logger = get_module_logger("TCTS")
self.logger.info("TCTS pytorch version...")
# set hyper-parameters.
self.d_feat = d_feat
self.hidden_size = hidden_size
self.num_layers = num_layers
self.dropout = dropout
self.n_epochs = n_epochs
self.batch_size = batch_size
self.early_stop = early_stop
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.use_gpu = torch.cuda.is_available()
self.seed = seed
self.output_dim = output_dim
self.fore_lr = fore_lr
self.weight_lr = weight_lr
self.steps = steps
self.target_label = target_label
self.lowest_valid_performance = lowest_valid_performance
self._fore_optimizer = fore_optimizer
self._weight_optimizer = weight_optimizer
self.logger.info(
"TCTS parameters setting:"
"\nd_feat : {}"
"\nhidden_size : {}"
"\nnum_layers : {}"
"\ndropout : {}"
"\nn_epochs : {}"
"\nbatch_size : {}"
"\nearly_stop : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
d_feat,
hidden_size,
num_layers,
dropout,
n_epochs,
batch_size,
early_stop,
loss,
GPU,
self.use_gpu,
seed,
)
)
def loss_fn(self, pred, label, weight):
loc = torch.argmax(weight, 1)
loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2
return torch.mean(loss)
def train_epoch(self, x_train, y_train, x_valid, y_valid):
x_train_values = x_train.values
y_train_values = np.squeeze(y_train.values)
indices = np.arange(len(x_train_values))
np.random.shuffle(indices)
init_fore_model = copy.deepcopy(self.fore_model)
for p in init_fore_model.parameters():
p.init_fore_model = False
self.fore_model.train()
self.weight_model.train()
for p in self.weight_model.parameters():
p.requires_grad = False
for p in self.fore_model.parameters():
p.requires_grad = True
for i in range(self.steps):
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
init_pred = init_fore_model(feature)
pred = self.fore_model(feature)
dis = init_pred - label.transpose(0, 1)
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, init_pred.view(-1, 1)), 1)
weight = self.weight_model(weight_feature)
loss = self.loss_fn(pred, label, weight) # hard
self.fore_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.fore_model.parameters(), 3.0)
self.fore_optimizer.step()
x_valid_values = x_valid.values
y_valid_values = np.squeeze(y_valid.values)
indices = np.arange(len(x_valid_values))
np.random.shuffle(indices)
for p in self.weight_model.parameters():
p.requires_grad = True
for p in self.fore_model.parameters():
p.requires_grad = False
# fix forecasting model and valid weight model
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.fore_model(feature)
dis = pred - label.transpose(0, 1)
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1)), 1)
weight = self.weight_model(weight_feature)
loc = torch.argmax(weight, 1)
valid_loss = torch.mean((pred - label[:, 0]) ** 2)
loss = torch.mean(-valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))
self.weight_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.weight_model.parameters(), 3.0)
self.weight_optimizer.step()
def test_epoch(self, data_x, data_y):
# prepare training data
x_values = data_x.values
y_values = np.squeeze(data_y.values)
self.fore_model.eval()
scores = []
losses = []
indices = np.arange(len(x_values))
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.fore_model(feature)
loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2)
losses.append(loss.item())
return np.mean(losses)
def fit(
self,
dataset: DatasetH,
verbose=True,
save_path=None,
):
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
x_test, y_test = df_test["feature"], df_test["label"]
if save_path == None:
save_path = get_or_create_path(save_path)
best_loss = np.inf
while best_loss > self.lowest_valid_performance:
if best_loss < np.inf:
print("Failed! Start retraining.")
self.seed = random.randint(0, 1000) # reset random seed
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
best_loss = self.training(
x_train, y_train, x_valid, y_valid, x_test, y_test, verbose=verbose, save_path=save_path
)
def training(
self,
x_train,
y_train,
x_valid,
y_valid,
x_test,
y_test,
verbose=True,
save_path=None,
):
self.fore_model = GRUModel(
d_feat=self.d_feat,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
dropout=self.dropout,
)
self.weight_model = MLPModel(
d_feat=360 + 2 * self.output_dim + 1,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
dropout=self.dropout,
output_dim=self.output_dim,
)
if self._fore_optimizer.lower() == "adam":
self.fore_optimizer = optim.Adam(self.fore_model.parameters(), lr=self.fore_lr)
elif self._fore_optimizer.lower() == "gd":
self.fore_optimizer = optim.SGD(self.fore_model.parameters(), lr=self.fore_lr)
else:
raise NotImplementedError("optimizer {} is not supported!".format(self._fore_optimizer))
if self._weight_optimizer.lower() == "adam":
self.weight_optimizer = optim.Adam(self.weight_model.parameters(), lr=self.weight_lr)
elif self._weight_optimizer.lower() == "gd":
self.weight_optimizer = optim.SGD(self.weight_model.parameters(), lr=self.weight_lr)
else:
raise NotImplementedError("optimizer {} is not supported!".format(self._weight_optimizer))
self.fitted = False
self.fore_model.to(self.device)
self.weight_model.to(self.device)
best_loss = np.inf
best_epoch = 0
stop_round = 0
fore_best_param = copy.deepcopy(self.fore_optimizer.state_dict())
weight_best_param = copy.deepcopy(self.weight_optimizer.state_dict())
for epoch in range(self.n_epochs):
print("Epoch:", epoch)
print("training...")
self.train_epoch(x_train, y_train, x_valid, y_valid)
print("evaluating...")
val_loss = self.test_epoch(x_valid, y_valid)
test_loss = self.test_epoch(x_test, y_test)
if verbose:
print("valid %.6f, test %.6f" % (val_loss, test_loss))
if val_loss < best_loss:
best_loss = val_loss
stop_round = 0
best_epoch = epoch
torch.save(copy.deepcopy(self.fore_model.state_dict()), save_path + "_fore_model.bin")
torch.save(copy.deepcopy(self.weight_model.state_dict()), save_path + "_weight_model.bin")
else:
stop_round += 1
if stop_round >= self.early_stop:
print("early stop")
break
print("best loss:", best_loss, "@", best_epoch)
best_param = torch.load(save_path + "_fore_model.bin")
self.fore_model.load_state_dict(best_param)
best_param = torch.load(save_path + "_weight_model.bin")
self.weight_model.load_state_dict(best_param)
self.fitted = True
if self.use_gpu:
torch.cuda.empty_cache()
return best_loss
def predict(self, dataset):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
index = x_test.index
self.fore_model.eval()
x_values = x_test.values
sample_num = x_values.shape[0]
preds = []
for begin in range(sample_num)[:: self.batch_size]:
if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
if self.use_gpu:
pred = self.fore_model(x_batch).detach().cpu().numpy()
else:
pred = self.fore_model(x_batch).detach().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index)
class MLPModel(nn.Module):
def __init__(self, d_feat, hidden_size=256, num_layers=3, dropout=0.0, output_dim=1):
super().__init__()
self.mlp = nn.Sequential()
self.softmax = nn.Softmax(dim=1)
for i in range(num_layers):
if i > 0:
self.mlp.add_module("drop_%d" % i, nn.Dropout(dropout))
self.mlp.add_module("fc_%d" % i, nn.Linear(d_feat if i == 0 else hidden_size, hidden_size))
self.mlp.add_module("relu_%d" % i, nn.ReLU())
self.mlp.add_module("fc_out", nn.Linear(hidden_size, output_dim))
def forward(self, x):
# feature
# [N, F]
out = self.mlp(x).squeeze()
out = self.softmax(out)
return out
class GRUModel(nn.Module):
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
super().__init__()
self.rnn = nn.GRU(
input_size=d_feat,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
self.fc_out = nn.Linear(hidden_size, 1)
self.d_feat = d_feat
def forward(self, x):
# x: [N, F*T]
x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
x = x.permute(0, 2, 1) # [N, T, F]
out, _ = self.rnn(x)
return self.fc_out(out[:, -1, :]).squeeze()

View File

@@ -62,7 +62,7 @@ class XGBModel(Model, FeatureInt):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
return pd.Series(self.model.predict(xgb.DMatrix(x_test)), index=x_test.index)
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance

View File

@@ -3,7 +3,6 @@
import pandas as pd
import plotly.tools as tls
import plotly.graph_objs as go
import statsmodels.api as sm
@@ -80,9 +79,35 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
:param dist:
:return:
"""
fig, ax = plt.subplots(figsize=(8, 5))
_mpl_fig = sm.qqplot(data.dropna(), dist, fit=True, line="45", ax=ax)
return tls.mpl_to_plotly(_mpl_fig)
# NOTE: plotly.tools.mpl_to_plotly not actively maintained, resulting in errors in the new version of matplotlib,
# ref: https://github.com/plotly/plotly.py/issues/2913#issuecomment-730071567
# removing plotly.tools.mpl_to_plotly for greater compatibility with matplotlib versions
_plt_fig = sm.qqplot(data.dropna(), dist=dist, fit=True, line="45")
plt.close(_plt_fig)
qqplot_data = _plt_fig.gca().lines
fig = go.Figure()
fig.add_trace(
{
"type": "scatter",
"x": qqplot_data[0].get_xdata(),
"y": qqplot_data[0].get_ydata(),
"mode": "markers",
"marker": {"color": "#19d3f3"},
}
)
fig.add_trace(
{
"type": "scatter",
"x": qqplot_data[1].get_xdata(),
"y": qqplot_data[1].get_ydata(),
"mode": "lines",
"line": {"color": "#636efa"},
}
)
del qqplot_data
return fig
def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple:

View File

@@ -148,7 +148,6 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
pred score for this trade date, index is stock_id, contain 'score' column.
current : Position()
current position.
trade_exchange : Exchange()
trade_date : pd.Timestamp
trade date.
"""

View File

@@ -237,7 +237,7 @@ class CacheUtils:
lock.acquire()
except redis_lock.AlreadyAcquired:
raise QlibCacheException(
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
You can use the following command to clear your redis keys and rerun your commands:
$ redis-cli
> select {C.redis_task_db}
@@ -784,10 +784,10 @@ class DiskDatasetCache(DatasetCache):
def build_index_from_data(data, start_index=0):
if data.empty:
return pd.DataFrame()
line_data = data.iloc[:, 0].fillna(0).groupby("datetime").count()
line_data = data.groupby("datetime").size()
line_data.sort_index(inplace=True)
index_end = line_data.cumsum()
index_start = index_end.shift(1).fillna(0)
index_start = index_end.shift(1, fill_value=0)
index_data = pd.DataFrame()
index_data["start"] = index_start

View File

@@ -1,6 +1,6 @@
from ...utils.serial import Serializable
from typing import Union, List, Tuple, Dict, Text, Optional
from ...utils import init_instance_by_config, np_ffill
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
from copy import deepcopy
@@ -243,6 +243,8 @@ class TSDataSampler:
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
dataset based on tabular data.
- On time step dimension, the smaller index indicates the historical data and the larger index indicates the future
data.
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
more powerful subclasses.
@@ -309,11 +311,19 @@ class TSDataSampler:
self.data_index = deepcopy(self.data.index)
if flt_data is not None:
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
if isinstance(flt_data, pd.DataFrame):
assert len(flt_data.columns) == 1
flt_data = flt_data.iloc[:, 0]
# NOTE: bool(np.nan) is True !!!!!!!!
# make sure reindex comes first. Otherwise extra NaN may appear.
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.start_idx, self.end_idx = self.data_index.slice_locs(
start=time_to_slc_point(start), end=time_to_slc_point(end)
)
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
del self.data # save memory
@@ -341,7 +351,7 @@ class TSDataSampler:
setattr(self, k, v)
@staticmethod
def build_index(data: pd.DataFrame) -> dict:
def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
"""
The relation of the data
@@ -352,9 +362,15 @@ class TSDataSampler:
Returns
-------
dict:
{<index>: <prev_index or None>}
# get the previous index of a line given index
Tuple[pd.DataFrame, dict]:
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
datetime
2021-01-11 0 1 2 3 4 5 ...
2021-01-12 4146 4147 4148 4149 4150 4151 ...
2021-01-13 8293 8294 8295 8296 8297 8298 ...
2021-01-14 12441 12442 12443 12444 12445 12446 ...
2) the second element: {<original index>: <row, col>}
"""
# object incase of pandas converting int to flaot
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
@@ -491,7 +507,9 @@ class TSDatasetH(DatasetH):
- The dimension of a batch of data <batch_idx, feature, timestep>
"""
def __init__(self, step_len=30, **kwargs):
DEFAULT_STEP_LEN = 30
def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
self.step_len = step_len
super().__init__(**kwargs)

View File

@@ -12,7 +12,7 @@ from typing import Tuple, Union
from qlib.data import D
from qlib.data import filter as filter_module
from qlib.data.filter import BaseDFilter
from qlib.utils import load_dataset, init_instance_by_config
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
from qlib.log import get_module_logger
@@ -207,7 +207,10 @@ class StaticDataLoader(DataLoader):
df = self._data.loc(axis=0)[:, instruments]
if start_time is None and end_time is None:
return df # NOTE: avoid copy by loc
return df.loc[pd.Timestamp(start_time) : pd.Timestamp(end_time)]
# pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
start_time = time_to_slc_point(start_time)
end_time = time_to_slc_point(end_time)
return df.loc[start_time:end_time]
def _maybe_load_raw_data(self):
if self._data is not None:

View File

@@ -10,10 +10,12 @@ import abc
import numpy as np
import pandas as pd
from typing import Union, List, Type
from scipy.stats import percentileofscore
from .base import Expression, ExpressionOps
from ..log import get_module_logger
from ..utils import get_cls_kwargs
try:
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
@@ -1495,16 +1497,34 @@ class OpsWrapper:
def reset(self):
self._ops = {}
def register(self, ops_list):
for operator in ops_list:
if not issubclass(operator, ExpressionOps):
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator))
def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]):
"""register operator
if operator.__name__ in self._ops:
Parameters
----------
ops_list : List[Union[Type[ExpressionOps], dict]]
- if type(ops_list) is List[Type[ExpressionOps]], each element of ops_list represents the operator class, which should be the subclass of `ExpressionOps`.
- if type(ops_list) is List[dict], each element of ops_list represents the config of operator, which has the following format:
{
"class": class_name,
"module_path": path,
}
Note: `class` should be the class name of operator, `module_path` should be a python module or path of file.
"""
for _operator in ops_list:
if isinstance(_operator, dict):
_ops_class, _ = get_cls_kwargs(_operator)
else:
_ops_class = _operator
if not issubclass(_ops_class, ExpressionOps):
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))
if _ops_class.__name__ in self._ops:
get_module_logger(self.__class__.__name__).warning(
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
"The custom operator [{}] will override the qlib default definition".format(_ops_class.__name__)
)
self._ops[operator.__name__] = operator
self._ops[_ops_class.__name__] = _ops_class
def __getattr__(self, key):
if key not in self._ops:

View File

@@ -28,16 +28,18 @@ class QlibLogger(metaclass=MetaLogger):
def __init__(self, module_name):
self.module_name = module_name
self.level = 0
# this feature name conflicts with the attribute with Logger
# rename it to avoid some corner cases that result in comparing `str` and `int`
self.__level = 0
@property
def logger(self):
logger = logging.getLogger(self.module_name)
logger.setLevel(self.level)
logger.setLevel(self.__level)
return logger
def setLevel(self, level):
self.level = level
self.__level = level
def __getattr__(self, name):
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
@@ -68,7 +70,7 @@ def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logge
class TimeInspector:
timer_logger = get_module_logger("timer", level=logging.WARNING)
timer_logger = get_module_logger("timer", level=logging.INFO)
time_marks = []

View File

@@ -97,7 +97,7 @@ class ModelFT(Model):
# Finetune model based on previous trained model
with R.start(experiment_name="finetune model"):
recorder = R.get_recorder(rid, experiment_name="init models")
recorder = R.get_recorder(recorder_id=rid, experiment_name="init models")
model = recorder.load_object("init_model")
model.finetune(dataset, num_boost_round=10)

View File

@@ -8,13 +8,15 @@ There are two steps in each Trainer including ``train``(make model recorder) and
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
"""
import socket
import time
from typing import Callable, List
from qlib.data.dataset import Dataset
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
from qlib.workflow import R
@@ -151,6 +153,9 @@ class Trainer:
"""
return self.delay
def __call__(self, *args, **kwargs) -> list:
return self.end_train(self.train(*args, **kwargs))
class TrainerR(Trainer):
"""
@@ -190,6 +195,8 @@ class TrainerR(Trainer):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
@@ -213,6 +220,8 @@ class TrainerR(Trainer):
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(recs, Recorder):
recs = [recs]
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
@@ -250,6 +259,8 @@ class DelayTrainerR(TrainerR):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(recs, Recorder):
recs = [recs]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
@@ -275,7 +286,12 @@ class TrainerRM(Trainer):
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
# This tag is the _id in TaskManager to distinguish tasks.
TM_ID = "_id in TaskManager"
def __init__(
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
):
"""
Init TrainerR.
@@ -283,11 +299,16 @@ class TrainerRM(Trainer):
experiment_name (str): the default name of experiment.
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default training method. Defaults to `task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
"""
super().__init__()
self.experiment_name = experiment_name
self.task_pool = task_pool
self.train_func = train_func
self.skip_run_task = skip_run_task
def train(
self,
@@ -315,6 +336,8 @@ class TrainerRM(Trainer):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
@@ -326,19 +349,26 @@ class TrainerRM(Trainer):
task_pool = experiment_name
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
run_task(
train_func,
task_pool,
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
query = {"_id": {"$in": _id_list}}
if not self.skip_run_task:
run_task(
train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
if not self.is_delay():
tm.wait(query=query)
recs = []
for _id in _id_list:
rec = tm.re_query(_id)["res"]
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
rec.set_tags(**{self.TM_ID: _id})
recs.append(rec)
return recs
@@ -352,10 +382,33 @@ class TrainerRM(Trainer):
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(recs, Recorder):
recs = [recs]
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
def worker(
self,
train_func: Callable = None,
experiment_name: str = None,
):
"""
The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.
Args:
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
"""
if train_func is None:
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
class DelayTrainerRM(TrainerRM):
"""
@@ -369,6 +422,7 @@ class DelayTrainerRM(TrainerRM):
task_pool: str = None,
train_func=begin_task_train,
end_train_func=end_task_train,
skip_run_task: bool = False,
):
"""
Init DelayTrainerRM.
@@ -378,10 +432,15 @@ class DelayTrainerRM(TrainerRM):
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
"""
super().__init__(experiment_name, task_pool, train_func)
self.end_train_func = end_train_func
self.delay = True
self.skip_run_task = skip_run_task
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
@@ -395,6 +454,8 @@ class DelayTrainerRM(TrainerRM):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
return super().train(
@@ -410,8 +471,6 @@ class DelayTrainerRM(TrainerRM):
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
Args:
recs (list): a list of Recorder, the tasks have been saved to them.
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
@@ -421,7 +480,8 @@ class DelayTrainerRM(TrainerRM):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(recs, Recorder):
recs = [recs]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
@@ -429,18 +489,45 @@ class DelayTrainerRM(TrainerRM):
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
tasks = []
_id_list = []
for rec in recs:
tasks.append(rec.load_object("task"))
_id_list.append(rec.list_tags()[self.TM_ID])
query = {"_id": {"$in": _id_list}}
if not self.skip_run_task:
run_task(
end_train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
TaskManager(task_pool=task_pool).wait(query=query)
run_task(
end_train_func,
task_pool,
query={"filter": {"$in": tasks}}, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
def worker(self, end_train_func=None, experiment_name: str = None):
"""
The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.
Args:
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
"""
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
run_task(
end_train_func,
task_pool=task_pool,
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
)

View File

@@ -43,17 +43,29 @@ RECORD_CONFIG = [
]
def get_data_handler_config(market=CSI300_MARKET):
def get_data_handler_config(
start_time="2008-01-01",
end_time="2020-08-01",
fit_start_time="2008-01-01",
fit_end_time="2014-12-31",
instruments=CSI300_MARKET,
):
return {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
"start_time": start_time,
"end_time": end_time,
"fit_start_time": fit_start_time,
"fit_end_time": fit_end_time,
"instruments": instruments,
}
def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
def get_dataset_config(
dataset_class=DATASET_ALPHA158_CLASS,
train=("2008-01-01", "2014-12-31"),
valid=("2015-01-01", "2016-12-31"),
test=("2017-01-01", "2020-08-01"),
handler_kwargs={"instruments": CSI300_MARKET},
):
return {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
@@ -61,48 +73,88 @@ def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLAS
"handler": {
"class": dataset_class,
"module_path": "qlib.contrib.data.handler",
"kwargs": get_data_handler_config(market),
"kwargs": get_data_handler_config(**handler_kwargs),
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
"train": train,
"valid": valid,
"test": test,
},
},
}
def get_gbdt_task(market=CSI300_MARKET):
def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": GBDT_MODEL,
"dataset": get_dataset_config(market),
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
}
def get_record_lgb_config(market=CSI300_MARKET):
def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": get_dataset_config(market),
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
"record": RECORD_CONFIG,
}
def get_record_xgboost_config(market=CSI300_MARKET):
def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": get_dataset_config(market),
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
"record": RECORD_CONFIG,
}
CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET})
CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET})
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET})
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET})
# use for rolling_online_managment.py
ROLLING_HANDLER_CONFIG = {
"start_time": "2013-01-01",
"end_time": "2020-09-25",
"fit_start_time": "2013-01-01",
"fit_end_time": "2014-12-31",
"instruments": CSI100_MARKET,
}
ROLLING_DATASET_CONFIG = {
"train": ("2013-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2015-12-31"),
"test": ("2016-01-01", "2020-07-10"),
}
CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config(
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
)
CSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config(
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
)
# use for online_management_simulate.py
ONLINE_HANDLER_CONFIG = {
"start_time": "2018-01-01",
"end_time": "2018-10-31",
"fit_start_time": "2018-01-01",
"fit_end_time": "2018-03-31",
"instruments": CSI100_MARKET,
}
ONLINE_DATASET_CONFIG = {
"train": ("2018-01-01", "2018-03-31"),
"valid": ("2018-04-01", "2018-05-31"),
"test": ("2018-06-01", "2018-09-10"),
}
CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config(
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
)
CSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config(
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
)

View File

@@ -642,6 +642,28 @@ def split_pred(pred, number=None, split_date=None):
return pred_left, pred_right
def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:
"""
Time slicing in Qlib or Pandas is a frequently-used action.
However, user often input all kinds of data format to represent time.
This function will help user to convert these inputs into a uniform format which is friendly to time slicing.
Parameters
----------
t : Union[None, str, pd.Timestamp]
original time
Returns
-------
Union[None, pd.Timestamp]:
"""
if t is None:
# None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]).
return t
else:
return pd.Timestamp(t)
def can_use_cache():
res = True
r = get_redis_connection()

17
qlib/utils/exceptions.py Normal file
View File

@@ -0,0 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Base exception class
class QlibException(Exception):
def __init__(self, message):
super(QlibException, self).__init__(message)
# Error type for reinitialization when starting an experiment
class RecorderInitializationError(QlibException):
pass
# Error type for Recorder when can not load object
class LoadObjectError(QlibException):
pass

View File

@@ -92,16 +92,16 @@ class Serializable:
@classmethod
def load(cls, filepath):
"""
Load the collector from a filepath.
Load the serializable class from a filepath.
Args:
filepath (str): the path of file
Raises:
TypeError: the pickled file must be `Collector`
TypeError: the pickled file must be `type(cls)`
Returns:
Collector: the instance of Collector
`type(cls)`: the instance of `type(cls)`
"""
with open(filepath, "rb") as f:
object = cls.get_backend().load(f)

View File

@@ -7,6 +7,7 @@ from .expm import MLflowExpManager
from .exp import Experiment
from .recorder import Recorder
from ..utils import Wrapper
from ..utils.exceptions import RecorderInitializationError
class QlibRecorder:
@@ -215,9 +216,9 @@ class QlibRecorder:
-------
A dictionary (id -> recorder) of recorder information that being stored.
"""
return self.get_exp(experiment_id, experiment_name).list_recorders()
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
"""
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
@@ -262,7 +263,7 @@ class QlibRecorder:
# Case 2
with R.start('test'):
exp = R.get_exp('test1')
exp = R.get_exp(experiment_name='test1')
# Case 3
exp = R.get_exp() -> a default experiment.
@@ -287,7 +288,9 @@ class QlibRecorder:
-------
An experiment instance with given id or name.
"""
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
return self.exp_manager.get_exp(
experiment_id=experiment_id, experiment_name=experiment_name, create=create, start=False
)
def delete_exp(self, experiment_id=None, experiment_name=None):
"""
@@ -331,7 +334,9 @@ class QlibRecorder:
"""
self.exp_manager.set_uri(uri)
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
def get_recorder(
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
) -> Recorder:
"""
Method for retrieving a recorder.
@@ -384,7 +389,7 @@ class QlibRecorder:
-------
A recorder instance.
"""
return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder(
recorder_id, recorder_name, create=False, start=False
)
@@ -525,14 +530,29 @@ class QlibRecorder:
self.get_exp().get_recorder().set_tags(**kwargs)
class RecorderWrapper(Wrapper):
"""
Wrapper class for QlibRecorder, which detects whether users reinitialize qlib when already starting an experiment.
"""
def register(self, provider):
if self._provider is not None:
expm = getattr(self._provider, "exp_manager")
if expm.active_experiment is not None:
raise RecorderInitializationError(
"Please don't reinitialize Qlib if QlibRecorder is already acivated. Otherwise, the experiment stored location will be modified."
)
self._provider = provider
import sys
if sys.version_info >= (3, 9):
from typing import Annotated
QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper]
QlibRecorderWrapper = Annotated[QlibRecorder, RecorderWrapper]
else:
QlibRecorderWrapper = QlibRecorder
# global record
R: QlibRecorderWrapper = Wrapper()
R: QlibRecorderWrapper = RecorderWrapper()

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
import mlflow, logging
from mlflow.entities import ViewType
from mlflow.exceptions import MlflowException
@@ -213,11 +214,15 @@ class Experiment:
"""
raise NotImplementedError(f"Please implement the `_get_recorder` method")
def list_recorders(self):
def list_recorders(self, **flt_kwargs):
"""
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
flt_kwargs : dict
filter recorders by conditions
e.g. list_recorders(status=Recorder.STATUS_FI)
Returns
-------
A dictionary (id -> recorder) of recorder information that being stored.
@@ -320,11 +325,21 @@ class MLflowExperiment(Experiment):
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
def list_recorders(self, max_results=UNLIMITED):
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None):
"""
Parameters
----------
max_results : int
the number limitation of the results
status : str
the criteria based on status to filter results.
`None` indicates no filtering.
"""
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
recorders = dict()
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
recorders[runs[i].info.run_id] = recorder
if status is None or recorder.status == status:
recorders[runs[i].info.run_id] = recorder
return recorders

View File

@@ -109,7 +109,7 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `search_records` method.")
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
"""
Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.
@@ -190,7 +190,7 @@ class ExpManager:
except ValueError:
if experiment_name is None:
experiment_name = self._default_exp_name
logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
return self.create_exp(experiment_name), True
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
@@ -352,6 +352,8 @@ class MLflowExpManager(ExpManager):
), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder."
if experiment_id is not None:
try:
# NOTE: the mlflow's experiment_id must be str type...
# https://www.mlflow.org/docs/latest/python_api/mlflow.tracking.html#mlflow.tracking.MlflowClient.get_experiment
exp = self.client.get_experiment(experiment_id)
if exp.lifecycle_stage.upper() == "DELETED":
raise MlflowException("No valid experiment has been found.")

View File

@@ -6,7 +6,7 @@ OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run
With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
So this module provides a series of methods to control this process.
So this module provides a series of methods to control this process.
This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
Which means you can verify your strategy or find a better one.
@@ -18,10 +18,12 @@ There are 4 total situations for using different trainers in different situation
========================= ===================================================================================
Situations Description
========================= ===================================================================================
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
will train models task by task and strategy by strategy.
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
nothing until all tasks have been prepared. It makes user can train all tasks in
the end of `routine` or `first_train`.
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
need to consider using Trainer. This means it will REAL train your models in
@@ -29,7 +31,7 @@ Simulation + Trainer When your models have some temporal dependence on the
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
for the ability to multitasking. It means all tasks in all routines
can be REAL trained at the end of simulating. The signals will be prepared well at
can be REAL trained at the end of simulating. The signals will be prepared well at
different time segments (based on whether or not any new model is online).
========================= ===================================================================================
"""
@@ -103,17 +105,23 @@ class OnlineManager(Serializable):
"""
if strategies is None:
strategies = self.strategies
for strategy in strategies:
models_list = []
for strategy in strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
tasks = strategy.first_tasks()
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
# FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the
# start.
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
for strategy, models in zip(strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
def routine(
self,
cur_time: Union[str, pd.Timestamp] = None,
@@ -139,33 +147,41 @@ class OnlineManager(Serializable):
cur_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(cur_time) # None for latest date
models_list = []
for strategy in self.strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
models = self.trainer.train(tasks)
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.trainer.is_delay():
# The online model may changes in the above processes
# So updating the predictions of online models should be the last step
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
for strategy, models in zip(self.strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.prepare_signals(**signal_kwargs)
def get_collector(self) -> MergeCollector:
def get_collector(self, **kwargs) -> MergeCollector:
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
This collector can be a basis as the signals preparation.
Args:
**kwargs: the params for get_collector.
Returns:
MergeCollector: the collector to merge other collectors.
"""
collector_dict = {}
for strategy in self.strategies:
collector_dict[strategy.name_id] = strategy.get_collector()
collector_dict[strategy.name_id] = strategy.get_collector(**kwargs)
return MergeCollector(collector_dict, process_list=[])
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
@@ -225,7 +241,7 @@ class OnlineManager(Serializable):
SIM_LOG_NAME = "SIMULATE_INFO"
def simulate(
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
self, end_time=None, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
) -> Union[pd.Series, pd.DataFrame]:
"""
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
@@ -297,6 +313,7 @@ class OnlineManager(Serializable):
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
self.prepare_signals(**signal_kwargs)
if signals_time > cur_time:
# FIXME: if use DelayTrainer and worker (and worker is faster than main progress), there are some possibilities of showing this warning.
self.logger.warn(
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
)

View File

@@ -52,6 +52,12 @@ class OnlineStrategy:
NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
**NOTE**:
Current implementation is very naive. Here is a more complex situation which is more closer to the
practical scenarios.
1. Train new models at the day before `test_start` (at time stamp `T`)
2. Switch models at the `test_start` (at time timestamp `T + 1` typically)
Args:
models (list): a list of models.
cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.

View File

@@ -135,10 +135,9 @@ class PredUpdater(RecordUpdater):
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
# https://github.com/pytorch/pytorch/issues/16797
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
if start_time >= self.to_date:
if self.last_end >= self.to_date:
self.logger.info(
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
f"The prediction in {self.record.info['id']} are latest ({self.last_end}). No need to update to {self.to_date}."
)
return

View File

@@ -8,8 +8,11 @@ This allows us to use efficient submodels as the market-style changing.
"""
from typing import List, Union
from qlib.data.dataset import TSDatasetH
from qlib.log import get_module_logger
from qlib.utils import get_cls_kwargs
from qlib.utils.exceptions import LoadObjectError
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
@@ -88,15 +91,15 @@ class OnlineToolR(OnlineTool):
The implementation of OnlineTool based on (R)ecorder.
"""
def __init__(self, experiment_name: str):
def __init__(self, default_exp_name: str = None):
"""
Init OnlineToolR.
Args:
experiment_name (str): the experiment name.
default_exp_name (str): the default experiment name.
"""
super().__init__()
self.exp_name = experiment_name
self.default_exp_name = default_exp_name
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
"""
@@ -125,44 +128,68 @@ class OnlineToolR(OnlineTool):
tags = recorder.list_tags()
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
def reset_online_tag(self, recorder: Union[Recorder, List]):
def reset_online_tag(self, recorder: Union[Recorder, List], exp_name: str = None):
"""
Offline all models and set the recorders to 'online'.
Args:
recorder (Union[Recorder, List]):
the recorder you want to reset to 'online'.
exp_name (str): the experiment name. If None, then use default_exp_name.
"""
exp_name = self._get_exp_name(exp_name)
if isinstance(recorder, Recorder):
recorder = [recorder]
recs = list_recorders(self.exp_name)
recs = list_recorders(exp_name)
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
self.set_online_tag(self.ONLINE_TAG, recorder)
def online_models(self) -> list:
def online_models(self, exp_name: str = None) -> list:
"""
Get current `online` models
Args:
exp_name (str): the experiment name. If None, then use default_exp_name.
Returns:
list: a list of `online` models.
"""
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
exp_name = self._get_exp_name(exp_name)
return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
def update_online_pred(self, to_date=None):
def update_online_pred(self, to_date=None, exp_name: str = None):
"""
Update the predictions of online models to to_date.
Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
exp_name (str): the experiment name. If None, then use default_exp_name.
"""
online_models = self.online_models()
exp_name = self._get_exp_name(exp_name)
online_models = self.online_models(exp_name=exp_name)
for rec in online_models:
hist_ref = 0
task = rec.load_object("task")
# Special treatment of historical dependencies
if task["dataset"]["class"] == "TSDatasetH":
hist_ref = task["dataset"]["kwargs"]["step_len"]
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset")
if issubclass(cls, TSDatasetH):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
try:
updater = PredUpdater(rec, to_date=to_date, hist_ref=hist_ref)
except LoadObjectError as e:
# skip the recorder without pred
self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.")
continue
updater.update()
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {exp_name}.")
def _get_exp_name(self, exp_name):
if exp_name is None:
if self.default_exp_name is None:
raise ValueError(
"Both default_exp_name and exp_name are None. OnlineToolR needs a specific experiment."
)
exp_name = self.default_exp_name
return exp_name

View File

@@ -227,10 +227,11 @@ class SigAnaRecord(SignalRecord):
artifact_path = "sig_analysis"
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, **kwargs):
super().__init__(recorder=recorder, **kwargs)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler
self.label_col = label_col
def generate(self, **kwargs):
try:
@@ -243,7 +244,7 @@ class SigAnaRecord(SignalRecord):
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
logger.warn(f"Empty label.")
return
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, self.label_col])
metrics = {
"IC": ic.mean(),
"ICIR": ic.mean() / ic.std(),
@@ -252,7 +253,7 @@ class SigAnaRecord(SignalRecord):
}
objects = {"ic.pkl": ic, "ric.pkl": ric}
if self.ana_long_short:
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, self.label_col])
metrics.update(
{
"Long-Short Ann Return": long_short_r.mean() * self.ann_scaler,

View File

@@ -5,6 +5,8 @@ import mlflow, logging
import shutil, os, pickle, tempfile, codecs, pickle
from pathlib import Path
from datetime import datetime
from qlib.utils.exceptions import LoadObjectError
from ..utils.objm import FileManager
from ..log import get_module_logger
@@ -307,10 +309,26 @@ class MLflowRecorder(Recorder):
shutil.rmtree(temp_dir)
def load_object(self, name):
"""
Load object such as prediction file or model checkpoint in mlflow.
Args:
name (str): the object name
Raises:
LoadObjectError: if raise some exceptions when load the object
Returns:
object: the saved object in mlflow.
"""
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
return pickle.load(f)
try:
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
return pickle.load(f)
except Exception as e:
raise LoadObjectError(message=str(e))
def log_params(self, **kwargs):
for name, data in kwargs.items():

View File

@@ -6,6 +6,7 @@ Collector module can collect objects from everywhere and process them such as me
"""
from typing import Callable, Dict, List
from qlib.log import get_module_logger
from qlib.utils.serial import Serializable
from qlib.workflow import R
@@ -192,6 +193,7 @@ class RecorderCollector(Collector):
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
logger = get_module_logger("RecorderCollector")
for _, rec in recs_flt.items():
rec_key = self.rec_key_func(rec)
for key in artifacts_key:
@@ -205,7 +207,13 @@ class RecorderCollector(Collector):
# only collect existing artifact
continue
raise e
collect_dict.setdefault(key, {})[rec_key] = artifact
# give user some warning if the values are overridden
cdd = collect_dict.setdefault(key, {})
if rec_key in cdd:
logger.warning(
f"key '{rec_key}' is duplicated. Previous value will be overrides. Please check you `rec_key_func`"
)
cdd[rec_key] = artifact
return collect_dict

View File

@@ -6,6 +6,8 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp
import abc
import copy
from typing import List, Union, Callable
from qlib.utils import transform_end_date
from .utils import TimeAdjuster
@@ -199,7 +201,7 @@ class RollingGen(TaskGen):
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
test_end = transform_end_date(segments[self.test_key][1])
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))

View File

@@ -69,28 +69,29 @@ class TaskManager:
ENCODE_FIELDS_PREFIX = ["def", "res"]
def __init__(self, task_pool: str = None):
def __init__(self, task_pool: str):
"""
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
A TaskManager instance serves a specific task pool.
The static method of this module serves the whole MongoDB.
Parameters
----------
task_pool: str
the name of Collection in MongoDB
"""
self.mdb = get_mongodb()
if task_pool is not None:
self.task_pool = getattr(self.mdb, task_pool)
self.task_pool = getattr(get_mongodb(), task_pool)
self.logger = get_module_logger(self.__class__.__name__)
def list(self) -> list:
@staticmethod
def list() -> list:
"""
List the all collection(task_pool) of the db
List the all collection(task_pool) of the db.
Returns:
list
"""
return self.mdb.list_collection_names()
return get_mongodb().list_collection_names()
def _encode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
@@ -109,6 +110,25 @@ class TaskManager:
def _dict_to_str(self, flt):
return {k: str(v) for k, v in flt.items()}
def _decode_query(self, query):
"""
If the query includes any `_id`, then it needs `ObjectId` to decode.
For example, when using TrainerRM, it needs query `{"_id": {"$in": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`.
Args:
query (dict): query dict. Defaults to {}.
Returns:
dict: the query after decoding.
"""
if "_id" in query:
if isinstance(query["_id"], dict):
for key in query["_id"]:
query["_id"][key] = [ObjectId(i) for i in query["_id"][key]]
else:
query["_id"] = ObjectId(query["_id"])
return query
def replace_task(self, task, new_task):
"""
Use a new task to replace a old one
@@ -224,8 +244,7 @@ class TaskManager:
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
query.update({"status": status})
task = self.task_pool.find_one_and_update(
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
@@ -253,10 +272,10 @@ class TaskManager:
task = self.fetch_task(query=query, status=status)
try:
yield task
except Exception:
except (Exception, KeyboardInterrupt): # KeyboardInterrupt is not a subclass of Exception
if task is not None:
self.logger.info("Returning task before raising error")
self.return_task(task)
self.return_task(task, status=status) # return task as the original status
self.logger.info("Task returned")
raise
@@ -283,12 +302,11 @@ class TaskManager:
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
for t in self.task_pool.find(query):
yield self._decode_task(t)
def re_query(self, _id):
def re_query(self, _id) -> dict:
"""
Use _id to query task.
@@ -339,8 +357,7 @@ class TaskManager:
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
self.task_pool.delete_many(query)
def task_stat(self, query={}) -> dict:
@@ -354,8 +371,7 @@ class TaskManager:
dict
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
tasks = self.query(query=query, decode=False)
status_stat = {}
for t in tasks:
@@ -377,8 +393,7 @@ class TaskManager:
def reset_status(self, query, status):
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query = self._decode_query(query)
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
def prioritize(self, task, priority: int):
@@ -396,15 +411,29 @@ class TaskManager:
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
def _get_undone_n(self, task_stat):
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
return (
task_stat.get(self.STATUS_WAITING, 0)
+ task_stat.get(self.STATUS_RUNNING, 0)
+ task_stat.get(self.STATUS_PART_DONE, 0)
)
def _get_total(self, task_stat):
return sum(task_stat.values())
def wait(self, query={}):
"""
When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks.
So main progress should wait until all tasks are trained well by other progress or machines.
Args:
query (dict, optional): the query dict. Defaults to {}.
"""
task_stat = self.task_stat(query)
total = self._get_total(task_stat)
last_undone_n = self._get_undone_n(task_stat)
if last_undone_n == 0:
return
self.logger.warning(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
with tqdm(total=total, initial=total - last_undone_n) as pbar:
while True:
time.sleep(10)

View File

@@ -17,7 +17,6 @@ def experiment_exit_handler():
Thus, if any exception or user interuption occurs beforehead, we should handle them first. Once `R` is
ended, another call of `R.end_exp` will not take effect.
"""
signal.signal(signal.SIGINT, experiment_kill_signal_handler) # handle user keyboard interupt
sys.excepthook = experiment_exception_hook # handle uncaught exception
atexit.register(R.end_exp, recorder_status=Recorder.STATUS_FI) # will not take effect if experiment ends
@@ -39,10 +38,3 @@ def experiment_exception_hook(type, value, tb):
print(f"{type.__name__}: {value}")
R.end_exp(recorder_status=Recorder.STATUS_FA)
def experiment_kill_signal_handler(signum, frame):
"""
End an experiment when user kill the program through keyboard (CTRL+C, etc.).
"""
R.end_exp(recorder_status=Recorder.STATUS_FA)

View File

@@ -7,12 +7,13 @@ import time
import datetime
import importlib
from pathlib import Path
from typing import Type
from typing import Type, Iterable
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import pandas as pd
from tqdm import tqdm
from loguru import logger
from joblib import Parallel, delayed
from qlib.utils import code_to_fname
@@ -22,9 +23,9 @@ class BaseCollector(abc.ABC):
NORMAL_FLAG = "NORMAL"
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1)).date()
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)).date()
DEFAULT_END_DATETIME_1MIN = DEFAULT_END_DATETIME_1D
INTERVAL_1min = "1min"
INTERVAL_1d = "1d"
@@ -35,10 +36,10 @@ class BaseCollector(abc.ABC):
start=None,
end=None,
interval="1d",
max_workers=4,
max_workers=1,
max_collector_count=2,
delay=0,
check_data_length: bool = False,
check_data_length: int = None,
limit_nums: int = None,
):
"""
@@ -48,7 +49,7 @@ class BaseCollector(abc.ABC):
save_dir: str
instrument save dir
max_workers: int
workers, default 4
workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
max_collector_count: int
default 2
delay: float
@@ -59,8 +60,8 @@ class BaseCollector(abc.ABC):
start datetime, default None
end: str
end datetime, default None
check_data_length: bool
check data length, by default False
check_data_length: int
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
limit_nums: int
using for debug, by default None
"""
@@ -72,7 +73,7 @@ class BaseCollector(abc.ABC):
self.max_collector_count = max_collector_count
self.mini_symbol_map = {}
self.interval = interval
self.check_small_data = check_data_length
self.check_data_length = max(int(check_data_length) if check_data_length is not None else 0, 0)
self.start_datetime = self.normalize_start_datetime(start)
self.end_datetime = self.normalize_end_datetime(end)
@@ -99,14 +100,6 @@ class BaseCollector(abc.ABC):
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
)
@property
@abc.abstractmethod
def min_numbers_trading(self):
# daily, one year: 252 / 4
# us 1min, a week: 6.5 * 60 * 5
# cn 1min, a week: 4 * 60 * 5
raise NotImplementedError("rewrite min_numbers_trading")
@abc.abstractmethod
def get_instrument_list(self):
raise NotImplementedError("rewrite get_instrument_list")
@@ -132,7 +125,7 @@ class BaseCollector(abc.ABC):
Returns
---------
pd.DataFrame, "symbol" in pd.columns
pd.DataFrame, "symbol" and "date"in pd.columns
"""
raise NotImplementedError("rewrite get_timezone")
@@ -151,7 +144,7 @@ class BaseCollector(abc.ABC):
self.sleep()
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
_result = self.NORMAL_FLAG
if self.check_small_data:
if self.check_data_length > 0:
_result = self.cache_small_data(symbol, df)
if _result == self.NORMAL_FLAG:
self.save_instrument(symbol, df)
@@ -181,8 +174,8 @@ class BaseCollector(abc.ABC):
df.to_csv(instrument_path, index=False)
def cache_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
if len(df) < self.check_data_length:
logger.warning(f"the number of trading days of {symbol} is less than {self.check_data_length}!")
_temp = self.mini_symbol_map.setdefault(symbol, [])
_temp.append(df.copy())
return self.CACHE_FLAG
@@ -194,12 +187,12 @@ class BaseCollector(abc.ABC):
def _collector(self, instrument_list):
error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
with tqdm(total=len(instrument_list)) as p_bar:
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
if _result != self.NORMAL_FLAG:
error_symbol.append(_symbol)
p_bar.update()
res = Parallel(n_jobs=self.max_workers)(
delayed(self._simple_collector)(_inst) for _inst in tqdm(instrument_list)
)
for _symbol, _result in zip(instrument_list, res):
if _result != self.NORMAL_FLAG:
error_symbol.append(_symbol)
print(error_symbol)
logger.info(f"error symbol nums: {len(error_symbol)}")
logger.info(f"current get symbol nums: {len(instrument_list)}")
@@ -217,20 +210,16 @@ class BaseCollector(abc.ABC):
instrument_list = self._collector(instrument_list)
logger.info(f"{i+1} finish.")
for _symbol, _df_list in self.mini_symbol_map.items():
self.save_instrument(
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
)
_df = pd.concat(_df_list, sort=False)
if not _df.empty:
self.save_instrument(_symbol, _df.drop_duplicates(["date"]).sort_values(["date"]))
if self.mini_symbol_map:
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
logger.warning(f"less than {self.check_data_length} instrument list: {list(self.mini_symbol_map.keys())}")
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
class BaseNormalize(abc.ABC):
def __init__(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
"""
Parameters
@@ -242,7 +231,7 @@ class BaseNormalize(abc.ABC):
"""
self._date_field_name = date_field_name
self._symbol_field_name = symbol_field_name
self.kwargs = kwargs
self._calendar_list = self._get_calendar_list()
@abc.abstractmethod
@@ -251,7 +240,7 @@ class BaseNormalize(abc.ABC):
raise NotImplementedError("")
@abc.abstractmethod
def _get_calendar_list(self):
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
"""Get benchmark calendar"""
raise NotImplementedError("")
@@ -265,6 +254,7 @@ class Normalize:
max_workers: int = 16,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
**kwargs,
):
"""
@@ -288,16 +278,23 @@ class Normalize:
self._source_dir = Path(source_dir).expanduser()
self._target_dir = Path(target_dir).expanduser()
self._target_dir.mkdir(parents=True, exist_ok=True)
self._date_field_name = date_field_name
self._symbol_field_name = symbol_field_name
self._end_date = kwargs.get("end_date", None)
self._max_workers = max_workers
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
self._normalize_obj = normalize_class(
date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs
)
def _executor(self, file_path: Path):
file_path = Path(file_path)
df = pd.read_csv(file_path)
df = self._normalize_obj.normalize(df)
if not df.empty:
if df is not None and not df.empty:
if self._end_date is not None:
_mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date)
df = df[_mask]
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
def normalize(self):
@@ -311,7 +308,7 @@ class Normalize:
class BaseRun(abc.ABC):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"):
"""
Parameters
@@ -321,7 +318,7 @@ class BaseRun(abc.ABC):
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
Concurrent number, default is 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
interval: str
freq, value from [1min, 1d], default 1d
"""
@@ -361,7 +358,7 @@ class BaseRun(abc.ABC):
start=None,
end=None,
interval="1d",
check_data_length=False,
check_data_length: int = None,
limit_nums=None,
):
"""download data from Internet
@@ -378,8 +375,8 @@ class BaseRun(abc.ABC):
start datetime, default "2000-01-01"
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool
check data length, by default False
check_data_length: int
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
limit_nums: int
using for debug, by default None
@@ -404,7 +401,7 @@ class BaseRun(abc.ABC):
limit_nums=limit_nums,
).collector_data()
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
"""normalize data
Parameters
@@ -426,5 +423,6 @@ class BaseRun(abc.ABC):
max_workers=self.max_workers,
date_field_name=date_field_name,
symbol_field_name=symbol_field_name,
**kwargs,
)
yc.normalize()

View File

@@ -19,12 +19,31 @@ CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.index import IndexBase
from data_collector.utils import get_calendar_list, get_trading_date_by_shift
from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
# INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
# 2020-11-27 Announcement title change
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89"
REQ_HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.101 Safari/537.36 Edg/91.0.864.48"
}
@deco_retry
def retry_request(url: str, method: str = "get", exclude_status: List = None):
if exclude_status is None:
exclude_status = []
method_func = getattr(requests, method)
_resp = method_func(url, headers=REQ_HEADERS)
_status = _resp.status_code
if _status not in exclude_status and _status != 200:
raise ValueError(f"response status: {_status}, url={url}")
return _resp
class CSIIndex(IndexBase):
@@ -134,9 +153,8 @@ class CSIIndex(IndexBase):
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
resp = requests.get(url)
resp = retry_request(url)
_text = resp.text
date_list = re.findall(r"(\d{4}).*?年.*?(\d+).*?月.*?(\d+).*?日", _text)
if len(date_list) >= 2:
add_date = pd.Timestamp("-".join(date_list[0]))
@@ -147,7 +165,7 @@ class CSIIndex(IndexBase):
logger.info(f"get {add_date} changes")
try:
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
content = requests.get(f"http://www.csindex.com.cn{excel_url}").content
content = retry_request(f"http://www.csindex.com.cn{excel_url}", exclude_status=[404]).content
_io = BytesIO(content)
df_map = pd.read_excel(_io, sheet_name=None)
with self.cache_dir.joinpath(
@@ -201,7 +219,7 @@ class CSIIndex(IndexBase):
-------
[url1, url2]
"""
resp = requests.get(self.changes_url)
resp = retry_request(self.changes_url)
html = etree.HTML(resp.text)
return html.xpath("//*[@id='itemContainer']//li/a/@href")
@@ -221,7 +239,7 @@ class CSIIndex(IndexBase):
end_date: pd.Timestamp
"""
logger.info("get new companies......")
context = requests.get(self.new_companies_url).content
context = retry_request(self.new_companies_url).content
with self.cache_dir.joinpath(
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
).open("wb") as fp:
@@ -292,7 +310,7 @@ def get_instruments(
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
"""
_cur_module = importlib.import_module("collector")
_cur_module = importlib.import_module("data_collector.cn_index.collector")
obj = getattr(_cur_module, f"{index_name.upper()}")(
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
)

View File

@@ -0,0 +1,23 @@
# Use 1d data to fill in the missing symbols relative to 1min
## Requirements
```bash
pip install -r requirements.txt
```
## fill 1min data
```bash
python fill_1min_using_1d.py --data_1min_dir ~/.qlib/csv_data/cn_data_1min --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data
```
## Parameters
- ata_1min_dir: csv data
- qlib_data_1d_dir: qlib data directory
- max_workers: `ThreadPoolExecutor(max_workers=max_workers)`, by default *16*
- date_field_name: date field name, by default *date*
- symbol_field_name: symbol field name, by default *symbol*

View File

@@ -0,0 +1,100 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
import fire
import qlib
import pandas as pd
from tqdm import tqdm
from qlib.data import D
from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent.parent))
from data_collector.utils import generate_minutes_calendar_from_daily
def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: str = "date"):
csv_files = list(data_1min_dir.glob("*.csv"))
min_date = None
max_date = None
with tqdm(total=len(csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for _file, _result in zip(csv_files, executor.map(pd.read_csv, csv_files)):
if not _result.empty:
_dates = pd.to_datetime(_result[date_field_name])
_tmp_min = _dates.min()
min_date = min(min_date, _tmp_min) if min_date is not None else _tmp_min
_tmp_max = _dates.max()
max_date = max(max_date, _tmp_max) if max_date is not None else _tmp_max
p_bar.update()
return min_date, max_date
def get_symbols(data_1min_dir: Path):
return list(map(lambda x: x.name[:-4].upper(), data_1min_dir.glob("*.csv")))
def fill_1min_using_1d(
data_1min_dir: [str, Path],
qlib_data_1d_dir: [str, Path],
max_workers: int = 16,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""Use 1d data to fill in the missing symbols relative to 1min
Parameters
----------
data_1min_dir: str
1min data dir
qlib_data_1d_dir: str
1d qlib data(bin data) dir, from: https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format
max_workers: int
ThreadPoolExecutor(max_workers), by default 16
date_field_name: str
date field name, by default date
symbol_field_name: str
symbol field name, by default symbol
"""
data_1min_dir = Path(data_1min_dir).expanduser().resolve()
qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve()
min_date, max_date = get_date_range(data_1min_dir, max_workers, date_field_name)
symbols_1min = get_symbols(data_1min_dir)
qlib.init(provider_uri=str(qlib_data_1d_dir))
data_1d = D.features(D.instruments("all"), ["$close"], min_date, max_date, freq="day")
miss_symbols = set(data_1d.index.get_level_values(level="instrument").unique()) - set(symbols_1min)
if not miss_symbols:
logger.warning("More symbols in 1min than 1d, no padding required")
return
logger.info(f"miss_symbols {len(miss_symbols)}: {miss_symbols}")
tmp_df = pd.read_csv(list(data_1min_dir.glob("*.csv"))[0])
columns = tmp_df.columns
_si = tmp_df[symbol_field_name].first_valid_index()
is_lower = tmp_df.loc[_si][symbol_field_name].islower()
for symbol in tqdm(miss_symbols):
if is_lower:
symbol = symbol.lower()
index_1d = data_1d.loc(axis=0)[symbol.upper()].index
index_1min = generate_minutes_calendar_from_daily(index_1d)
index_1min.name = date_field_name
_df = pd.DataFrame(columns=columns, index=index_1min)
if date_field_name in _df.columns:
del _df[date_field_name]
_df.reset_index(inplace=True)
_df[symbol_field_name] = symbol
_df["paused_num"] = 0
_df.to_csv(data_1min_dir.joinpath(f"{symbol}.csv"), index=False)
if __name__ == "__main__":
fire.Fire(fill_1min_using_1d)

View File

@@ -0,0 +1,5 @@
fire
pandas
loguru
tqdm
pyqlib

View File

@@ -14,7 +14,7 @@ from loguru import logger
import baostock as bs
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
sys.path.append(str(CUR_DIR.parent.parent.parent))
from data_collector.utils import generate_minutes_calendar_from_daily

View File

@@ -3,18 +3,13 @@
import abc
import sys
import copy
import time
import datetime
import importlib
import json
from abc import ABC
from pathlib import Path
from typing import Iterable, Type
import fire
import requests
import numpy as np
import pandas as pd
from loguru import logger
from dateutil.tz import tzlocal
@@ -38,7 +33,7 @@ class FundCollector(BaseCollector):
max_workers=4,
max_collector_count=2,
delay=0,
check_data_length: bool = False,
check_data_length: int = None,
limit_nums: int = None,
):
"""
@@ -59,8 +54,8 @@ class FundCollector(BaseCollector):
start datetime, default None
end: str
end datetime, default None
check_data_length: bool
check data length, by default False
check_data_length: int
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
limit_nums: int
using for debug, by default None
"""
@@ -168,9 +163,7 @@ class FundollectorCN(FundCollector, ABC):
class FundCollectorCN1d(FundollectorCN):
@property
def min_numbers_trading(self):
return 252 / 4
pass
class FundNormalize(BaseNormalize):
@@ -261,7 +254,7 @@ class Run(BaseRun):
start=None,
end=None,
interval="1d",
check_data_length=False,
check_data_length: int = None,
limit_nums=None,
):
"""download data from Internet
@@ -278,8 +271,8 @@ class Run(BaseRun):
start datetime, default "2000-01-01"
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool # if this param useful?
check data length, by default False
check_data_length: int # if this param useful?
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
limit_nums: int
using for debug, by default None

View File

@@ -271,7 +271,7 @@ def get_instruments(
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
"""
_cur_module = importlib.import_module("collector")
_cur_module = importlib.import_module("data_collector.us_index.collector")
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
)

View File

@@ -2,7 +2,6 @@
# Licensed under the MIT License.
import re
import os
import time
import bisect
import pickle
@@ -10,7 +9,7 @@ import random
import requests
import functools
from pathlib import Path
from typing import Iterable, Tuple
from typing import Iterable, Tuple, List
import numpy as np
import pandas as pd
@@ -47,7 +46,7 @@ _CALENDAR_MAP = {}
MINIMUM_SYMBOLS_NUM = 3900
def get_calendar_list(bench_code="CSI300") -> list:
def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
"""get SH/SZ history calendar list
Parameters

View File

@@ -1,3 +1,11 @@
- [Collector Data](#collector-data)
- [Get Qlib data](#get-qlib-databin-file)
- [Collector *YahooFinance* data to qlib](#collector-yahoofinance-data-to-qlib)
- [Automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
- [Using qlib data](#using-qlib-data)
# Collect Data From Yahoo Finance
> *Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
@@ -18,113 +26,170 @@ pip install -r requirements.txt
## Collector Data
### Get Qlib data(`bin file`)
> `qlib-data` from *YahooFinance*, is the data that has been dumped and can be used directly in `qlib`
### CN Data
- get data: `python scripts/get_data.py qlib_data`
- parameters:
- `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data*
- `version`: dataset version, value from [`v1`, `v2`], by default `v1`
- `v2` end date is *2021-06*, `v1` end date is *2020-09*
- user can append data to `v2`: [automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
- **the [benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks) for qlib use `v1`**, *due to the unstable access to historical data by YahooFinance, there are some differences between `v2` and `v1`*
- `interval`: `1d` or `1min`, by default `1d`
- `region`: `cn` or `us`, by default `cn`
- `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`
- `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`
- examples:
```bash
# cn 1d
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn
# cn 1min
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
# us 1d
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d --region us --interval 1d
# us 1min
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1min --region us --interval 1min
```
#### 1d from yahoo
### Collector *YahooFinance* data to qlib
> collector *YahooFinance* data and *dump* into `qlib` format
1. download data to csv: `python scripts/data_collector/yahoo/collector.py download_data`
```bash
- parameters:
- `source_dir`: save the directory
- `interval`: `1d` or `1min`, by default `1d`
> **due to the limitation of the *YahooFinance API*, only the last month's data is available in `1min`**
- `region`: `CN` or `US`, by default `CN`
- `delay`: `time.sleep(delay)`, by default *0.5*
- `start`: start datetime, by default *"2000-01-01"*; *closed interval(including start)*
- `end`: end datetime, by default `pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`; *open interval(excluding end)*
- `max_workers`: get the number of concurrent symbols, it is not recommended to change this parameter in order to maintain the integrity of the symbol data, by default *1*
- `check_data_length`: check the number of rows per *symbol*, by default `None`
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
- `max_collector_count`: number of *"failed"* symbol retries, by default 2
- examples:
```bash
# cn 1d data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US
# cn 1min data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --delay 1 --interval 1min --region CN
# us 1d data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US
# us 1min data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1min --delay 1 --interval 1min --region US
```
2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data`
- parameters:
- `source_dir`: csv directory
- `normalize_dir`: result directory
- `max_workers`: number of concurrent, by default *1*
- `interval`: `1d` or `1min`, by default `1d`
> if **`interval == 1min`**, `qlib_data_1d_dir` cannot be `None`
- `region`: `CN` or `US`, by default `CN`
- `date_field_name`: column *name* identifying time in csv files, by default `date`
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
- `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`
- `qlib_data_1d_dir`: qlib directory(1d data)
```
if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;
qlib_data_1d can be obtained like this:
$ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
or:
download 1d data from YahooFinance
```
- examples:
```bash
# normalize 1d cn
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1d --normalize_dir ~/.qlib/stock_data/source/cn_1d_nor --region CN --interval 1d
# normalize 1min cn
python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/qlib_cn_1d --source_dir ~/.qlib/stock_data/source/cn_1min --normalize_dir ~/.qlib/stock_data/source/cn_1min_nor --region CN --interval 1min
```
3. dump data: `python scripts/dump_bin.py dump_all`
- parameters:
- `csv_path`: stock data path or directory, **normalize result(normalize_dir)**
- `qlib_dir`: qlib(dump) data director
- `freq`: transaction frequency, by default `day`
> `freq_map = {1d:day, 1mih: 1min}`
- `max_workers`: number of threads, by default *16*
- `include_fields`: dump fields, by default `""`
- `exclude_fields`: fields not dumped, by default `"""
> dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
- `date_field_name`: column *name* identifying time in csv files, by default `date`
- examples:
```bash
# dump 1d cn
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,symbol
# dump 1min cn
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1min --freq 1min --exclude_fields date,symbol
```
# download from yahoo finance
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
### Automatic update of daily frequency data(from yahoo finance)
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
# normalize
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1d --normalize_dir ~/.qlib/stock_data/source/cn_1d_nor --region CN --interval 1d
* Automatic update of data to the "qlib" directory each trading day(Linux)
* use *crontab*: `crontab -e`
* set up timed tasks:
# dump data
cd qlib/scripts
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
```
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
```
* **script path**: *scripts/data_collector/yahoo/collector.py*
```
* Manual update of data
```
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
```
* `trading_date`: start of trading day
* `end_date`: end of trading day(not included)
* `check_data_length`: check the number of rows per *symbol*, by default `None`
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
### 1d from qlib
```bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn
```
### using data
```python
import qlib
from qlib.data import D
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="cn")
df = D.features(D.instruments("all"), ["$close"], freq="day")
```
#### 1min from yahoo
```bash
# download from yahoo finance
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1min
# normalize
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1min --normalize_dir ~/.qlib/stock_data/source/cn_1min_nor --region CN --interval 1min
# dump data
cd qlib/scripts
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1min --freq 1min --exclude_fields date,adjclose,dividends,splits,symbol
```
### 1min from qlib
```bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --interval 1min --region cn
```
### using data
```python
import qlib
from qlib.data import D
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="cn")
df = D.features(D.instruments("all"), ["$close"], freq="1min")
```
### US Data
#### 1d from yahoo
```bash
# download from yahoo finance
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1d --region US --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
# normalize
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/us_1d --normalize_dir ~/.qlib/stock_data/source/us_1d_nor --region US --interval 1d
# dump data
cd qlib/scripts
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/us_1d_nor --qlib_dir ~/.qlib/stock_data/source/qlib_us_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
```
#### 1d from qlib
```bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d --region us
```
### using data
```python
# using
import qlib
from qlib.data import D
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="us")
df = D.features(D.instruments("all"), ["$close"], freq="day")
```
* `scripts/data_collector/yahoo/collector.py update_data_to_bin` parameters:
* `source_dir`: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
* `normalize_dir`: Directory for normalize data, default "Path(__file__).parent/normalize"
* `qlib_data_1d_dir`: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
* `trading_date`: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
* `end_date`: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
* `region`: region, value from ["CN", "US"], default "CN"
### Help
```bash
python collector.py collector_data --help
```
## Using qlib data
## Parameters
```python
import qlib
from qlib.data import D
# 1d data cn
# freq=day, freq default day
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="cn")
df = D.features(D.instruments("all"), ["$close"], freq="day")
# 1min data cn
# freq=1min
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="cn")
inst = D.list_instruments(D.instruments("all"), freq="1min", as_list=True)
# get 100 symbols
df = D.features(inst[:100], ["$close"], freq="1min")
# get all symbol data
# df = D.features(D.instruments("all"), ["$close"], freq="1min")
# 1d data us
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="us")
df = D.features(D.instruments("all"), ["$close"], freq="day")
# 1min data us
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1min", region="cn")
inst = D.list_instruments(D.instruments("all"), freq="1min", as_list=True)
# get 100 symbols
df = D.features(inst[:100], ["$close"], freq="1min")
# get all symbol data
# df = D.features(D.instruments("all"), ["$close"], freq="1min")
```
- interval: 1min or 1d
- region: CN or US

View File

@@ -8,8 +8,9 @@ import time
import datetime
import importlib
from abc import ABC
import multiprocessing
from pathlib import Path
from typing import Iterable, Type
from typing import Iterable
import fire
import requests
@@ -18,13 +19,18 @@ import pandas as pd
from loguru import logger
from yahooquery import Ticker
from dateutil.tz import tzlocal
from qlib.utils import code_to_fname, fname_to_code
from qlib.tests.data import GetData
from qlib.utils import code_to_fname, fname_to_code, exists_qlib_data
from qlib.config import REG_CN as REGION_CN
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
from dump_bin import DumpDataUpdate
from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize
from data_collector.utils import (
deco_retry,
get_calendar_list,
get_hs_stock_symbols,
get_us_stock_symbols,
@@ -44,7 +50,7 @@ class YahooCollector(BaseCollector):
max_workers=4,
max_collector_count=2,
delay=0,
check_data_length: bool = False,
check_data_length: int = None,
limit_nums: int = None,
):
"""
@@ -65,8 +71,8 @@ class YahooCollector(BaseCollector):
start datetime, default None
end: str
end datetime, default None
check_data_length: bool
check data length, by default False
check_data_length: int
check data length, by default None
limit_nums: int
using for debug, by default None
"""
@@ -92,10 +98,6 @@ class YahooCollector(BaseCollector):
else:
raise ValueError(f"interval error: {self.interval}")
# using for 1min
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
@@ -140,40 +142,39 @@ class YahooCollector(BaseCollector):
def get_data(
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> pd.DataFrame:
@deco_retry(retry_sleep=self.delay)
def _get_simple(start_, end_):
self.sleep()
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
return self.get_data_from_remote(
resp = self.get_data_from_remote(
symbol,
interval=_remote_interval,
start=start_,
end=end_,
)
if resp is None or resp.empty:
raise ValueError(f"get data error: {symbol}--{start_}--{end_}")
return resp
_result = None
if interval == self.INTERVAL_1d:
_result = _get_simple(start_datetime, end_datetime)
elif interval == self.INTERVAL_1min:
if self._next_datetime >= self._latest_datetime:
try:
_result = _get_simple(start_datetime, end_datetime)
else:
_res = []
def _get_multi(start_, end_):
_resp = _get_simple(start_, end_)
if _resp is not None and not _resp.empty:
_res.append(_resp)
for _s, _e in (
(self.start_datetime, self._next_datetime),
(self._latest_datetime, self.end_datetime),
):
_get_multi(_s, _e)
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
_end = _start + pd.Timedelta(days=1)
_get_multi(_start, _end)
if _res:
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
except ValueError as e:
pass
elif interval == self.INTERVAL_1min:
_res = []
_start = self.start_datetime
while _start < self.end_datetime:
_tmp_end = min(_start + pd.Timedelta(days=7), self.end_datetime)
try:
_resp = _get_simple(_start, _tmp_end)
_res.append(_resp)
except ValueError as e:
pass
_start = _tmp_end
if _res:
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
else:
raise ValueError(f"cannot support {self.interval}")
return pd.DataFrame() if _result is None else _result
@@ -207,10 +208,6 @@ class YahooCollectorCN(YahooCollector, ABC):
class YahooCollectorCN1d(YahooCollectorCN):
@property
def min_numbers_trading(self):
return 252 / 4
def download_index_data(self):
# TODO: from MSN
_format = "%Y%m%d"
@@ -244,13 +241,12 @@ class YahooCollectorCN1d(YahooCollectorCN):
class YahooCollectorCN1min(YahooCollectorCN):
@property
def min_numbers_trading(self):
return 60 * 4 * 5
def get_instrument_list(self):
symbols = super(YahooCollectorCN1min, self).get_instrument_list()
return symbols + ["000300.ss", "000905.ss", "000903.ss"]
def download_index_data(self):
# TODO: 1m
logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data")
pass
class YahooCollectorUS(YahooCollector, ABC):
@@ -276,15 +272,11 @@ class YahooCollectorUS(YahooCollector, ABC):
class YahooCollectorUS1d(YahooCollectorUS):
@property
def min_numbers_trading(self):
return 252 / 4
pass
class YahooCollectorUS1min(YahooCollectorUS):
@property
def min_numbers_trading(self):
return 60 * 6.5 * 5
pass
class YahooNormalize(BaseNormalize):
@@ -297,6 +289,7 @@ class YahooNormalize(BaseNormalize):
calendar_list: list = None,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
last_close: float = None,
):
if df.empty:
return df
@@ -318,7 +311,10 @@ class YahooNormalize(BaseNormalize):
df.sort_index(inplace=True)
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan
_tmp_series = df["close"].fillna(method="ffill")
df["change"] = _tmp_series / _tmp_series.shift(1) - 1
_tmp_shift_series = _tmp_series.shift(1)
if last_close is not None:
_tmp_shift_series.iloc[0] = float(last_close)
df["change"] = _tmp_series / _tmp_shift_series - 1
columns += ["change"]
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
@@ -367,6 +363,17 @@ class YahooNormalize1d(YahooNormalize, ABC):
df = self._manual_adj_data(df)
return df
def _get_first_close(self, df: pd.DataFrame) -> float:
"""get first close value
Notes
-----
For incremental updates(append) to Yahoo 1D data, user need to use a close that is not 0 on the first trading day of the existing data
"""
df = df.loc[df["close"].first_valid_index() :]
_close = df["close"].iloc[0]
return _close
def _manual_adj_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""manual adjust data: All fields (except change) are standardized according to the close of the first day"""
if df.empty:
@@ -374,45 +381,112 @@ class YahooNormalize1d(YahooNormalize, ABC):
df = df.copy()
df.sort_values(self._date_field_name, inplace=True)
df = df.set_index(self._date_field_name)
df = df.loc[df["close"].first_valid_index() :]
_close = df["close"].iloc[0]
_close = self._get_first_close(df)
for _col in df.columns:
if _col == self._symbol_field_name:
# NOTE: retain original adjclose, required for incremental updates
if _col in [self._symbol_field_name, "adjclose", "change"]:
continue
if _col == "volume":
df[_col] = df[_col] * _close
elif _col != "change":
df[_col] = df[_col] / _close
else:
pass
df[_col] = df[_col] / _close
return df.reset_index()
class YahooNormalize1dExtend(YahooNormalize1d):
def __init__(
self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
):
"""
Parameters
----------
old_qlib_data_dir: str, Path
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name)
self._first_close_field = "first_close"
self._ori_close_field = "ori_close"
self.old_qlib_data = self._get_old_data(old_qlib_data_dir)
def _get_old_data(self, qlib_data_dir: [str, Path]):
import qlib
from qlib.data import D
qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())
qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)
df = D.features(D.instruments("all"), ["$close/$factor", "$adjclose/$close"])
df.columns = [self._ori_close_field, self._first_close_field]
return df
def _get_close(self, df: pd.DataFrame, field_name: str):
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
_df = self.old_qlib_data.loc(axis=0)[_symbol]
_close = _df.loc[_df.last_valid_index()][field_name]
return _close
def _get_first_close(self, df: pd.DataFrame) -> float:
try:
_close = self._get_close(df, field_name=self._first_close_field)
except KeyError:
_close = super(YahooNormalize1dExtend, self)._get_first_close(df)
return _close
def _get_last_close(self, df: pd.DataFrame) -> float:
try:
_close = self._get_close(df, field_name=self._ori_close_field)
except KeyError:
_close = None
return _close
def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp:
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
try:
_df = self.old_qlib_data.loc(axis=0)[_symbol]
_date = _df.index.max()
except KeyError:
_date = None
return _date
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
_last_close = self._get_last_close(df)
# reindex
_last_date = self._get_last_date(df)
if _last_date is not None:
df = df.set_index(self._date_field_name)
df.index = pd.to_datetime(df.index)
df = df[~df.index.duplicated(keep="first")]
_max_date = df.index.max()
df = df.reindex(self._calendar_list).loc[:_max_date].reset_index()
df = df[df[self._date_field_name] > _last_date]
if df.empty:
return pd.DataFrame()
_si = df["close"].first_valid_index()
if _si > df.index[0]:
logger.warning(
f"{df.loc[_si][self._symbol_field_name]} missing data: {df.loc[:_si - 1][self._date_field_name].to_list()}"
)
# normalize
df = self.normalize_yahoo(
df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close
)
# adjusted price
df = self.adjusted_price(df)
df = self._manual_adj_data(df)
return df
class YahooNormalize1min(YahooNormalize, ABC):
AM_RANGE = None # type: tuple # eg: ("09:30:00", "11:29:00")
PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00")
# Whether the trading day of 1min data is consistent with 1d
CONSISTENT_1d = False
def __init__(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""
Parameters
----------
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name)
_class_name = self.__class__.__name__.replace("min", "d")
_class = getattr(importlib.import_module("collector"), _class_name) # type: Type[YahooNormalize]
self.data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
CONSISTENT_1d = True
CALC_PAUSED_NUM = True
@property
def calendar_list_1d(self):
@@ -427,24 +501,40 @@ class YahooNormalize1min(YahooNormalize, ABC):
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
)
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
"""get 1d data
Returns
------
data_1d: pd.DataFrame
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
"""
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
if not (data_1d is None or data_1d.empty):
_class_name = self.__class__.__name__.replace("min", "d")
_class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name)
data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
data_1d = data_1d_obj.normalize(data_1d)
return data_1d
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
# TODO: using daily data factor
if df.empty:
return df
df = df.copy()
df = df.sort_values(self._date_field_name)
symbol = df.iloc[0][self._symbol_field_name]
# get 1d data from yahoo
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
data_1d = YahooCollector.get_data_from_remote(
self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end
)
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
data_1d = data_1d.copy()
if data_1d is None or data_1d.empty:
df["factor"] = 1
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"]
# TODO: np.nan or 1 or 0
df["paused"] = np.nan
else:
data_1d = self.data_1d_obj.normalize(data_1d) # type: pd.DataFrame
# NOTE: volume is np.nan or volume <= 0, paused = 1
# FIXME: find a more accurate data source
data_1d["paused"] = 0
@@ -452,9 +542,13 @@ class YahooNormalize1min(YahooNormalize, ABC):
data_1d = data_1d.set_index(self._date_field_name)
# add factor from 1d data
# NOTE: yahoo 1d data info:
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
df.set_index("date_tmp", inplace=True)
df.loc[:, "factor"] = data_1d["factor"]
df.loc[:, "factor"] = data_1d["close"] / df["close"]
df.loc[:, "paused"] = data_1d["paused"]
df.reset_index("date_tmp", drop=True, inplace=True)
@@ -478,6 +572,54 @@ class YahooNormalize1min(YahooNormalize, ABC):
df[_col] = df[_col] / df["factor"]
else:
df[_col] = df[_col] * df["factor"]
if self.CALC_PAUSED_NUM:
df = self.calc_paused_num(df)
return df
def calc_paused_num(self, df: pd.DataFrame):
_symbol = df.iloc[0][self._symbol_field_name]
df = df.copy()
df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
# remove data that starts and ends with `np.nan` all day
all_data = []
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
all_nan_nums = 0
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
not_nan_nums = 0
for _date, _df in df.groupby("_tmp_date"):
_df["paused"] = 0
if not _df.loc[_df["volume"] < 0].empty:
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
_df.loc[_df["volume"] < 0, "volume"] = np.nan
check_fields = set(_df.columns) - {
"_tmp_date",
"paused",
"factor",
self._date_field_name,
self._symbol_field_name,
}
if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all():
all_nan_nums += 1
not_nan_nums = 0
_df["paused"] = 1
if all_data:
_df["paused_num"] = not_nan_nums
all_data.append(_df)
else:
all_nan_nums = 0
not_nan_nums += 1
_df["paused_num"] = not_nan_nums
all_data.append(_df)
all_data = all_data[: len(all_data) - all_nan_nums]
if all_data:
df = pd.concat(all_data, sort=False)
else:
logger.warning(f"data is empty: {_symbol}")
df = pd.DataFrame()
return df
del df["_tmp_date"]
return df
@abc.abstractmethod
@@ -485,12 +627,67 @@ class YahooNormalize1min(YahooNormalize, ABC):
raise NotImplementedError("rewrite symbol_to_yahoo")
@abc.abstractmethod
def _get_1d_calendar_list(self):
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
raise NotImplementedError("rewrite _get_1d_calendar_list")
class YahooNormalize1minOffline(YahooNormalize1min):
"""Normalised to 1min using local 1d data"""
def __init__(
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
):
"""
Parameters
----------
qlib_data_1d_dir: str, Path
the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
self.qlib_data_1d_dir = qlib_data_1d_dir
super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name)
self._all_1d_data = self._get_all_1d_data()
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
import qlib
from qlib.data import D
qlib.init(provider_uri=self.qlib_data_1d_dir)
return list(D.calendar(freq="day"))
def _get_all_1d_data(self):
import qlib
from qlib.data import D
qlib.init(provider_uri=self.qlib_data_1d_dir)
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
df.reset_index(inplace=True)
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
return df
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
"""get 1d data
Returns
------
data_1d: pd.DataFrame
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
"""
return self._all_1d_data[
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
]
class YahooNormalizeUS:
def _get_calendar_list(self):
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
# TODO: from MSN
return get_calendar_list("US_ALL")
@@ -499,10 +696,10 @@ class YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d):
pass
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
CONSISTENT_1d = False
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline):
CALC_PAUSED_NUM = False
def _get_calendar_list(self):
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
# TODO: support 1min
raise ValueError("Does not support 1min")
@@ -514,7 +711,7 @@ class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
class YahooNormalizeCN:
def _get_calendar_list(self):
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
# TODO: from MSN
return get_calendar_list("ALL")
@@ -523,28 +720,30 @@ class YahooNormalizeCN1d(YahooNormalizeCN, YahooNormalize1d):
pass
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):
pass
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline):
AM_RANGE = ("09:30:00", "11:29:00")
PM_RANGE = ("13:00:00", "14:59:00")
CONSISTENT_1d = True
def _get_calendar_list(self):
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
return self.generate_1min_from_daily(self.calendar_list_1d)
def symbol_to_yahoo(self, symbol):
if "." not in symbol:
_exchange = symbol[:2]
_exchange = "ss" if _exchange == "sh" else _exchange
_exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange
symbol = symbol[2:] + "." + _exchange
return symbol
def _get_1d_calendar_list(self):
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
return get_calendar_list("ALL")
class Run(BaseRun):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
"""
Parameters
@@ -554,7 +753,7 @@ class Run(BaseRun):
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
interval: str
freq, value from [1min, 1d], default 1d
region: str
@@ -578,10 +777,10 @@ class Run(BaseRun):
def download_data(
self,
max_collector_count=2,
delay=0,
delay=0.5,
start=None,
end=None,
check_data_length=False,
check_data_length=None,
limit_nums=None,
):
"""download data from Internet
@@ -591,16 +790,23 @@ class Run(BaseRun):
max_collector_count: int
default 2
delay: float
time.sleep(delay), default 0
time.sleep(delay), default 0.5
start: str
start datetime, default "2000-01-01"
start datetime, default "2000-01-01"; closed interval(including start)
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool
check data length, by default False
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``; open interval(excluding end)
check_data_length: int
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
limit_nums: int
using for debug, by default None
Notes
-----
check_data_length, example:
daily, one year: 252 // 4
us 1min, a week: 6.5 * 60 * 5
cn 1min, a week: 4 * 60 * 5
Examples
---------
# get daily data
@@ -612,7 +818,13 @@ class Run(BaseRun):
max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums
)
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
def normalize_data(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
end_date: str = None,
qlib_data_1d_dir: str = None,
):
"""normalize data
Parameters
@@ -621,12 +833,205 @@ class Run(BaseRun):
date field name, default date
symbol_field_name: str
symbol field name, default symbol
end_date: str
if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None
qlib_data_1d_dir: str
if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;
qlib_data_1d can be obtained like this:
$ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
or:
download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
"""
super(Run, self).normalize_data(date_field_name, symbol_field_name)
if self.interval.lower() == "1min":
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
raise ValueError(
"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/zhupr/qlib/tree/support_extend_data/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
)
super(Run, self).normalize_data(
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
)
def normalize_data_1d_extend(
self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
):
"""normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
Notes
-----
Steps to extend yahoo qlib data:
1. download qlib data: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data; save to <dir1>
2. collector source data: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#collector-data; save to <dir2>
3. normalize new source data(from step 2): python scripts/data_collector/yahoo/collector.py normalize_data_1d_extend --old_qlib_dir <dir1> --source_dir <dir2> --normalize_dir <dir3> --region CN --interval 1d
4. dump data: python scripts/dump_bin.py dump_update --csv_path <dir3> --qlib_dir <dir1> --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date
5. update instrument(eg. csi300): python python scripts/data_collector/cn_index/collector.py --index_name CSI300 --qlib_dir <dir1> --method parse_instruments
Parameters
----------
old_qlib_data_dir: str
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
date_field_name: str
date field name, default date
symbol_field_name: str
symbol field name, default symbol
Examples
---------
$ python collector.py normalize_data_1d_extend --old_qlib_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
"""
_class = getattr(self._cur_module, f"{self.normalize_class_name}Extend")
yc = Normalize(
source_dir=self.source_dir,
target_dir=self.normalize_dir,
normalize_class=_class,
max_workers=self.max_workers,
date_field_name=date_field_name,
symbol_field_name=symbol_field_name,
old_qlib_data_dir=old_qlib_data_dir,
)
yc.normalize()
def download_today_data(
self,
max_collector_count=2,
delay=0.5,
check_data_length=None,
limit_nums=None,
):
"""download today data from Internet
Parameters
----------
max_collector_count: int
default 2
delay: float
time.sleep(delay), default 0.5
check_data_length: int
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
limit_nums: int
using for debug, by default None
Notes
-----
Download today's data:
start_time = datetime.datetime.now().date(); closed interval(including start)
end_time = pd.Timestamp(start_time + pd.Timedelta(days=1)).date(); open interval(excluding end)
check_data_length, example:
daily, one year: 252 // 4
us 1min, a week: 6.5 * 60 * 5
cn 1min, a week: 4 * 60 * 5
Examples
---------
# get daily data
$ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1d
# get 1m data
$ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1m
"""
start = datetime.datetime.now().date()
end = pd.Timestamp(start + pd.Timedelta(days=1)).date()
self.download_data(
max_collector_count,
delay,
start.strftime("%Y-%m-%d"),
end.strftime("%Y-%m-%d"),
check_data_length,
limit_nums,
)
def update_data_to_bin(
self,
qlib_data_1d_dir: str,
trading_date: str = None,
end_date: str = None,
check_data_length: int = None,
delay: float = 1,
):
"""update yahoo data to bin
Parameters
----------
qlib_data_1d_dir: str
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
trading_date: str
trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
end_date: str
end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
check_data_length: int
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
delay: float
time.sleep(delay), default 1
Notes
-----
If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day
Examples
-------
$ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
# get 1m data
"""
if self.interval.lower() != "1d":
logger.warning(f"currently supports 1d data updates: --interval 1d")
# start/end date
if trading_date is None:
trading_date = datetime.datetime.now().strftime("%Y-%m-%d")
logger.warning(f"trading_date is None, use the current date: {trading_date}")
if end_date is None:
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
# download qlib 1d data
qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
if not exists_qlib_data(qlib_data_1d_dir):
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
# download data from yahoo
# NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1
self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length)
# NOTE: a larger max_workers setting here would be faster
self.max_workers = (
max(multiprocessing.cpu_count() - 2, 1)
if self.max_workers is None or self.max_workers <= 1
else self.max_workers
)
# normalize data
self.normalize_data_1d_extend(qlib_data_1d_dir)
# dump bin
_dump = DumpDataUpdate(
csv_path=self.normalize_dir,
qlib_dir=qlib_data_1d_dir,
exclude_fields="symbol,date",
max_workers=self.max_workers,
)
_dump.dump()
# parse index
_region = self.region.lower()
if _region not in ["cn", "us"]:
logger.warning(f"Unsupported region: region={_region}, component downloads will be ignored")
return
index_list = ["CSI100", "CSI300"] if _region == "cn" else ["SP500", "NASDAQ100", "DJIA", "SP400"]
get_instruments = getattr(
importlib.import_module(f"data_collector.{_region}_index.collector"), "get_instruments"
)
for _index in index_list:
get_instruments(str(qlib_data_1d_dir), _index)
if __name__ == "__main__":

View File

@@ -6,3 +6,4 @@ pandas
tqdm
lxml
yahooquery
joblib

View File

@@ -401,6 +401,8 @@ class DumpDataUpdate(DumpDataBase):
)
self._mode = self.UPDATE_MODE
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
# NOTE: all.txt only exists once for each stock
# NOTE: if a stock corresponds to multiple different time ranges, user need to modify self._update_instruments
self._update_instruments = (
self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))
.set_index([self.symbol_field_name])
@@ -409,10 +411,9 @@ class DumpDataUpdate(DumpDataBase):
# load all csv files
self._all_data = self._load_all_source_data() # type: pd.DataFrame
self._update_calendars = sorted(
self._new_calendar_list = self._old_calendar_list + sorted(
filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique())
)
self._new_calendar_list = self._old_calendar_list + self._update_calendars
def _load_all_source_data(self):
# NOTE: Need more memory
@@ -452,8 +453,16 @@ class DumpDataUpdate(DumpDataBase):
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
continue
if _code in self._update_instruments:
# exists stock, will append data
_update_calendars = (
_df[_df[self.date_field_name] > self._update_instruments[_code][self.INSTRUMENTS_START_FIELD]][
self.date_field_name
]
.sort_values()
.to_list()
)
self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
futures[executor.submit(self._dump_bin, _df, _update_calendars)] = _code
else:
# new stock
_dt_range = self._update_instruments.setdefault(_code, dict())

View File

@@ -11,7 +11,7 @@ NAME = "pyqlib"
DESCRIPTION = "A Quantitative-research Platform"
REQUIRES_PYTHON = ">=3.5.0"
VERSION = "0.6.3.99"
VERSION = "0.7.0"
# Detect Cython
try: