1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00

Compare commits

...

82 Commits

Author SHA1 Message Date
Young
ee3d4092ae release 0.8.0 2021-12-08 07:32:03 +08:00
you-n-g
ae83f9056f docs improvement (#730) 2021-12-07 21:45:29 +08:00
Pengrong Zhu
c276de4040 Fix backtest (#719)
* modify FileStorage to support multiple freqs

* modify backtest's sample documentation

* change the logging level of read data exception from error to debug

* fix the backtest exception when volume is 0 or np.nan

* fix test_storage.py

* add backtest_daily

* modify backtest_daily's docstring

* add __repr__/__str__ to Position

* fix the bug of nested_decision_execution example

Co-authored-by: Young <afe.young@gmail.com>
Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2021-12-07 19:04:23 +08:00
you-n-g
84103c7d43 Fix update record bug(#723) 2021-12-03 10:03:30 +08:00
you-n-g
6d6c586dc2 Update data crawler 2021-11-28 13:44:49 +08:00
you-n-g
54ef18ec4e Update the docs ops.py (#718)
* Update ops.py

* Update ops.py
2021-11-26 18:15:44 +08:00
demon143
0dfbf8c413 add logger (#709)
* Add files via upload

* Update README.md

* Update README.md

* Update README.md

* Delete change doc.gif

* Add files via upload

* Update README.md

* Delete change doc.gif

* Add files via upload

* Delete change doc.gif

* Add files via upload

* Update README.md

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>

* Add files via upload

* Delete Task-Gen-Recorder-CollectorV3.drawio (2).drawio.svg

* Add files via upload

* Delete Task-Gen-Recorder-CollectorV3.drawio.drawio.svg

* Add files via upload

* Delete Task-Gen-Recorder-CollectorV3.drawio.drawio (1).svg

* Add files via upload

* Delete Task-Gen-Recorder-CollectorV3.drawio (2).drawio (1).svg

* Add files via upload

* Delete Task-Gen-Recorder-CollectorV3.drawio (1).drawio (1).svg

* Add files via upload

* Add files via upload

* Delete Task-Gen-Recorder-CollectorV3.drawio (1).svg

* Delete Task-Gen-Recorder-Collector.svg

* Update manage.py

* Update utils.py

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2021-11-23 17:29:02 +08:00
cslwqxx
f9c35284e1 Update README.md 2021-11-23 15:58:32 +08:00
you-n-g
3974bfe746 Add SigAnaRecord in to workflow_by_code.py (#707)
* Update workflow_by_code.py

* Update workflow_by_code.py
2021-11-23 12:37:50 +08:00
you-n-g
45ebb1d0e0 support adding from date when updating pred (#703)
* support adding from date when updating pred

* fix updating data error
2021-11-22 19:52:52 +08:00
you-n-g
103f8579c1 Update README.md 2021-11-22 18:19:16 +08:00
fengcunguang
654033733d Add A New Baseline: ADD (#704) 2021-11-22 18:16:50 +08:00
Pengrong Zhu
d224ea447e Fix high-freq data (#702)
* fix the collector.py yahoo 1min factor calculation

* fix HFSignalRecord
2021-11-20 15:03:53 +08:00
demon143
9265b66e09 Update task_management.rst (#700) 2021-11-18 18:23:43 +08:00
you-n-g
d4b56d97b5 Update __init__.py 2021-11-18 18:22:09 +08:00
Chao Wang
0b3b95f22f Update workflow_by_code.ipynb (#697)
changed model_strategy to signal_stragey. there is no model_strategy under qlib.contrib.strategy.
2021-11-18 09:21:45 +08:00
demon143
0596174b94 Update task_management.rst (#654)
* Update task_management.rst

* Update task_management.rst

* Update task_management.rst

* Update task_management.rst

* Update task_management.rst

* Update task_management.rst

* Update task_management.rst

* Update task_management.rst

* Update task_management.rst

* Update task_management.rst

* Update task_management.rst

* Add files via upload
2021-11-18 09:20:54 +08:00
you-n-g
779b1786bd Merging Backtest improve (#694)
* Fix private import

* temporarily fix create exp conflicts for remote mlflow

Co-authored-by: Dong Zhou <Zhou.Dong@microsoft.com>
2021-11-16 11:43:35 +08:00
Lewen Wang
007082a112 Add AdaRNN baseline. (#689)
* Update TCTS.

* Update TCTS README.

* Update TCTS README.

* Update TCTS.

* Add ADARNN.

* Update README.

* Reformat ADARNN.

* Add README for adarnn.

Co-authored-by: lewwang <lwwang@microsoft.com>
2021-11-16 11:24:43 +08:00
you-n-g
4e380b611e Delete .DS_Store 2021-11-14 23:54:20 +08:00
Young
1e410c99be Support auto skip flawed data 2021-11-13 07:44:26 +00:00
Pengrong Zhu
b223c4304d add ops_warning_log to config (#685) 2021-11-12 22:23:29 +08:00
you-n-g
f2771f1beb Callable Exp (#683) 2021-11-12 14:56:22 +08:00
Ckend
01bdf6c1b1 bugfix: Fix the problem that caused highfreq's yaml to be unusable (#678) 2021-11-10 22:32:58 +08:00
Pengrong Zhu
9639a8cac9 add default protocol_version (#677)
* add default protocol_version

* add comment to serial.Serializable.get_backend
2021-11-10 14:37:18 +08:00
you-n-g
cae4c9c924 An example to get index from TSDataSampler (#679) 2021-11-10 14:35:27 +08:00
you-n-g
a2be6e28e9 handler demo cache (#606)
* handler demo  cache

* Update data_cache_demo.py

* example to reusing processed data in memory

* Skip dumping task of task_train

* FIX Black

Co-authored-by: Wangwuyi123 <51237097+Wangwuyi123@users.noreply.github.com>
2021-11-08 17:33:10 +08:00
you-n-g
fdbc666678 Fix private import 2021-11-08 09:53:58 +08:00
you-n-g
7800dd4ec9 Merge pull request #650 from microsoft/backtest_improve
Improve the backtest design and APIs
2021-11-08 09:10:33 +08:00
Young
3fa48d7017 simplify record tmp 2021-11-05 12:57:14 +00:00
Mark Zhao
b18132dce1 More citations (#673)
* check lexsort

* check lexsort

* lexsort comment

* lexsort comment

* fix GPU identification bug

* Create README.md

* Update README.md

* Create README.md

* Create README.md

* Update README.md

* Create README.md

* Update README.md

* Update README.md

* Create README.md

* Create README.md

* Create README.md
2021-11-05 19:43:50 +08:00
you-n-g
16957176a9 Update README.md
Update  benchmark link
2021-11-05 13:00:32 +08:00
fengcunguang
f0b9a807ea Add A New Baseline: TCN (#668) 2021-11-04 20:30:52 +08:00
Mark Zhao
5ee2d9496b GPU identification bug fix for GATs (#669)
* check lexsort

* check lexsort

* lexsort comment

* lexsort comment

* fix GPU identification bug
2021-11-03 14:41:46 +08:00
Young
4f2d6b0d84 fix pytorch memory amount error 2021-11-02 20:41:39 +08:00
Young
3943b7001f fix CI bug for AyncCaller 2021-11-02 14:32:09 +08:00
Young
2593185721 Simplify TSDataset and async recorder 2021-11-02 11:07:40 +08:00
Young
7a884fa9f2 remove redundant file only when remote artifact 2021-11-01 18:55:44 +08:00
Dong Zhou
d929d4bb21 rm recorder temp file 2021-11-01 09:29:44 +00:00
Young
e54b019ee2 solve init kwargs conflictions 2021-11-01 06:22:25 +00:00
Young
426b98a3bc make the logic of online manager cleaner 2021-11-01 02:40:54 +00:00
Young
82f8ff9066 Update seperate dataframe 2021-11-01 00:51:21 +08:00
you-n-g
7b15682c63 Update initialization.rst 2021-10-30 21:42:22 +08:00
you-n-g
df36839a7f Update README.md 2021-10-30 21:36:15 +08:00
Pengrong Zhu
4cecaba618 fix 'duplicate axis' error when updating cache (#661) 2021-10-29 22:52:42 +08:00
Pengrong Zhu
63b823f343 add logs in case of ops.py errors (#659)
* add logs in case of ops.py errors

* Update qlib/data/ops.py

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2021-10-29 08:52:43 +08:00
you-n-g
e41c0ac90a Adjusting gbdt.py's parameter (#660)
* Update gbdt.py

* Update gbdt.py

* Update gbdt.py
2021-10-28 19:43:05 +08:00
Young
31e9d529de Add multi horizon task generator 2021-10-28 00:01:19 +08:00
Young
5fa56703ae add handler pickle attr, enhance init_instance_by_config 2021-10-26 23:32:33 +08:00
Dong Zhou
c6bb11fe56 avoid trade without enough cash 2021-10-25 05:46:19 +00:00
Dong Zhou
3d7ebd1fe0 add back trade_val 2021-10-22 10:13:15 +00:00
Dong Zhou
7313b4dad0 fix impact cost 2021-10-22 08:58:37 +00:00
Dong Zhou
b70caff522 add doc 2021-10-22 08:49:20 +00:00
Dong Zhou
96b422a906 support market impact cost 2021-10-22 08:44:47 +00:00
Young
64130d9407 Fix the aggregation function of IndexData 2021-10-22 15:20:45 +08:00
Young
a58bc03a8e add sepdf(make mini project only rely on qlib) 2021-10-21 13:15:02 +00:00
Young
f537222ce3 make handler seperable 2021-10-21 12:38:24 +00:00
Dong Zhou
c427c64845 fix calendar 2021-10-19 06:17:53 +00:00
Young
22ff8fdc44 simple change log 2021-10-16 17:14:37 +00:00
Young
4efb0a75c1 Being compatible with previous Qlib version 2021-10-16 16:43:38 +00:00
Young
052aad7982 simplify signal parameter 2021-10-15 14:48:31 +00:00
Young
12f05c7182 Merge branch 'backtest_improve' of github.com:microsoft/qlib into backtest_improve 2021-10-15 11:27:33 +00:00
Young
ac08468330 Make static prediction easier 2021-10-15 11:21:03 +00:00
Dong Zhou
df9745f134 support empty order 2021-10-15 09:07:03 +00:00
Dong Zhou
2e49a5f7c0 fix order generator 2021-10-15 07:04:47 +00:00
you-n-g
3ab5721448 Fix OrderGenerator's return value 2021-10-15 14:28:08 +08:00
you-n-g
6a94b45503 Update order_generator.py 2021-10-15 13:52:55 +08:00
you-n-g
7c31012b50 Auto injecting model and dataset for Recorder (#645)
* Auto injecting model and dataset for Recorder

* Support using Feature in expression
2021-10-15 13:50:24 +08:00
you-n-g
334b92ace7 Checking dataset empty (#647)
* Checking dataset empty

* add dataset checker
2021-10-14 23:35:12 +08:00
you-n-g
9a175d7507 improve the doc of auto init (#541)
* improve the doc of auto init

* Update setup.py

* Update setup.py

* change cvxpy version

Co-authored-by: Wangwuyi123 <51237097+Wangwuyi123@users.noreply.github.com>
2021-10-12 11:58:27 +08:00
Lewen Wang
17ea44e0cf Update TCTS. (#643)
* Update TCTS.

* Update TCTS README.

* Update TCTS README.

* Update TCTS.

Co-authored-by: lewwang <lwwang@microsoft.com>
2021-10-12 10:08:48 +08:00
you-n-g
c0ce712be9 more detailed docs for workflow (#639)
* more detailed docs for workflow

* add more detailed docs for workflow
2021-10-11 15:38:18 +08:00
demon143
8e81a017c1 Update manage.py (#628)
* Update manage.py

* Update manage.py

* Update manage.py

* Create manage.py

* Update manage.py

* Update qlib/workflow/task/manage.py

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2021-10-11 15:37:50 +08:00
you-n-g
706727988c Update README.md 2021-10-09 23:37:07 +08:00
you-n-g
e99224e5c2 Update benchmark based on new backtest (#634)
* free random seed

* update model baselines

* more robust for parameters
2021-10-07 22:57:19 +08:00
Pengrong Zhu
8c8d1336de fix workflow_config_lightgbm_multi_freq.yaml (#635) 2021-10-06 17:18:27 +08:00
Pengrong Zhu
d01de411a8 add support for macos-11 (#630)
Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2021-10-03 16:49:17 +08:00
Young
28fe4d4bb4 update file strategy test 2021-10-03 14:58:37 +08:00
Young
873129aa9b update fix CI tests bugs 2021-10-03 14:58:37 +08:00
Young
3a152f9b8b fix CI 2021-10-03 14:58:37 +08:00
Young
2b75b41a08 remove 3.6 2021-10-03 14:58:37 +08:00
you-n-g
00d17f0a52 Update python-publish.yml 2021-10-01 03:03:26 +08:00
173 changed files with 5496 additions and 1366 deletions

View File

@@ -12,8 +12,9 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, macos-latest]
python-version: [3.6, 3.7, 3.8, 3.9]
os: [windows-latest, macos-latest, macos-11]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
steps:
- uses: actions/checkout@v2
@@ -44,7 +45,8 @@ jobs:
- name: Build wheel on Linux
uses: RalfG/python-wheels-manylinux-build@v0.3.1-manylinux2010_x86_64
with:
python-versions: 'cp36-cp36m cp37-cp37m cp38-cp38'
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-versions: 'cp37-cp37m cp38-cp38'
build-requirements: 'numpy cython'
- name: Set up Python
uses: actions/setup-python@v2

View File

@@ -13,7 +13,8 @@ jobs:
strategy:
matrix:
os: [windows-latest, ubuntu-18.04, ubuntu-20.04]
python-version: [3.6, 3.7, 3.8]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
steps:
- uses: actions/checkout@v2
@@ -49,15 +50,6 @@ jobs:
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
pip install -e .
- name: Test data downloads
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
else
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
fi
shell: bash
- name: Install test dependencies
run: |
pip install --upgrade pip

View File

@@ -10,10 +10,12 @@ on:
jobs:
build:
runs-on: macos-latest
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
os: [macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
steps:
- uses: actions/checkout@v2

View File

@@ -159,6 +159,21 @@ Version 0.5.0
- Add baselines
- public data crawler
Version greater than Version 0.5.0
Version 0.8.0
--------------------
- The backtest is greatly refactored.
- Nested decision execution framework is supported
- There are lots of changes for daily trading, it is hard to list all of them. But a few important changes could be noticed
- The trading limitation is more accurate;
- In `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/backtest/exchange.py#L160>`_, longing and shorting actions share the same action.
- In `current verison <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/backtest/exchange.py#L304>`_, the trading limitation is different between loging and shorting action.
- The constant is different when calculating annualized metrics.
- `Current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/contrib/evaluate.py#L42>`_ uses more accurate constant than `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/evaluate.py#L22>`_
- `A new version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/tests/data.py#L17>`_ of data is released. Due to the unstability of Yahoo data source, the data may be different after downloading data again.
- Users could chec kout the backtesting results between `Current version <https://github.com/microsoft/qlib/tree/7c31012b507a3823117bddcc693fc64899460b2a/examples/benchmarks>`_ and `previous version <https://github.com/microsoft/qlib/tree/v0.7.2/examples/benchmarks>`_
Other Versions
----------------------------------
Please refer to `Github release Notes <https://github.com/microsoft/qlib/releases>`_

View File

@@ -11,6 +11,9 @@
Recent released features
| Feature | Status |
| -- | ------ |
| ADD model | [Released](https://github.com/microsoft/qlib/pull/704) on Nov 22, 2021 |
| ADARNN model | [Released](https://github.com/microsoft/qlib/pull/689) on Nov 14, 2021 |
| TCN model | [Released](https://github.com/microsoft/qlib/pull/668) on Nov 4, 2021 |
|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
@@ -79,7 +82,7 @@ At the module level, Qlib is a platform that consists of the above components. T
| Name | Description |
| ------ | ----- |
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides a high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides a flexible interface to control the training process of models, which enable algorithms to control the training process. |
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Decision Generator` will generate the target trading decisions(i.e. portfolio, orders) to be executed by `Execution Env` (i.e. the trading market). There may be multiple levels of `Trading Agent` and `Execution Env` (e.g. an _order executor trading agent and intraday order execution environment_ could behave like an interday trading environment and nested in _daily portfolio management trading agent and interday trading environment_ ) |
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
* The modules with hand-drawn style are under development and will be released in the future.
@@ -100,7 +103,6 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
This table demonstrates the supported Python version of `Qlib`:
| | install with pip | install from source | plot |
| ------------- |:---------------------:|:--------------------:|:----:|
| Python 3.6 | :heavy_check_mark: | :heavy_check_mark: (only with `Anaconda`) | :heavy_check_mark: |
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
@@ -279,22 +281,25 @@ The automatic workflow may not suit the research workflow of all Quant researche
# [Quant Model (Paper) Zoo](examples/benchmarks)
Here is a list of models built on `Qlib`.
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](qlib/contrib/model/xgboost.py)
- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](qlib/contrib/model/gbdt.py)
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](qlib/contrib/model/catboost_model.py)
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural omputation 1997)](qlib/contrib/model/pytorch_lstm.py)
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](qlib/contrib/model/pytorch_gru.py)
- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](qlib/contrib/model/pytorch_alstm.py)
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](qlib/contrib/model/pytorch_gats.py)
- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](qlib/contrib/model/pytorch_sfm.py)
- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/tft.py)
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](qlib/contrib/model/pytorch_tabnet.py)
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](qlib/contrib/model/double_ensemble.py)
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](qlib/contrib/model/pytorch_tcts.py)
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py)
- [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py)
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](qlib/contrib/model/pytorch_tra.py)
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](examples/benchmarks/XGBoost/)
- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](examples/benchmarks/LightGBM/)
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](examples/benchmarks/CatBoost/)
- [MLP based on pytorch](examples/benchmarks/MLP/)
- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural computation 1997)](examples/benchmarks/LSTM/)
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](examples/benchmarks/GRU/)
- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](examples/benchmarks/ALSTM)
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](examples/benchmarks/GATs/)
- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](examples/benchmarks/SFM/)
- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/)
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](examples/benchmarks/TabNet/)
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](examples/benchmarks/DoubleEnsemble/)
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](examples/benchmarks/TCTS/)
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](examples/benchmarks/Transformer/)
- [Localformer based on pytorch (Juyong Jiang, et al.)](examples/benchmarks/Localformer/)
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](examples/benchmarks/TRA/)
- [TCN based on pytorch (Shaojie Bai, et al. 2018)](examples/benchmarks/TCN/)
- [ADARNN based on pytorch (YunTao Du, et al. 2021)](examples/benchmarks/ADARNN/)
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
Your PR of new Quant models is highly welcomed.
@@ -307,7 +312,7 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
- Users can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py run --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)
## Run multiple models
@@ -317,7 +322,7 @@ The script will create a unique virtual environment for each model, and delete t
Here is an example of running all the models for 10 iterations:
```python
python run_all_model.py 10
python run_all_model.py run 10
```
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).

View File

@@ -1 +1 @@
0.7.2.99
0.8.0

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 198 KiB

View File

@@ -11,7 +11,10 @@ Introduction
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.The processes of task generation, model training and combine and collect data are shown in the following figure.
.. image:: ../_static/img/Task-Gen-Recorder-Collector.svg
:align: center
This whole process can be used in `Online Serving <../component/online.html>`_.
@@ -74,6 +77,8 @@ If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to
Task Collecting
===============
Before collecting model training results, you need to use the ``qlib.init`` to specify the path of mlruns.
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
@@ -82,8 +87,10 @@ To collect the results of ``task`` after training, ``Qlib`` provides `Collector
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
For example: {C1: object, C2: object} ---``Ensemble``---> object
For example: {C1: object, C2: object} ---``Ensemble``---> object.
You can set the ensembles you want in the ``Collector``'s process_list.
Common ensembles include ``AverageEnsemble`` and ``RollingEnsemble``. Average ensemble is used to ensemble the results of different models in the same time period. Rollingensemble is used to ensemble the results of different models in the same time period
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.

View File

@@ -1,114 +0,0 @@
.. _backtest:
============================================
Intraday Trading: Model&Strategy Testing
============================================
.. currentmodule:: qlib
Introduction
===================
``Intraday Trading`` is designed to test models and strategies, which help users to check the performance of a custom model/strategy.
.. note::
``Intraday Trading`` uses ``Order Executor`` to trade and execute orders output by ``Portfolio Strategy``. ``Order Executor`` is a component in `Qlib Framework <../introduction/introduction.html#framework>`_, which can execute orders. ``VWAP Executor`` and ``Close Executor`` is supported by ``Qlib`` now. In the future, ``Qlib`` will support ``HighFreq Executor`` also.
Example
===========================
Users need to generate a `prediction score`(a pandas DataFrame) with MultiIndex<instrument, datetime> and a `score` column. And users need to assign a strategy used in backtest, if strategy is not assigned,
a `TopkDropoutStrategy` strategy with `(topk=50, n_drop=5, risk_degree=0.95, limit_threshold=0.0095)` will be used.
If ``Strategy`` module is not users' interested part, `TopkDropoutStrategy` is enough.
The simple example of the default strategy is as follows.
.. code-block:: python
from qlib.contrib.evaluate import backtest
# pred_score is the prediction score
report, positions = backtest(pred_score, topk=50, n_drop=0.5, limit_threshold=0.0095)
To know more about backtesting with a specific ``Strategy``, please refer to `Portfolio Strategy <strategy.html>`_.
To know more about the prediction score `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
Prediction Score
-----------------
The `prediction score` is a pandas DataFrame. Its index is <datetime(pd.Timestamp), instrument(str)> and it must
contains a `score` column.
A prediction sample is shown as follows.
.. code-block:: python
datetime instrument score
2019-01-04 SH600000 -0.505488
2019-01-04 SZ002531 -0.320391
2019-01-04 SZ000999 0.583808
2019-01-04 SZ300569 0.819628
2019-01-04 SZ001696 -0.137140
... ...
2019-04-30 SZ000996 -1.027618
2019-04-30 SH603127 0.225677
2019-04-30 SH603126 0.462443
2019-04-30 SH603133 -0.302460
2019-04-30 SZ300760 -0.126383
``Forecast Model`` module can make predictions, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
Backtest Result
------------------
The backtest results are in the following form:
.. code-block:: python
risk
excess_return_without_cost mean 0.000605
std 0.005481
annualized_return 0.152373
information_ratio 1.751319
max_drawdown -0.059055
excess_return_with_cost mean 0.000410
std 0.005478
annualized_return 0.103265
information_ratio 1.187411
max_drawdown -0.075024
- `excess_return_without_cost`
- `mean`
Mean value of the `CAR` (cumulative abnormal return) without cost
- `std`
The `Standard Deviation` of `CAR` (cumulative abnormal return) without cost.
- `annualized_return`
The `Annualized Rate` of `CAR` (cumulative abnormal return) without cost.
- `information_ratio`
The `Information Ratio` without cost. please refer to `Information Ratio IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
- `max_drawdown`
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) without cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
- `excess_return_with_cost`
- `mean`
Mean value of the `CAR` (cumulative abnormal return) series with cost
- `std`
The `Standard Deviation` of `CAR` (cumulative abnormal return) series with cost.
- `annualized_return`
The `Annualized Rate` of `CAR` (cumulative abnormal return) with cost.
- `information_ratio`
The `Information Ratio` with cost. please refer to `Information Ratio IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
- `max_drawdown`
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) with cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
Reference
==============
To know more about ``Intraday Trading``, please refer to `Intraday Trading <../reference/api.html#module-qlib.contrib.evaluate>`_.

View File

@@ -1,120 +1,31 @@
.. _highfreq:
============================================
Design of hierarchical order execution framework
Design of Nested Decision Execution Framework for High-Frequency Trading
============================================
.. currentmodule:: qlib
Introduction
===================
In order to support reinforcement learning algorithms for high-frequency trading, a corresponding framework is required. None of the publicly available high-frequency trading frameworks now consider multi-layer trading mechanisms, and the currently designed algorithms cannot directly use existing frameworks.
In addition to supporting the basic intraday multi-layer trading, the linkage with the day-ahead strategy is also a factor that affects the performance evaluation of the strategy. Different day strategies generate different order distributions and different patterns on different stocks. To verify that high-frequency trading strategies perform well on real trading orders, it is necessary to support day-frequency and high-frequency multi-level linkage trading. In addition to more accurate backtesting of high-frequency trading algorithms, if the distribution of day-frequency orders is considered when training a high-frequency trading model, the algorithm can also be optimized more for product-specific day-frequency orders.
Therefore, innovation in the high-frequency trading framework is necessary to solve the various problems mentioned above, for which we designed a hierarchical order execution framework that can link daily-frequency and intra-day trading at different granularities.
Daily trading (e.g. portfolio management) and intraday trading (e.g. orders execution) are two hot topics in Quant investment and usually studied separately.
To get the join trading performance of daily and intraday trading, they must interact with each other and run backtest jointly.
In order to support the joint backtest strategies in multiple levels, a corresponding framework is required. None of the publicly available high-frequency trading frameworks considers multi-level joint trading, which make the backtesting aforementioned inaccurate.
Besides backtesting, the optimization of strategies from different levels is not standalone and can be affected by each other.
For example, the best portfolio management strategy may change with the performance of order executions(e.g. a portfolio with higher turnover may becomes a better choice when we imporve the order execution strategies).
To achieve the overall good performance , it is necessary to consider the interaction of strategies in different level.
Therefore, building a new framework for trading in multiple levels becomes necessary to solve the various problems mentioned above, for which we designed a nested decision execution framework that consider the interaction of strategies.
.. image:: ../_static/img/framework.svg
The design of the framework is shown in the figure above. At each layer consists of Trading Agent and Execution Env. The Trading Agent has its own data processing module (Information Extractor), forecasting module (Forecast Model) and decision generator (Decision Generator). The trading algorithm generates the corresponding decisions by the Decision Generator based on the forecast signals output by the Forecast Module, and the decisions generated by the trading algorithm are passed to the Execution Env, which returns the execution results. Here the frequency of trading algorithm, decision content and execution environment can be customized by users (e.g. intra-day trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The hierarchical order execution framework is user-defined in terms of hierarchy division and decision frequency, making it easy for users to explore the effects of combining different levels of trading algorithms and breaking down the barriers between different levels of trading algorithm optimization.
In addition to the innovation in the framework, the hierarchical order execution framework also takes into account various details of the real backtesting environment, minimizing the differences with the final real environment as much as possible. At the same time, the framework is designed to unify the interface between online and offline (e.g. data pre-processing level supports using the same set of code to process both offline and online data) to reduce the cost of strategy go-live as much as possible.
Prepare Data
===================
.. _data:: ../../examples/highfreq/README.md
The design of the framework is shown in the yellow part in the middle of the figure above. Each level consists of ``Trading Agent`` and ``Execution Env``. ``Trading Agent`` has its own data processing module (``Information Extractor``), forecasting module (``Forecast Model``) and decision generator (``Decision Generator``). The trading algorithm generates the decisions by the ``Decision Generator`` based on the forecast signals output by the ``Forecast Module``, and the decisions generated by the trading algorithm are passed to the ``Execution Env``, which returns the execution results.
The frequency of trading algorithm, decision content and execution environment can be customized by users (e.g. intraday trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The flexibility of nested decision execution framework makes it easy for users to explore the effects of combining different levels of trading strategies and break down the optimization barriers between different levels of trading algorithm.
Example
===========================
Here is an example of highfreq execution.
.. code-block:: python
import qlib
# init qlib
provider_uri_day = "~/.qlib/qlib_data/cn_data"
provider_uri_1min = "~/.qlib/qlib_data/cn_data_1min"
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
qlib.init(provider_uri=provider_uri_day, expression_cache=None, dataset_cache=None)
# data freq and backtest time
freq = "1min"
inst_list = D.list_instruments(D.instruments("all"), as_list=True)
start_time = "2020-01-01"
start_time = "2020-01-31"
When initializing qlib, if the default data is used, then both daily and minute frequency data need to be passed in.
.. code-block:: python
# random order strategy config
strategy_config = {
"class": "RandomOrderStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
"trade_range": TradeRangeByTime("9:30", "15:00"),
"sample_ratio": 1.0,
"volume_ratio": 0.01,
"market": market,
},
}
.. code-block:: python
# backtest config
backtest_config = {
"start_time": start_time,
"end_time": end_time,
"account": 100000000,
"benchmark": None,
"exchange_kwargs": {
"freq": freq,
"limit_threshold": 0.095,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
"codes": market,
},
"pos_type": "InfPosition", # Position with infinitive position
}
please refer to "../../qlib/backtest".
.. code-block:: python
# excutor config
executor_config = {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "day",
"inner_executor": {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": freq,
"generate_portfolio_metrics": True,
"verbose": False,
# "verbose": True,
"indicator_config": {
"show_indicator": False,
},
},
},
"inner_strategy": {
"class": "TWAPStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
},
"track_data": True,
"generate_portfolio_metrics": True,
"indicator_config": {
"show_indicator": True,
},
},
}
NestedExecutor represents not the innermost layer, the initialization parameters should contain inner_executor and inner_strategy. simulatorExecutor represents the current excutor is the innermost layer, the innermost strategy used here is the TWAP strategy, the framework currently also supports the VWAP strategy
.. code-block:: python
# backtest
portfolio_metrics_dict, indicator_dict = backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
The metrics of backtest are included in the portfolio_metrics_dict and indicator_dict.
An example of nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.

View File

@@ -12,7 +12,9 @@ Introduction
Because the components in ``Qlib`` are designed in a loosely-coupled way, ``Portfolio Strategy`` can be used as an independent module also.
``Qlib`` provides several implemented portfolio strategies. Also, ``Qlib`` supports custom strategy, users can customize strategies according to their own needs.
``Qlib`` provides several implemented portfolio strategies. Also, ``Qlib`` supports custom strategy, users can customize strategies according to their own requirements.
After users specifying the models(forecasting signals) and strategies, running backtest will help users to check the performance of a custom model(forecasting signals)/strategy.
Base Class & Interface
======================
@@ -82,38 +84,203 @@ TopkDropoutStrategy
Usage & Example
====================
``Portfolio Strategy`` can be specified in the ``Intraday Trading(Backtest)``, the example is as follows.
First, user can create a model to get trading signals(the variable name is ``pred_score`` in following cases).
Prediction Score
-----------------
The `prediction score` is a pandas DataFrame. Its index is <datetime(pd.Timestamp), instrument(str)> and it must
contains a `score` column.
A prediction sample is shown as follows.
.. code-block:: python
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import backtest
STRATEGY_CONFIG = {
"topk": 50,
"n_drop": 5,
}
BACKTEST_CONFIG = {
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
# use default strategy
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
datetime instrument score
2019-01-04 SH600000 -0.505488
2019-01-04 SZ002531 -0.320391
2019-01-04 SZ000999 0.583808
2019-01-04 SZ300569 0.819628
2019-01-04 SZ001696 -0.137140
... ...
2019-04-30 SZ000996 -1.027618
2019-04-30 SH603127 0.225677
2019-04-30 SH603126 0.462443
2019-04-30 SH603133 -0.302460
2019-04-30 SZ300760 -0.126383
# pred_score is the `prediction score` output by Model
report_normal, positions_normal = backtest(
pred_score, strategy=strategy, **BACKTEST_CONFIG
)
``Forecast Model`` module can make predictions, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.
Running backtest
-----------------
- In most cases, users could backtest their portfolio management strategy with ``backtest_daily``.
.. code-block:: python
from pprint import pprint
import qlib
import pandas as pd
from qlib.utils.time import Freq
from qlib.utils import flatten_dict
from qlib.contrib.evaluate import backtest_daily
from qlib.contrib.evaluate import risk_analysis
from qlib.contrib.strategy import TopkDropoutStrategy
# init qlib
qlib.init(provider_uri=<qlib data dir>)
CSI300_BENCH = "SH000300"
STRATEGY_CONFIG = {
"topk": 50,
"n_drop": 5,
# pred_score, pd.Series
"signal": pred_score,
}
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
report_normal, positions_normal = backtest_daily(
start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_obj
)
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"], freq=analysis_freq
)
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
pprint(analysis_df)
- If users would like to control their strategies in a more detailed(e.g. users have a more advanced version of executor), user could follow this example.
.. code-block:: python
from pprint import pprint
import qlib
import pandas as pd
from qlib.utils.time import Freq
from qlib.utils import flatten_dict
from qlib.backtest import backtest, executor
from qlib.contrib.evaluate import risk_analysis
from qlib.contrib.strategy import TopkDropoutStrategy
# init qlib
qlib.init(provider_uri=<qlib data dir>)
CSI300_BENCH = "SH000300"
FREQ = "day"
STRATEGY_CONFIG = {
"topk": 50,
"n_drop": 5,
# pred_score, pd.Series
"signal": pred_score,
}
EXECUTOR_CONFIG = {
"time_per_step": "day",
"generate_portfolio_metrics": True,
}
backtest_config = {
"start_time": "2017-01-01",
"end_time": "2020-08-01",
"account": 100000000,
"benchmark": CSI300_BENCH,
"exchange_kwargs": {
"freq": FREQ,
"limit_threshold": 0.095,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
},
}
# strategy object
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
# executor object
executor_obj = executor.SimulatorExecutor(**EXECUTOR_CONFIG)
# backtest
portfolio_metric_dict, indicator_dict = backtest(executor=executor_obj, strategy=strategy_obj, **backtest_config)
analysis_freq = "{0}{1}".format(*Freq.parse(FREQ))
# backtest info
report_normal, positions_normal = portfolio_metric_dict.get(analysis_freq)
# analysis
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"], freq=analysis_freq
)
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
# log metrics
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
# print out results
pprint(f"The following are analysis results of benchmark return({analysis_freq}).")
pprint(risk_analysis(report_normal["bench"], freq=analysis_freq))
pprint(f"The following are analysis results of the excess return without cost({analysis_freq}).")
pprint(analysis["excess_return_without_cost"])
pprint(f"The following are analysis results of the excess return with cost({analysis_freq}).")
pprint(analysis["excess_return_with_cost"])
Result
------------------
The backtest results are in the following form:
.. code-block:: python
risk
excess_return_without_cost mean 0.000605
std 0.005481
annualized_return 0.152373
information_ratio 1.751319
max_drawdown -0.059055
excess_return_with_cost mean 0.000410
std 0.005478
annualized_return 0.103265
information_ratio 1.187411
max_drawdown -0.075024
- `excess_return_without_cost`
- `mean`
Mean value of the `CAR` (cumulative abnormal return) without cost
- `std`
The `Standard Deviation` of `CAR` (cumulative abnormal return) without cost.
- `annualized_return`
The `Annualized Rate` of `CAR` (cumulative abnormal return) without cost.
- `information_ratio`
The `Information Ratio` without cost. please refer to `Information Ratio IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
- `max_drawdown`
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) without cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
- `excess_return_with_cost`
- `mean`
Mean value of the `CAR` (cumulative abnormal return) series with cost
- `std`
The `Standard Deviation` of `CAR` (cumulative abnormal return) series with cost.
- `annualized_return`
The `Annualized Rate` of `CAR` (cumulative abnormal return) with cost.
- `information_ratio`
The `Information Ratio` with cost. please refer to `Information Ratio IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
- `max_drawdown`
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) with cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
Reference
===================
To know more about ``Portfolio Strategy``, please refer to `Strategy API <../reference/api.html#module-qlib.contrib.strategy.strategy>`_.
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.

View File

@@ -53,6 +53,9 @@ Below is a typical config file of ``qrun``.
kwargs:
topk: 50
n_drop: 5
signal:
- <MODEL>
- <DATASET>
backtest:
limit_threshold: 0.095
account: 100000000
@@ -240,6 +243,9 @@ The following script is the configuration of `backtest` and the `strategy` used
kwargs:
topk: 50
n_drop: 5
signal:
- <MODEL>
- <DATASET>
backtest:
limit_threshold: 0.095
account: 100000000

View File

@@ -38,8 +38,8 @@ Document Structure
Workflow: Workflow Management <component/workflow.rst>
Data Layer: Data Framework&Usage <component/data.rst>
Forecast Model: Model Training & Prediction <component/model.rst>
Strategy: Portfolio Management <component/strategy.rst>
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
Portfolio Management and Backtest <component/strategy.rst>
Nested Decision Execution: High-Frequency Trading <component/highfreq.rst>
Qlib Recorder: Experiment Management <component/recorder.rst>
Analysis: Evaluation & Results Analysis <component/report.rst>
Online Serving: Online Management & Strategy & Tool <component/online.rst>

View File

@@ -34,9 +34,14 @@ Name Description
`Workflow` layer `Workflow` layer covers the whole workflow of quantitative investment.
`Information Extractor` extracts data for models. `Forecast Model` focuses
on producing all kinds of forecast signals (e.g. _alpha_, risk) for other
modules. With these signals `Portfolio Generator` will generate the target
portfolio and produce orders to be executed by `Order Executor`.
on producing all kinds of forecast signals (e.g. *alpha*, risk) for other
modules. With these signals `Decision Generator` will generate the target
trading decisions(i.e. portfolio, orders) to be executed by `Execution Env`
(i.e. the trading market). There may be multiple levels of `Trading Agent`
and `Execution Env` (e.g. an *order executor trading agent and intraday
order execution environment* could behave like an interday trading
environment and nested in *daily portfolio management trading agent and
interday trading environment* )
`Interface` layer `Interface` layer tries to present a user-friendly interface for the underlying
system. `Analyser` module will provide users detailed analysis reports of

View File

@@ -48,6 +48,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
- ``qlib.config.REG_CN``: China stock market.
Different modes will result in different trading limitations and costs.
The region is just `shortcuts for defining a batch of configurations <https://github.com/microsoft/qlib/blob/main/qlib/config.py#L239>`_. Users can set the key configurations manually if the existing region setting can't meet their requirements.
- `redis_host`
Type: str, optional parameter(default: "127.0.0.1"), host of `redis`
The lock and cache mechanism relies on redis.

View File

@@ -0,0 +1,4 @@
# AdaRNN
* Code: [https://github.com/jindongwang/transferlearning/tree/master/code/deep/adarnn](https://github.com/jindongwang/transferlearning/tree/master/code/deep/adarnn)
* Paper: [AdaRNN: Adaptive Learning and Forecasting for Time Series](https://arxiv.org/pdf/2108.04443.pdf).

View File

@@ -0,0 +1,4 @@
pandas==1.1.2
numpy==1.17.4
scikit_learn==0.23.2
torch==1.7.0

View File

@@ -0,0 +1,88 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: ADARNN
module_path: qlib.contrib.model.pytorch_adarnn
kwargs:
d_feat: 6
hidden_size: 64
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 1e-3
early_stop: 20
batch_size: 800
metric: loss
loss: mse
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,3 @@
# ADD
* Paper: [ADD: Augmented Disentanglement Distillation Framework for Improving Stock Trend Forecasting](https://arxiv.org/abs/2012.06289).

View File

@@ -0,0 +1,4 @@
numpy==1.17.4
pandas==1.1.2
scikit_learn==0.23.2
torch==1.7.0

View File

@@ -0,0 +1,94 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: ADD
module_path: qlib.contrib.model.pytorch_add
kwargs:
d_feat: 6
hidden_size: 64
num_layers: 2
dropout: 0.1
dec_dropout: 0.0
n_epochs: 200
lr: 1e-3
early_stop: 20
batch_size: 5000
metric: ic
base_model: GRU
gamma: 0.1
gamma_clip: 0.2
optimizer: adam
mu: 0.2
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -86,4 +87,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -100,4 +101,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -35,8 +35,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -94,4 +95,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -0,0 +1,2 @@
# Gated Recurrent Unit (GRU)
* Paper: [Learning Phrase Representations using RNN EncoderDecoder for Statistical Machine Translation](https://aclanthology.org/D14-1179.pdf).

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -85,4 +86,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -0,0 +1,2 @@
# Long Short-Term Memory (LSTM)
* Paper: [Long Short-Term Memory](https://direct.mit.edu/neco/article-abstract/9/8/1735/6109/Long-Short-Term-Memory?redirectedFrom=fulltext).

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -85,4 +86,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -14,7 +14,7 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5

View File

@@ -33,6 +33,9 @@ port_analysis_config: &port_analysis_config
kwargs:
topk: 50
n_drop: 5
signal:
- <MODEL>
- <DATASET>
backtest:
verbose: False
limit_threshold: 0.095
@@ -80,4 +83,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -76,4 +77,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -31,18 +31,22 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
limit_threshold: 0.095
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -0,0 +1 @@
# Localformer

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -0,0 +1 @@
# Multi-Layer Perceptron (MLP)

View File

@@ -41,8 +41,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -98,4 +99,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -85,4 +86,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -1,51 +1,68 @@
# Benchmarks Performance
This page lists a batch of methods designed for alpha seeking. Each method tries to give scores/predictions for all stocks each day(e.g. forecasting the future excess return of stocks). The scores/predictions of the models will be used as the mined alpha. Investing in stocks with higher scores is expected to yield more profit.
The alpha is evaluated in two ways.
1. The correlation between the alpha and future return.
1. Constructing portfolio based on the alpha and evaluating the final total return.
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
<!--
> If you need to reproduce the results below, please use the **v1** dataset: `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn --version v1`
>
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ -->
> NOTE:
> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference.
## Alpha360 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|---|---|---|---|---|---|---|---|---|
| Linear | Alpha360 | 0.0150±0.00 | 0.1049±0.00| 0.0284±0.00 | 0.1970±0.00 | -0.0659±0.00 | -0.7072±0.00| -0.2955±0.00 |
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0397±0.00 | 0.2878±0.00| 0.0470±0.00 | 0.3703±0.00 | 0.0342±0.00 | 0.4092±0.00| -0.1057±0.00 |
| XGBoost (Tianqi Chen, et al.) | Alpha360 | 0.0400±0.00 | 0.3031±0.00| 0.0461±0.00 | 0.3862±0.00 | 0.0528±0.00 | 0.6307±0.00| -0.1113±0.00 |
| LightGBM (Guolin Ke, et al.) | Alpha360 | 0.0399±0.00 | 0.3075±0.00| 0.0492±0.00 | 0.4019±0.00 | 0.0323±0.00 | 0.4370±0.00| -0.0917±0.00 |
| MLP | Alpha360 | 0.0285±0.00 | 0.1981±0.02| 0.0402±0.00 | 0.2993±0.02 | 0.0073±0.02 | 0.0880±0.22| -0.1446±0.03 |
| GRU (Kyunghyun Cho, et al.) | Alpha360 | 0.0490±0.01 | 0.3787±0.05| 0.0581±0.00 | 0.4664±0.04 | 0.0726±0.02 | 0.9817±0.34| -0.0902±0.03 |
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
| Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 |
| Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 |
| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0491±0.01 | 0.3868±0.06 | 0.0589±0.00 | 0.4802±0.04 | 0.0898±0.02 | 1.2490±0.32 | -0.0778±0.02 |
## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|---|---|---|---|---|---|---|---|---|
| Linear | Alpha158 | 0.0393±0.00 | 0.2980±0.00| 0.0475±0.00 | 0.3546±0.00 | 0.0795±0.00 | 1.0712±0.00| -0.1449±0.00 |
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha158 | 0.0503±0.00 | 0.3586±0.00| 0.0483±0.00 | 0.3667±0.00 | 0.1080±0.00 | 1.1561±0.00| -0.0787±0.00 |
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
| Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 |
| Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 |
| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features)| 0.0409±0.00 | 0.3253±0.04 | 0.0488±0.00 | 0.4045±0.02 | 0.0673±0.02 | 1.0389±0.39 | -0.0830±0.02 |
| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0442±0.00 | 0.3426±0.03 | 0.0555±0.00 | 0.4395±0.03 | 0.0833±0.03 | 1.2064±0.36 | -0.0849±0.02 |
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
| TCN(Shaojie Bai, et al.) | Alpha158 | 0.0275±0.00 | 0.2157±0.01 | 0.0411±0.00 | 0.3379±0.01 | 0.0190±0.02 | 0.2887±0.27 | -0.1202±0.03 |
| TabNet(Sercan O. Arik, et al.) | Alpha158 | 0.0204±0.01 | 0.1554±0.07 | 0.0333±0.00 | 0.2552±0.05 | 0.0227±0.04 | 0.3676±0.54 | -0.1089±0.08 |
| Transformer(Ashish Vaswani, et al.) | Alpha158 | 0.0264±0.00 | 0.2053±0.02 | 0.0407±0.00 | 0.3273±0.02 | 0.0273±0.02 | 0.3970±0.26 | -0.1101±0.02 |
| GRU(Kyunghyun Cho, et al.) | Alpha158(with selected 20 features) | 0.0315±0.00 | 0.2450±0.04 | 0.0428±0.00 | 0.3440±0.03 | 0.0344±0.02 | 0.5160±0.25 | -0.1017±0.02 |
| LSTM(Sepp Hochreiter, et al.) | Alpha158(with selected 20 features) | 0.0318±0.00 | 0.2367±0.04 | 0.0435±0.00 | 0.3389±0.03 | 0.0381±0.03 | 0.5561±0.46 | -0.1207±0.04 |
| Localformer(Juyong Jiang, et al.) | Alpha158 | 0.0356±0.00 | 0.2756±0.03 | 0.0468±0.00 | 0.3784±0.03 | 0.0438±0.02 | 0.6600±0.33 | -0.0952±0.02 |
| SFM(Liheng Zhang, et al.) | Alpha158 | 0.0379±0.00 | 0.2959±0.04 | 0.0464±0.00 | 0.3825±0.04 | 0.0465±0.02 | 0.5672±0.29 | -0.1282±0.03 |
| ALSTM (Yao Qin, et al.) | Alpha158(with selected 20 features) | 0.0362±0.01 | 0.2789±0.06 | 0.0463±0.01 | 0.3661±0.05 | 0.0470±0.03 | 0.6992±0.47 | -0.1072±0.03 |
| GATs (Petar Velickovic, et al.) | Alpha158(with selected 20 features) | 0.0349±0.00 | 0.2511±0.01 | 0.0462±0.00 | 0.3564±0.01 | 0.0497±0.01 | 0.7338±0.19 | -0.0777±0.02 |
| TRA(Hengxu Lin, et al.) | Alpha158(with selected 20 features) | 0.0404±0.00 | 0.3197±0.05 | 0.0490±0.00 | 0.4047±0.04 | 0.0649±0.02 | 1.0091±0.30 | -0.0860±0.02 |
| Linear | Alpha158 | 0.0397±0.00 | 0.3000±0.00 | 0.0472±0.00 | 0.3531±0.00 | 0.0692±0.00 | 0.9209±0.00 | -0.1509±0.00 |
| TRA(Hengxu Lin, et al.) | Alpha158 | 0.0440±0.00 | 0.3535±0.05 | 0.0540±0.00 | 0.4451±0.03 | 0.0718±0.02 | 1.0835±0.35 | -0.0760±0.02 |
| CatBoost(Liudmila Prokhorenkova, et al.) | Alpha158 | 0.0481±0.00 | 0.3366±0.00 | 0.0454±0.00 | 0.3311±0.00 | 0.0765±0.00 | 0.8032±0.01 | -0.1092±0.00 |
| XGBoost(Tianqi Chen, et al.) | Alpha158 | 0.0498±0.00 | 0.3779±0.00 | 0.0505±0.00 | 0.4131±0.00 | 0.0780±0.00 | 0.9070±0.00 | -0.1168±0.00 |
| TFT (Bryan Lim, et al.) | Alpha158(with selected 20 features) | 0.0358±0.00 | 0.2160±0.03 | 0.0116±0.01 | 0.0720±0.03 | 0.0847±0.02 | 0.8131±0.19 | -0.1824±0.03 |
| MLP | Alpha158 | 0.0376±0.00 | 0.2846±0.02 | 0.0429±0.00 | 0.3220±0.01 | 0.0895±0.02 | 1.1408±0.23 | -0.1103±0.02 |
| LightGBM(Guolin Ke, et al.) | Alpha158 | 0.0448±0.00 | 0.3660±0.00 | 0.0469±0.00 | 0.3877±0.00 | 0.0901±0.00 | 1.0164±0.00 | -0.1038±0.00 |
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 |
## Alpha360 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|-------------------------------------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
| Transformer(Ashish Vaswani, et al.) | Alpha360 | 0.0114±0.00 | 0.0716±0.03 | 0.0327±0.00 | 0.2248±0.02 | -0.0270±0.03 | -0.3378±0.37 | -0.1653±0.05 |
| TabNet(Sercan O. Arik, et al.) | Alpha360 | 0.0099±0.00 | 0.0593±0.00 | 0.0290±0.00 | 0.1887±0.00 | -0.0369±0.00 | -0.3892±0.00 | -0.2145±0.00 |
| MLP | Alpha360 | 0.0273±0.00 | 0.1870±0.02 | 0.0396±0.00 | 0.2910±0.02 | 0.0029±0.02 | 0.0274±0.23 | -0.1385±0.03 |
| Localformer(Juyong Jiang, et al.) | Alpha360 | 0.0404±0.00 | 0.2932±0.04 | 0.0542±0.00 | 0.4110±0.03 | 0.0246±0.02 | 0.3211±0.21 | -0.1095±0.02 |
| CatBoost((Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0378±0.00 | 0.2714±0.00 | 0.0467±0.00 | 0.3659±0.00 | 0.0292±0.00 | 0.3781±0.00 | -0.0862±0.00 |
| XGBoost(Tianqi Chen, et al.) | Alpha360 | 0.0394±0.00 | 0.2909±0.00 | 0.0448±0.00 | 0.3679±0.00 | 0.0344±0.00 | 0.4527±0.02 | -0.1004±0.00 |
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha360 | 0.0404±0.00 | 0.3023±0.00 | 0.0495±0.00 | 0.3898±0.00 | 0.0468±0.01 | 0.6302±0.20 | -0.0860±0.01 |
| LightGBM(Guolin Ke, et al.) | Alpha360 | 0.0400±0.00 | 0.3037±0.00 | 0.0499±0.00 | 0.4042±0.00 | 0.0558±0.00 | 0.7632±0.00 | -0.0659±0.00 |
| TCN(Shaojie Bai, et al.) | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02 | 0.8295±0.34 | -0.1018±0.03 |
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0497±0.00 | 0.3829±0.04 | 0.0599±0.00 | 0.4736±0.03 | 0.0626±0.02 | 0.8651±0.31 | -0.0994±0.03 |
| LSTM(Sepp Hochreiter, et al.) | Alpha360 | 0.0448±0.00 | 0.3474±0.04 | 0.0549±0.00 | 0.4366±0.03 | 0.0647±0.03 | 0.8963±0.39 | -0.0875±0.02 |
| ADD | Alpha360 | 0.0430±0.00 | 0.3188±0.04 | 0.0559±0.00 | 0.4301±0.03 | 0.0667±0.02 | 0.8992±0.34 | -0.0855±0.02 |
| GRU(Kyunghyun Cho, et al.) | Alpha360 | 0.0493±0.00 | 0.3772±0.04 | 0.0584±0.00 | 0.4638±0.03 | 0.0720±0.02 | 0.9730±0.33 | -0.0821±0.02 |
| AdaRNN(Yuntao Du, et al.) | Alpha360 | 0.0464±0.01 | 0.3619±0.08 | 0.0539±0.01 | 0.4287±0.06 | 0.0753±0.03 | 1.0200±0.40 | -0.0936±0.03 |
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
| TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 |
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
- The selected 20 features are based on the feature importance of a lightgbm-based model.
- The base model of DoubleEnsemble is LGBM.
- The base model of TCTS is GRU.

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -0,0 +1,4 @@
# TCN
* Code: [https://github.com/locuslab/TCN](https://github.com/locuslab/TCN)
* Paper: [An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling](https://arxiv.org/abs/1803.01271).

View File

@@ -0,0 +1,4 @@
numpy==1.17.4
pandas==1.1.2
scikit_learn==0.23.2
torch==1.7.0

View File

@@ -0,0 +1,100 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TCN
module_path: qlib.contrib.model.pytorch_tcn_ts
kwargs:
d_feat: 20
num_layers: 5
n_chans: 32
kernel_size: 7
dropout: 0.5
n_epochs: 200
lr: 1e-4
early_stop: 20
batch_size: 2000
metric: loss
loss: mse
optimizer: adam
n_jobs: 20
GPU: 0
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,90 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TCN
module_path: qlib.contrib.model.pytorch_tcn
kwargs:
d_feat: 6
num_layers: 5
n_chans: 128
kernel_size: 3
dropout: 0.5
n_epochs: 200
lr: 1e-3
early_stop: 20
batch_size: 2000
metric: loss
loss: mse
optimizer: adam
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -1,52 +1,38 @@
# Temporally Correlated Task Scheduling for Sequence Learning
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
### Background
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
<p align="center">
<img src="task_description.png" width="600" height="200"/>
</p>
### Method
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. This work introduces a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in the current minibatch) and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
<p align="center">
<img src="workflow.png"/>
</p>
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
### DataSet
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
At step <img src="https://latex.codecogs.com/png.latex?s" title="s" />, with training data <img src="https://latex.codecogs.com/png.latex?x_s,y_s" title="x_s,y_s" />, the scheduler <img src="https://latex.codecogs.com/png.latex?\varphi" title="\varphi" /> chooses a suitable task <img src="https://latex.codecogs.com/png.latex?T_{i_s}" title="T_{i_s}" /> (green solid lines) to update the model <img src="https://latex.codecogs.com/png.latex?f" title="f" /> (blue solid lines). After <img src="https://latex.codecogs.com/png.latex?S" title="S" /> steps, we evaluate the model <img src="https://latex.codecogs.com/png.latex?f" title="f" /> on the validation set and update the scheduler <img src="https://latex.codecogs.com/png.latex?\varphi" title="\varphi" /> (green dashed lines).
### Experiments
#### Task Description
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
Due to different data versions and different Qlib versions, the original data and data preprocessing methods of the experimental settings in the paper are different from those experimental settings in the existing Qlib version. Therefore, we provide two versions of the code according to the two kinds of settings, 1) the [code](https://github.com/lwwang1995/tcts) that can be used to reproduce the experimental results and 2) the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) in the current Qlib baseline.
#### Setting1
* Dataset: We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020. We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
* The main tasks <img src="https://latex.codecogs.com/png.latex?T_k" title="T_k" /> refers to forecasting return of stock <img src="https://latex.codecogs.com/png.latex?i" title="i" /> as following,
<div align=center>
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
<img src="https://latex.codecogs.com/png.image?\dpi{110}&space;r_{i}^{t,k}&space;=&space;\frac{price_i^{t&plus;k}}{price_i^{t&plus;k-1}}-1" title="r_{i}^{t,k} = \frac{price_i^{t+k}}{price_i^{t+k-1}}-1" />
</div>
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
#### Baselines
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
#### Result
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
| :----: | :----: | :----: | :----: |
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
* Temporally correlated task sets <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_k&space;=&space;\{T_1,&space;T_2,&space;...&space;,&space;T_k\}" title="\mathcal{T}_k = \{T_1, T_2, ... , T_k\}" />, in this paper, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_3" title="\mathcal{T}_3" />, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_5" title="\mathcal{T}_5" /> and <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_{10}" title="\mathcal{T}_{10}" /> are used in <img src="https://latex.codecogs.com/png.latex?T_1" title="T_1" />, <img src="https://latex.codecogs.com/png.latex?T_2" title="T_2" />, and <img src="https://latex.codecogs.com/png.latex?T_3" title="T_3" />.
#### Setting2
* Dataset: We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020. We split the data into training (01/01/2008-12/31/2014), validation (01/01/2015-12/31/2016), and test sets (01/01/2017-08/01/2020) based on the transaction time.
* The main tasks <img src="https://latex.codecogs.com/png.latex?T_k" title="T_k" /> refers to forecasting return of stock <img src="https://latex.codecogs.com/png.latex?i" title="i" /> as following,
<div align=center>
<img src="https://latex.codecogs.com/png.image?\dpi{110}&space;r_{i}^{t,k}&space;=&space;\frac{price_i^{t&plus;1&plus;k}}{price_i^{t&plus;1}}-1" title="r_{i}^{t,k} = \frac{price_i^{t+1+k}}{price_i^{t+1}}-1" />
</div>
* In Qlib baseline, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_3" title="\mathcal{T}_3" />, is used in <img src="https://latex.codecogs.com/png.latex?T_1" title="T_1" />.
### Experimental Result
You can find the experimental result of setting1 in the [paper](http://proceedings.mlr.press/v139/wu21e/wu21e.pdf) and the experimental result of setting2 in this [page](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).

Binary file not shown.

Before

Width:  |  Height:  |  Size: 25 KiB

View File

@@ -22,16 +22,17 @@ data_handler_config: &data_handler_config
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -1) / $close - 1",
"Ref($close, -2) / Ref($close, -1) - 1",
"Ref($close, -3) / Ref($close, -2) - 1"]
label: ["Ref($close, -2) / Ref($close, -1) - 1",
"Ref($close, -3) / Ref($close, -1) - 1",
"Ref($close, -4) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -53,9 +54,8 @@ task:
d_feat: 6
hidden_size: 64
num_layers: 2
dropout: 0.0
dropout: 0.3
n_epochs: 200
lr: 1e-3
early_stop: 20
batch_size: 800
metric: loss
@@ -64,10 +64,10 @@ task:
fore_optimizer: adam
weight_optimizer: adam
output_dim: 3
fore_lr: 5e-4
weight_lr: 5e-4
fore_lr: 2e-3
weight_lr: 2e-3
steps: 3
target_label: 1
target_label: 0
lowest_valid_performance: 0.993
dataset:
class: DatasetH
@@ -92,8 +92,7 @@ task:
kwargs:
ana_long_short: False
ann_scaler: 252
label_col: 1
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -195,7 +195,8 @@ class Alpha158Formatter(GenericDataFormatter):
for col in column_names:
if col not in {"forecast_time", "identifier"}:
output[col] = self._target_scaler.inverse_transform(predictions[col])
# Using [col] is for aligning with the format when fitting
output[col] = self._target_scaler.inverse_transform(predictions[[col]])
return output

View File

@@ -311,5 +311,11 @@ class TFTModel(ModelFT):
# self.model.save(path)
# save qlib model wrapper
self.model = None
drop_attrs = ["model", "tf_graph", "sess", "data_formatter"]
orig_attr = {}
for attr in drop_attrs:
orig_attr[attr] = getattr(self, attr)
setattr(self, attr, None)
super(TFTModel, self).to_pickle(path)
for attr in drop_attrs:
setattr(self, attr, orig_attr[attr])

View File

@@ -16,8 +16,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -38,7 +38,7 @@ class TRAModel(Model):
model_init_state=None,
lamb=0.0,
rho=0.99,
seed=0,
seed=None,
logdir=None,
eval_train=True,
eval_test=False,

View File

@@ -57,8 +57,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -0,0 +1,3 @@
# TabNet
* Code: [https://github.com/dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet)
* Paper: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/pdf/1908.07442.pdf).

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -50,6 +51,7 @@ task:
kwargs:
d_feat: 158
pretrain: True
seed: 993
dataset:
class: DatasetH
module_path: qlib.data.dataset

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -50,6 +51,7 @@ task:
kwargs:
d_feat: 360
pretrain: True
seed: 993
dataset:
class: DatasetH
module_path: qlib.data.dataset

View File

@@ -0,0 +1,3 @@
# Transformer
* Code: [https://github.com/tensorflow/tensor2tensor](https://github.com/tensorflow/tensor2tensor)
* Paper: [Attention is All you Need](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf).

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -0,0 +1,2 @@
# Introduction
The examples in this folder try to demonstrate some common usage of data-related modules of Qlib

View File

@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
The motivation of this demo
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
"""
from copy import deepcopy
from pathlib import Path
import pickle
from pprint import pprint
import subprocess
import yaml
from qlib.log import TimeInspector
from qlib import init
from qlib.data.dataset.handler import DataHandlerLP
from qlib.utils import init_instance_by_config
# For general purpose, we use relative path
DIRNAME = Path(__file__).absolute().resolve().parent
if __name__ == "__main__":
init()
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
# 1) show original time
with TimeInspector.logt("The original time without handler cache:"):
subprocess.run(f"qrun {config_path}", shell=True)
# 2) dump handler
task_config = yaml.safe_load(config_path.open())
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
pprint(hd_conf)
hd: DataHandlerLP = init_instance_by_config(hd_conf)
hd_path = DIRNAME / "handler.pkl"
hd.to_pickle(hd_path, dump_all=True)
# 3) create new task with handler cache
new_task_config = deepcopy(task_config)
new_task_config["task"]["dataset"]["kwargs"]["handler"] = f"file://{hd_path}"
new_task_config["sys"] = {"path": [str(config_path.parent.resolve())]}
new_task_path = DIRNAME / "new_task.yaml"
print("The location of the new task", new_task_path)
# save new task
with new_task_path.open("w") as f:
yaml.safe_dump(new_task_config, f, indent=4, sort_keys=False)
# 4) train model with new task
with TimeInspector.logt("The time for task with handler cache:"):
subprocess.run(f"qrun {new_task_path}", shell=True)

View File

@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
The motivation of this demo
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
"""
from copy import deepcopy
from pathlib import Path
import pickle
from pprint import pprint
import subprocess
import yaml
from qlib import init
from qlib.data.dataset.handler import DataHandlerLP
from qlib.log import TimeInspector
from qlib.model.trainer import task_train
from qlib.utils import init_instance_by_config
# For general purpose, we use relative path
DIRNAME = Path(__file__).absolute().resolve().parent
if __name__ == "__main__":
init()
repeat = 2
exp_name = "data_mem_reuse_demo"
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
task_config = yaml.safe_load(config_path.open())
# 1) without using processed data in memory
with TimeInspector.logt("The original time without reusing processed data in memory:"):
for i in range(repeat):
task_train(task_config["task"], experiment_name=exp_name)
# 2) prepare processed data in memory.
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
pprint(hd_conf)
hd: DataHandlerLP = init_instance_by_config(hd_conf)
# 3) with reusing processed data in memory
new_task = deepcopy(task_config["task"])
new_task["dataset"]["kwargs"]["handler"] = hd
print(new_task)
with TimeInspector.logt("The time with reusing processed data in memory:"):
# this will save the time to reload and process data from disk(in `DataHandlerLP`)
# It still takes a lot of time in the backtest phase
for i in range(repeat):
task_train(new_task, experiment_name=exp_name)
# 4) User can change other parts exclude processed data in memory(handler)
new_task = deepcopy(task_config["task"])
new_task["dataset"]["kwargs"]["segments"]["train"] = ("20100101", "20131231")
with TimeInspector.logt("The time with reusing processed data in memory:"):
task_train(new_task, experiment_name=exp_name)

View File

@@ -30,6 +30,7 @@ Run the example by running the following command:
## Benchmarks Performance
### Signal Test
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|---|---|---|---|---|---|---|---|---|---|
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |
| LightGBM | Alpha158 | 0.0349±0.00 | 0.3805±0.00| 0.0435±0.00 | 0.4724±0.00 | 0.5111±0.00 | 0.5428±0.00 | 0.000074±0.00 | 0.2677±0.00 |

View File

@@ -59,7 +59,7 @@ task:
record:
- class: "SignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs:
kwargs: {}
- class: "HFSignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs: {}

View File

@@ -1,9 +1,105 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
The expect result of `backtest` is following in current version
'The following are analysis results of benchmark return(1day).'
risk
mean 0.000651
std 0.012472
annualized_return 0.154967
information_ratio 0.805422
max_drawdown -0.160445
'The following are analysis results of the excess return without cost(1day).'
risk
mean 0.001258
std 0.007575
annualized_return 0.299303
information_ratio 2.561219
max_drawdown -0.068386
'The following are analysis results of the excess return with cost(1day).'
risk
mean 0.001110
std 0.007575
annualized_return 0.264280
information_ratio 2.261392
max_drawdown -0.071842
[1706497:MainThread](2021-12-07 14:08:30,263) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_30minute.
pkl' has been saved as the artifact of the Experiment 2
'The following are analysis results of benchmark return(30minute).'
risk
mean 0.000078
std 0.003646
annualized_return 0.148787
information_ratio 0.935252
max_drawdown -0.142830
('The following are analysis results of the excess return without '
'cost(30minute).')
risk
mean 0.000174
std 0.003343
annualized_return 0.331867
information_ratio 2.275019
max_drawdown -0.074752
'The following are analysis results of the excess return with cost(30minute).'
risk
mean 0.000155
std 0.003343
annualized_return 0.294536
information_ratio 2.018860
max_drawdown -0.075579
[1706497:MainThread](2021-12-07 14:08:30,277) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_5minute.p
kl' has been saved as the artifact of the Experiment 2
'The following are analysis results of benchmark return(5minute).'
risk
mean 0.000015
std 0.001460
annualized_return 0.172170
information_ratio 1.103439
max_drawdown -0.144807
'The following are analysis results of the excess return without cost(5minute).'
risk
mean 0.000028
std 0.001412
annualized_return 0.319771
information_ratio 2.119563
max_drawdown -0.077426
'The following are analysis results of the excess return with cost(5minute).'
risk
mean 0.000025
std 0.001412
annualized_return 0.281536
information_ratio 1.866091
max_drawdown -0.078194
[1706497:MainThread](2021-12-07 14:08:30,287) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_1day
.pkl' has been saved as the artifact of the Experiment 2
'The following are analysis results of indicators(1day).'
value
ffr 0.945821
pa 0.000324
pos 0.542882
[1706497:MainThread](2021-12-07 14:08:30,293) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_30mi
nute.pkl' has been saved as the artifact of the Experiment 2
'The following are analysis results of indicators(30minute).'
value
ffr 0.982910
pa 0.000037
pos 0.500806
[1706497:MainThread](2021-12-07 14:08:30,302) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_5min
ute.pkl' has been saved as the artifact of the Experiment 2
'The following are analysis results of indicators(5minute).'
value
ffr 0.991017
pa 0.000000
pos 0.000000
[1706497:MainThread](2021-12-07 14:08:30,627) INFO - qlib.timer - [log.py:113] - Time cost: 0.014s | waiting `async_log` Done
"""
from copy import deepcopy
import qlib
import fire
import pandas as pd
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
from qlib.data import D
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
@@ -14,7 +110,6 @@ from qlib.backtest import collect_data
class NestedDecisionExecutionWorkflow:
market = "csi300"
benchmark = "SH000300"
data_handler_config = {
@@ -151,10 +246,9 @@ class NestedDecisionExecutionWorkflow:
self._train_model(model, dataset)
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},
@@ -163,13 +257,10 @@ class NestedDecisionExecutionWorkflow:
self.port_analysis_config["backtest"]["benchmark"] = self.benchmark
with R.start(experiment_name="backtest"):
recorder = R.get_recorder()
par = PortAnaRecord(
recorder,
self.port_analysis_config,
risk_analysis_freq=["day", "30min", "5min"],
indicator_analysis_freq=["day", "30min", "5min"],
indicator_analysis_method="value_weighted",
)
par.generate()
@@ -189,10 +280,9 @@ class NestedDecisionExecutionWorkflow:
backtest_config["benchmark"] = self.benchmark
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},
@@ -201,6 +291,101 @@ class NestedDecisionExecutionWorkflow:
for trade_decision in data_generator:
print(trade_decision)
# the code below are for checking, users don't have to care about it
# The tests can be categorized into 2 types
# 1) comparing same backtest
# - Basic test idea: the shared accumulated value are equal in multiple levels
# - Aligning the profit calculation between multiple levels and single levels.
# 2) comparing different backtest
# - Basic test idea:
# - the daily backtest will be similar as multi-level(the data quality makes this gap samller)
def check_diff_freq(self):
self._init_qlib()
exp = R.get_exp(experiment_name="backtest")
rec = next(iter(exp.list_recorders().values())) # assuming this will get the latest recorder
for check_key in "account", "total_turnover", "total_cost":
check_key = "total_cost"
acc_dict = {}
for freq in ["30minute", "5minute", "1day"]:
acc_dict[freq] = rec.load_object(f"portfolio_analysis/report_normal_{freq}.pkl")[check_key]
acc_df = pd.DataFrame(acc_dict)
acc_resam = acc_df.resample("1d").last().dropna()
assert (acc_resam["30minute"] == acc_resam["1day"]).all()
def backtest_only_daily(self):
"""
This backtest is used for comparing the nested execution and single layer execution
Due to the low quality daily-level and miniute-level data, they are hardly comparable.
So it is used for detecting serious bugs which make the results different greatly.
.. code-block:: shell
[1724971:MainThread](2021-12-07 16:24:31,156) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_1day.pkl'
has been saved as the artifact of the Experiment 2
'The following are analysis results of benchmark return(1day).'
risk
mean 0.000651
std 0.012472
annualized_return 0.154967
information_ratio 0.805422
max_drawdown -0.160445
'The following are analysis results of the excess return without cost(1day).'
risk
mean 0.001375
std 0.006103
annualized_return 0.327204
information_ratio 3.475016
max_drawdown -0.024927
'The following are analysis results of the excess return with cost(1day).'
risk
mean 0.001184
std 0.006091
annualized_return 0.281801
information_ratio 2.998749
max_drawdown -0.029568
[1724971:MainThread](2021-12-07 16:24:31,170) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_1day.
pkl' has been saved as the artifact of the Experiment 2
'The following are analysis results of indicators(1day).'
value
ffr 1.0
pa 0.0
pos 0.0
[1724971:MainThread](2021-12-07 16:24:31,188) INFO - qlib.timer - [log.py:113] - Time cost: 0.007s | waiting `async_log` Done
"""
self._init_qlib()
model = init_instance_by_config(self.task["model"])
dataset = init_instance_by_config(self.task["dataset"])
self._train_model(model, dataset)
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},
}
pa_conf = deepcopy(self.port_analysis_config)
pa_conf["strategy"] = strategy_config
pa_conf["executor"] = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "day",
"generate_portfolio_metrics": True,
"verbose": True,
},
}
pa_conf["backtest"]["benchmark"] = self.benchmark
with R.start(experiment_name="backtest"):
recorder = R.get_recorder()
par = PortAnaRecord(recorder, pa_conf)
par.generate()
if __name__ == "__main__":
fire.Fire(NestedDecisionExecutionWorkflow)

View File

@@ -151,6 +151,9 @@ def get_all_results(folders) -> dict:
if recorders[recorder_id].status == "FINISHED":
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
metrics = recorder.list_metrics()
if "1day.excess_return_with_cost.annualized_return" not in metrics:
print(f"{recorder_id} is skipped due to incomplete result")
continue
result["annualized_return_with_cost"].append(metrics["1day.excess_return_with_cost.annualized_return"])
result["information_ratio_with_cost"].append(metrics["1day.excess_return_with_cost.information_ratio"])
result["max_drawdown_with_cost"].append(metrics["1day.excess_return_with_cost.max_drawdown"])
@@ -200,174 +203,183 @@ def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
return temp_path
# function to run the all the models
@only_allow_defined_args
def run(
times=1,
models=None,
dataset="Alpha360",
exclude=False,
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
exp_folder_name: str = "run_all_model_records",
wait_before_rm_env: bool = False,
wait_when_err: bool = False,
):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model
for multiple times, and this will be fixed in the future development.
class ModelRunner:
def _init_qlib(self, exp_folder_name):
# init qlib
GetData().qlib_data(exists_skip=True)
qlib.init(
exp_manager={
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
"kwargs": {
"uri": "file:" + str(Path(os.getcwd()).resolve() / exp_folder_name),
"default_exp_name": "Experiment",
},
}
)
Parameters:
-----------
times : int
determines how many times the model should be running.
models : str or list
determines the specific model or list of models to run or exclude.
exclude : boolean
determines whether the model being used is excluded or included.
dataset : str
determines the dataset to be used for each model.
qlib_uri : str
the uri to install qlib with pip
it could be url on the we or local path
exp_folder_name: str
the name of the experiment folder
wait_before_rm_env : bool
wait before remove environment.
wait_when_err : bool
wait when errors raised when executing commands
# function to run the all the models
@only_allow_defined_args
def run(
self,
times=1,
models=None,
dataset="Alpha360",
exclude=False,
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
exp_folder_name: str = "run_all_model_records",
wait_before_rm_env: bool = False,
wait_when_err: bool = False,
):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model
for multiple times, and this will be fixed in the future development.
Usage:
-------
Here are some use cases of the function in the bash:
Parameters:
-----------
times : int
determines how many times the model should be running.
models : str or list
determines the specific model or list of models to run or exclude.
exclude : boolean
determines whether the model being used is excluded or included.
dataset : str
determines the dataset to be used for each model.
qlib_uri : str
the uri to install qlib with pip
it could be url on the we or local path
exp_folder_name: str
the name of the experiment folder
wait_before_rm_env : bool
wait before remove environment.
wait_when_err : bool
wait when errors raised when executing commands
.. code-block:: bash
Usage:
-------
Here are some use cases of the function in the bash:
# Case 1 - run all models multiple times
python run_all_model.py 3
.. code-block:: bash
# Case 2 - run specific models multiple times
python run_all_model.py 3 mlp
# Case 1 - run all models multiple times
python run_all_model.py run 3
# Case 3 - run specific models multiple times with specific dataset
python run_all_model.py 3 mlp Alpha158
# Case 2 - run specific models multiple times
python run_all_model.py run 3 mlp
# Case 4 - run other models except those are given as arguments for multiple times
python run_all_model.py 3 [mlp,tft,lstm] --exclude=True
# Case 3 - run specific models multiple times with specific dataset
python run_all_model.py run 3 mlp Alpha158
# Case 5 - run specific models for one time
python run_all_model.py --models=[mlp,lightgbm]
# Case 4 - run other models except those are given as arguments for multiple times
python run_all_model.py run 3 [mlp,tft,lstm] --exclude=True
# Case 6 - run other models except those are given as arguments for one time
python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
# Case 5 - run specific models for one time
python run_all_model.py run --models=[mlp,lightgbm]
"""
# init qlib
GetData().qlib_data(exists_skip=True)
qlib.init(
exp_manager={
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
"kwargs": {
"uri": "file:" + str(Path(os.getcwd()).resolve() / exp_folder_name),
"default_exp_name": "Experiment",
},
}
)
# Case 6 - run other models except those are given as arguments for one time
python run_all_model.py run --models=[mlp,tft,sfm] --exclude=True
# get all folders
folders = get_all_folders(models, exclude)
# init error messages:
errors = dict()
# run all the model for iterations
for fn in folders:
# get all files
sys.stderr.write("Retrieving files...\n")
yaml_path, req_path = get_all_files(folders[fn], dataset)
if yaml_path is None:
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
continue
sys.stderr.write("\n")
# create env by anaconda
temp_dir, env_path, python_path, conda_activate = create_env()
"""
self._init_qlib(exp_folder_name)
# install requirements.txt
sys.stderr.write("Installing requirements.txt...\n")
with open(req_path) as f:
content = f.read()
if "torch" in content:
# automatically install pytorch according to nvidia's version
execute(
f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err
) # for automatically installing torch according to the nvidia driver
execute(
f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}",
wait_when_err=wait_when_err,
)
else:
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
sys.stderr.write("\n")
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)
# setup gpu for tft
if fn == "TFT":
execute(
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
wait_when_err=wait_when_err,
)
# get all folders
folders = get_all_folders(models, exclude)
# init error messages:
errors = dict()
# run all the model for iterations
for fn in folders:
# get all files
sys.stderr.write("Retrieving files...\n")
yaml_path, req_path = get_all_files(folders[fn], dataset)
if yaml_path is None:
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
continue
sys.stderr.write("\n")
# install qlib
sys.stderr.write("Installing qlib...\n")
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
if fn == "TFT":
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
else:
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
sys.stderr.write("\n")
# run workflow_by_config for multiple times
for i in range(times):
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
errs = execute(
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
wait_when_err=wait_when_err,
)
if errs is not None:
_errs = errors.get(fn, {})
_errs.update({i: errs})
errors[fn] = _errs
# create env by anaconda
temp_dir, env_path, python_path, conda_activate = create_env()
# install requirements.txt
sys.stderr.write("Installing requirements.txt...\n")
with open(req_path) as f:
content = f.read()
if "torch" in content:
# automatically install pytorch according to nvidia's version
execute(
f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err
) # for automatically installing torch according to the nvidia driver
execute(
f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}",
wait_when_err=wait_when_err,
)
else:
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
sys.stderr.write("\n")
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)
# setup gpu for tft
if fn == "TFT":
execute(
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
wait_when_err=wait_when_err,
)
sys.stderr.write("\n")
# install qlib
sys.stderr.write("Installing qlib...\n")
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
if fn == "TFT":
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
else:
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
sys.stderr.write("\n")
# run workflow_by_config for multiple times
for i in range(times):
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
errs = execute(
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
wait_when_err=wait_when_err,
)
if errs is not None:
_errs = errors.get(fn, {})
_errs.update({i: errs})
errors[fn] = _errs
sys.stderr.write("\n")
# remove env
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
if wait_before_rm_env:
input("Press Enter to Continue")
shutil.rmtree(env_path)
# print errors
sys.stderr.write(f"Here are some of the errors of the models...\n")
pprint(errors)
self._collect_results(exp_folder_name, dataset)
def _collect_results(self, exp_folder_name, dataset):
folders = get_all_folders(exp_folder_name, dataset)
# getting all results
sys.stderr.write(f"Retrieving results...\n")
results = get_all_results(folders)
if len(results) > 0:
# calculating the mean and std
sys.stderr.write(f"Calculating the mean and std of results...\n")
results = cal_mean_std(results)
# generating md table
sys.stderr.write(f"Generating markdown table...\n")
gen_and_save_md_table(results, dataset)
sys.stderr.write("\n")
# remove env
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
if wait_before_rm_env:
input("Press Enter to Continue")
shutil.rmtree(env_path)
# getting all results
sys.stderr.write(f"Retrieving results...\n")
results = get_all_results(folders)
if len(results) > 0:
# calculating the mean and std
sys.stderr.write(f"Calculating the mean and std of results...\n")
results = cal_mean_std(results)
# generating md table
sys.stderr.write(f"Generating markdown table...\n")
gen_and_save_md_table(results, dataset)
sys.stderr.write("\n")
# print errors
sys.stderr.write(f"Here are some of the errors of the models...\n")
pprint(errors)
sys.stderr.write("\n")
# move results folder
shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
# move results folder
shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
if __name__ == "__main__":
fire.Fire(run) # run all the model
fire.Fire(ModelRunner) # run all the model

View File

@@ -204,7 +204,7 @@
" },\n",
" \"strategy\": {\n",
" \"class\": \"TopkDropoutStrategy\",\n",
" \"module_path\": \"qlib.contrib.strategy.model_strategy\",\n",
" \"module_path\": \"qlib.contrib.strategy.signal_strategy\",\n",
" \"kwargs\": {\n",
" \"model\": model,\n",
" \"dataset\": dataset,\n",

View File

@@ -5,7 +5,7 @@ import qlib
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
@@ -31,10 +31,9 @@ if __name__ == "__main__":
},
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},
@@ -71,6 +70,10 @@ if __name__ == "__main__":
sr = SignalRecord(model, dataset, recorder)
sr.generate()
# Signal Analysis
sar = SigAnaRecord(recorder)
sar.generate()
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
par = PortAnaRecord(recorder, port_analysis_config, "day")

View File

@@ -6,6 +6,7 @@ _version_path = Path(__file__).absolute().parent / "VERSION.txt" # This file is
__version__ = _version_path.read_text(encoding="utf-8").strip()
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
from typing import Union
import yaml
import logging
import platform
@@ -151,14 +152,17 @@ def init_from_yaml_conf(conf_path, **kwargs):
:param conf_path: A path to the qlib config in yml format
"""
with open(conf_path) as f:
config = yaml.safe_load(f)
if conf_path is None:
config = {}
else:
with open(conf_path) as f:
config = yaml.safe_load(f)
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
def get_project_path(config_name="config.yaml", cur_path: Union[Path, str, None] = None) -> Path:
"""
If users are building a project follow the following pattern.
- Qlib is a sub folder in project path
@@ -187,6 +191,7 @@ def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
"""
if cur_path is None:
cur_path = Path(__file__).absolute().resolve()
cur_path = Path(cur_path)
while True:
if (cur_path / config_name).exists():
return cur_path
@@ -202,6 +207,40 @@ def auto_init(**kwargs):
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
- Skip initialization if already initialized
:**kwargs: it may contain following parameters
cur_path: the start path to find the project path
Here are two examples of the configuration
Example 1)
If you want create a new project-specific config based on a shared configure, you can use `conf_type: ref`
.. code-block:: yaml
conf_type: ref
qlib_cfg: '<shared_yaml_config_path>' # this could be null reference no config from other files
# following configs in `qlib_cfg_update` is project=specific
qlib_cfg_update:
exp_manager:
class: "MLflowExpManager"
module_path: "qlib.workflow.expm"
kwargs:
uri: "file://<your mlflow experiment path>"
default_exp_name: "Experiment"
Example 2)
If you wan to create simple a stand alone config, you can use following config(a.k.a `conf_type: origin`)
.. code-block:: python
exp_manager:
class: "MLflowExpManager"
module_path: "qlib.workflow.expm"
kwargs:
uri: "file://<your mlflow experiment path>"
default_exp_name: "Experiment"
"""
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
@@ -210,6 +249,7 @@ def auto_init(**kwargs):
except FileNotFoundError:
init(**kwargs)
else:
logger = get_module_logger("Initialization")
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)
@@ -223,8 +263,14 @@ def auto_init(**kwargs):
# - There is a shared configure file and you don't want to edit it inplace.
# - The shared configure may be updated later and you don't want to copy it.
# - You have some customized config.
qlib_conf_path = conf["qlib_cfg"]
qlib_conf_update = conf.get("qlib_cfg_update")
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
logger = get_module_logger("Initialization")
qlib_conf_path = conf.get("qlib_cfg", None)
# merge the arguments
qlib_conf_update = conf.get("qlib_cfg_update", {})
for k, v in kwargs.items():
if k in qlib_conf_update:
logger.warning(f"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'")
qlib_conf_update.update(kwargs)
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update)
logger.info(f"Auto load project config: {conf_pp}")

View File

@@ -50,11 +50,12 @@ def get_exchange(
subscribe_fields: list
subscribe fields.
open_cost : float
open transaction cost.
open transaction cost. It is a ratio. The cost is proportional to your order's deal amount.
close_cost : float
close transaction cost.
close transaction cost. It is a ratio. The cost is proportional to your order's deal amount.
min_cost : float
min transaction cost.
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
trade_unit : int
Included in kwargs. Please refer to the docs of `__init__` of `Exchange`
deal_price: Union[str, Tuple[str], List[str]]
@@ -185,8 +186,10 @@ def get_strategy_executor(
trade_exchange = get_exchange(**exchange_kwargs)
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy, common_infra=common_infra)
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra)
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
trade_strategy.reset_common_infra(common_infra)
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor)
trade_executor.reset_common_infra(common_infra)
return trade_strategy, trade_executor

View File

@@ -29,7 +29,10 @@ rtn & earning in the Account
class AccumulatedInfo:
"""accumulated trading info, including accumulated return/cost/turnover"""
"""
accumulated trading info, including accumulated return/cost/turnover
AccumulatedInfo should be shared accross different levels
"""
def __init__(self):
self.reset()
@@ -62,6 +65,11 @@ class AccumulatedInfo:
class Account:
"""
The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in qlib/backtest/executor.py:NestedExecutor
Different level of executor has different Account object when calculating metrics. But the position object is shared cross all the Account object.
"""
def __init__(
self,
init_cash: float = 1e9,
@@ -95,6 +103,8 @@ class Account:
self.init_vars(init_cash, position_dict, freq, benchmark_config)
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
# 1) the following variables are shared by multiple layers
# - you will see a shallow copy instead of deepcopy in the NestedExecutor;
self.init_cash = init_cash
self.current_position: BasePosition = init_instance_by_config(
{
@@ -106,6 +116,9 @@ class Account:
"module_path": "qlib.backtest.position",
}
)
self.accum_info = AccumulatedInfo()
# 2) following variables are not shared between layers
self.portfolio_metrics = None
self.hist_positions = {}
self.reset(freq=freq, benchmark_config=benchmark_config)
@@ -119,7 +132,8 @@ class Account:
def reset_report(self, freq, benchmark_config):
# portfolio related metrics
if self.is_port_metr_enabled():
self.accum_info = AccumulatedInfo()
# NOTE:
# `accum_info` and `current_position` are shared here
self.portfolio_metrics = PortfolioMetrics(freq, benchmark_config)
self.hist_positions = {}

View File

@@ -34,6 +34,7 @@ class Exchange:
open_cost=0.0015,
close_cost=0.0025,
min_cost=5,
impact_cost=0.0,
extra_quote=None,
quote_cls=NumpyQuote,
**kwargs,
@@ -95,6 +96,7 @@ class Exchange:
**NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must
distinguish `not set` and `disable trade_unit`
:param min_cost: min cost, default 5
:param impact_cost: market impact cost rate (a.k.a. slippage). A recommended value is 0.1.
:param extra_quote: pandas, dataframe consists of
columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].
The limit indicates that the etf is tradable on a specific day.
@@ -164,9 +166,12 @@ class Exchange:
all_fields = list(all_fields | set(subscribe_fields))
self.all_fields = all_fields
self.open_cost = open_cost
self.close_cost = close_cost
self.min_cost = min_cost
self.impact_cost = impact_cost
self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold
self.volume_threshold = volume_threshold
self.extra_quote = extra_quote
@@ -226,7 +231,7 @@ class Exchange:
self.extra_quote["limit_buy"] = False
self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
self.quote_df = pd.concat([self.quote_df, extra_quote], sort=False, axis=0)
self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0)
LT_TP_EXP = "(exp)" # Tuple[str, str]
LT_FLT = "float" # float
@@ -685,12 +690,14 @@ class Exchange:
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
)
def _get_buy_amount_by_cash_limit(self, trade_price, cash):
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
"""return the real order amount after cash limit for buying.
Parameters
----------
trade_price : float
position : cash
cost_ratio : float
Return
----------
float
@@ -699,10 +706,10 @@ class Exchange:
max_trade_amount = 0
if cash >= self.min_cost:
# critical_price means the stock transaction price when the service fee is equal to min_cost.
critical_price = self.min_cost / self.open_cost + self.min_cost
critical_price = self.min_cost / cost_ratio + self.min_cost
if cash >= critical_price:
# the service fee is equal to open_cost * trade_amount
max_trade_amount = cash / (1 + self.open_cost) / trade_price
# the service fee is equal to cost_ratio * trade_amount
max_trade_amount = cash / (1 + cost_ratio) / trade_price
else:
# the service fee is equal to min_cost
max_trade_amount = (cash - self.min_cost) / trade_price
@@ -718,6 +725,7 @@ class Exchange:
:return: trade_price, trade_val, trade_cost
"""
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
order.deal_amount = order.amount # set to full amount and clip it step by step
# Clipping amount first
@@ -726,8 +734,16 @@ class Exchange:
# - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit.
self._clip_amount_by_volume(order, dealt_order_amount)
# TODO: the adjusted cost ratio can be overestimated as deal_amount will be clipped in the next steps
trade_val = order.deal_amount * trade_price
if not total_trade_val or np.isnan(total_trade_val):
# TODO: assert trade_val == 0, f"trade_val != 0, total_trade_val: {total_trade_val}; order info: {order}"
adj_cost_ratio = self.impact_cost
else:
adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2
if order.direction == Order.SELL:
cost_ratio = self.close_cost
cost_ratio = self.close_cost + adj_cost_ratio
# sell
# if we don't know current position, we choose to sell all
# Otherwise, we clip the amount based on current position
@@ -750,14 +766,18 @@ class Exchange:
self.logger.debug(f"Order clipped due to cash limitation: {order}")
elif order.direction == Order.BUY:
cost_ratio = self.open_cost
cost_ratio = self.open_cost + adj_cost_ratio
# buy
if position is not None:
cash = position.get_cash()
trade_val = order.deal_amount * trade_price
if cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
if cash < max(trade_val * cost_ratio, self.min_cost):
# cash cannot cover cost
order.deal_amount = 0
self.logger.debug(f"Order clipped due to cost higher than cash: {order}")
elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
# The money is not enough
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash)
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)
order.deal_amount = self.round_amount_by_trade_unit(
min(max_buy_amount, order.deal_amount), order.factor
)

View File

@@ -130,7 +130,7 @@ class BaseExecutor:
if common_infra.has("trade_account"):
# NOTE: there is a trick in the code.
# copy is used instead of deepcopy. So positions are shared
# shallow copy is used instead of deepcopy. So positions are shared
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)

View File

@@ -160,6 +160,11 @@ class NumpyQuote(BaseQuote):
if is_single_value(start_time, end_time, self.freq, self.region):
# this is a very special case.
# skip aggregating function to speed-up the query calculation
# FIXME:
# it will go to the else logic when it comes to the
# 1) the day before holiday when daily trading
# 2) the last minute of the day when intraday trading
try:
return self.data[stock_id].loc[start_time, field]
except KeyError:

View File

@@ -223,6 +223,12 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `settle_commit` method")
def __str__(self):
return self.__dict__.__str__()
def __repr__(self):
return self.__dict__.__repr__()
class Position(BasePosition):
"""Position
@@ -345,15 +351,19 @@ class Position(BasePosition):
if stock_id not in self.position:
raise KeyError("{} not in current position".format(stock_id))
else:
# decrease the amount of stock
self.position[stock_id]["amount"] -= trade_amount
# check if to delete
if self.position[stock_id]["amount"] < -1e-5:
raise ValueError(
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
)
elif abs(self.position[stock_id]["amount"]) <= 1e-5:
if np.isclose(self.position[stock_id]["amount"], trade_amount):
# Selling all the stocks
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both ralative amount and absolute amount
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
self._del_stock(stock_id)
else:
# decrease the amount of stock
self.position[stock_id]["amount"] -= trade_amount
# check if to delete
if self.position[stock_id]["amount"] < -1e-5:
raise ValueError(
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
)
new_cash = trade_val - cost
if self._settle_type == self.ST_CASH:

102
qlib/backtest/signal.py Normal file
View File

@@ -0,0 +1,102 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.utils import init_instance_by_config
from typing import Dict, List, Text, Tuple, Union
from ..model.base import BaseModel
from ..data.dataset import Dataset
from ..data.dataset.utils import convert_index_format
from ..utils.resam import resam_ts_data
import pandas as pd
import abc
class Signal(metaclass=abc.ABCMeta):
"""
Some trading strategy make decisions based on other prediction signals
The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset)
This interface is tries to provide unified interface for those different sources
"""
@abc.abstractmethod
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
"""
get the signal at the end of the decision step(from `start_time` to `end_time`)
Returns
-------
Union[pd.Series, pd.DataFrame, None]:
returns None if no signal in the specific day
"""
...
class SignalWCache(Signal):
"""
Signal With pandas with based Cache
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
"""
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
"""
Parameters
----------
signal : Union[pd.Series, pd.DataFrame]
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
instrument datetime
SH600000 2008-01-02 0.079704
2008-01-03 0.120125
2008-01-04 0.878860
2008-01-07 0.505539
2008-01-08 0.395004
"""
self.signal_cache = convert_index_format(signal, level="datetime")
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
# the frequency of the signal may not algin with the decision frequency of strategy
# so resampling from the data is necessary
# the latest signal leverage more recent data and therefore is used in trading.
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
return signal
class ModelSignal(SignalWCache):
def __init__(self, model: BaseModel, dataset: Dataset):
self.model = model
self.dataset = dataset
pred_scores = self.model.predict(dataset)
if isinstance(pred_scores, pd.DataFrame):
pred_scores = pred_scores.iloc[:, 0]
super().__init__(pred_scores)
def _update_model(self):
"""
When using online data, update model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
"""
# TODO: this method is not included in the framework and could be refactor later
raise NotImplementedError("_update_model is not implemented!")
def create_signal_from(
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame]
) -> Signal:
"""
create signal from diverse information
This method will choose the right method to create a signal based on `obj`
Please refer to the code below.
"""
if isinstance(obj, Signal):
return obj
elif isinstance(obj, (tuple, list)):
return ModelSignal(*obj)
elif isinstance(obj, (dict, str)):
return init_instance_by_config(obj)
elif isinstance(obj, (pd.DataFrame, pd.Series)):
return SignalWCache(signal=obj)
else:
raise NotImplementedError(f"This type of signal is not supported")

View File

@@ -222,7 +222,7 @@ class CommonInfrastructure(BaseInfrastructure):
class LevelInfrastructure(BaseInfrastructure):
"""level instrastructure is created by executor, and then shared to strategies on the same level"""
"""level infrastructure is created by executor, and then shared to strategies on the same level"""
def get_support_infra(self):
"""

View File

@@ -73,6 +73,9 @@ class Config:
REG_CN = "cn"
REG_US = "us"
# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format
PROTOCOL_VERSION = 4
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
DISK_DATASET_CACHE = "DiskDatasetCache"
@@ -107,6 +110,8 @@ _default_config = {
# for simple dataset cache
"local_cache_path": None,
"kernels": NUM_USABLE_CPU,
# pickle.dump protocol version
"dump_protocol_version": PROTOCOL_VERSION,
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
"maxtasksperchild": None,
# If joblib_backend is None, use loky
@@ -239,8 +244,8 @@ HIGH_FREQ_CONFIG = {
_default_region_config = {
REG_CN: {
"trade_unit": 100,
"limit_threshold": 0.099,
"deal_price": "vwap",
"limit_threshold": 0.095,
"deal_price": "close",
},
REG_US: {
"trade_unit": 1,
@@ -265,6 +270,20 @@ class QlibConfig(Config):
self.provider_uri = provider_uri
self.mount_path = mount_path
@staticmethod
def format_provider_uri(provider_uri: Union[str, dict, Path]) -> dict:
if provider_uri is None:
raise ValueError("provider_uri cannot be None")
if isinstance(provider_uri, (str, dict, Path)):
if not isinstance(provider_uri, dict):
provider_uri = {QlibConfig.DEFAULT_FREQ: provider_uri}
else:
raise TypeError(f"provider_uri does not support {type(provider_uri)}")
for freq, _uri in provider_uri.items():
if QlibConfig.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
provider_uri[freq] = str(Path(_uri).expanduser().resolve())
return provider_uri
@staticmethod
def get_uri_type(uri: Union[str, Path]):
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
@@ -311,11 +330,7 @@ class QlibConfig(Config):
def resolve_path(self):
# resolve path
_mount_path = self["mount_path"]
_provider_uri = self["provider_uri"]
if _provider_uri is None:
raise ValueError("provider_uri cannot be None")
if not isinstance(_provider_uri, dict):
_provider_uri = {self.DEFAULT_FREQ: _provider_uri}
_provider_uri = self.DataPathManager.format_provider_uri(self["provider_uri"])
if not isinstance(_mount_path, dict):
_mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}
@@ -324,10 +339,7 @@ class QlibConfig(Config):
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
# resolve
for _freq, _uri in _provider_uri.items():
# provider_uri
if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
_provider_uri[_freq] = str(Path(_uri).expanduser().resolve())
for _freq in _provider_uri.keys():
# mount_path
_mount_path[_freq] = (
_mount_path[_freq]
@@ -337,20 +349,6 @@ class QlibConfig(Config):
self["provider_uri"] = _provider_uri
self["mount_path"] = _mount_path
def get_uri_type(self):
path = self["provider_uri"]
if isinstance(path, Path):
path = str(path)
is_win = re.match("^[a-zA-Z]:.*", path) is not None # such as 'C:\\data', 'D:'
is_nfs_or_win = (
re.match("^[^/]+:.+", path) is not None
) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
if is_nfs_or_win and not is_win:
return QlibConfig.NFS_URI
else:
return QlibConfig.LOCAL_URI
def set(self, default_conf: str = "client", **kwargs):
"""
configure qlib based on the input parameters

View File

View File

@@ -0,0 +1,183 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from typing import Dict, Iterable
def align_index(df_dict, join):
res = {}
for k, df in df_dict.items():
if join is not None and k != join:
df = df.reindex(df_dict[join].index)
res[k] = df
return res
# Mocking the pd.DataFrame class
class SepDataFrame:
"""
(Sep)erate DataFrame
We usually concat multiple dataframe to be processed together(Such as feature, label, weight, filter).
However, they are usally be used seperately at last.
This will result in extra cost for concating and spliting data(reshaping and copying data in the memory is very expensive)
SepDataFrame tries to act like a DataFrame whose column with multiindex
"""
def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False):
"""
initialize the data based on the dataframe dictionary
Parameters
----------
df_dict : Dict[str, pd.DataFrame]
dataframe dictionary
join : str
how to join the data
It will reindex the dataframe based on the join key.
If join is None, the reindex step will be skipped
skip_align :
for some cases, we can improve performance by skipping aligning index
"""
self.join = join
if skip_align:
self._df_dict = df_dict
else:
self._df_dict = align_index(df_dict, join)
@property
def loc(self):
return SDFLoc(self, join=self.join)
@property
def index(self):
return self._df_dict[self.join].index
def apply_each(self, method: str, skip_align=True, *args, **kwargs):
"""
Assumptions:
- inplace methods will return None
"""
inplace = False
df_dict = {}
for k, df in self._df_dict.items():
df_dict[k] = getattr(df, method)(*args, **kwargs)
if df_dict[k] is None:
inplace = True
if not inplace:
return SepDataFrame(df_dict=df_dict, join=self.join, skip_align=skip_align)
def sort_index(self, *args, **kwargs):
return self.apply_each("sort_index", True, *args, **kwargs)
def copy(self, *args, **kwargs):
return self.apply_each("copy", True, *args, **kwargs)
def _update_join(self):
if self.join not in self:
self.join = next(iter(self._df_dict.keys()))
def __getitem__(self, item):
return self._df_dict[item]
def __setitem__(self, item: str, df: pd.DataFrame):
# TODO: consider the join behavior
self._df_dict[item] = df
def __delitem__(self, item: str):
del self._df_dict[item]
self._update_join()
def __contains__(self, item):
return item in self._df_dict
def __len__(self):
return len(self._df_dict[self.join])
def droplevel(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `droplevel` method")
@property
def columns(self):
dfs = []
for k, df in self._df_dict.items():
df = df.head(0)
df.columns = pd.MultiIndex.from_product([[k], df.columns])
dfs.append(df)
return pd.concat(dfs, axis=1).columns
# Useless methods
@staticmethod
def merge(df_dict: Dict[str, pd.DataFrame], join: str):
all_df = df_dict[join]
for k, df in df_dict.items():
if k != join:
all_df = all_df.join(df)
return all_df
class SDFLoc:
"""Mock Class"""
def __init__(self, sdf: SepDataFrame, join):
self._sdf = sdf
self.axis = None
self.join = join
def __call__(self, axis):
self.axis = axis
return self
def __getitem__(self, args):
if self.axis == 1:
if isinstance(args, str):
return self._sdf[args]
elif isinstance(args, (tuple, list)):
new_df_dict = {k: self._sdf[k] for k in args}
return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0], skip_align=True)
else:
raise NotImplementedError(f"This type of input is not supported")
elif self.axis == 0:
return SepDataFrame(
{k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join, skip_align=True
)
else:
df = self._sdf
if isinstance(args, tuple):
ax0, *ax1 = args
if len(ax1) == 0:
ax1 = None
if ax1 is not None:
df = df.loc(axis=1)[ax1]
if ax0 is not None:
df = df.loc(axis=0)[ax0]
return df
else:
return df.loc(axis=0)[args]
# Patch pandas DataFrame
# Tricking isinstance to accept SepDataFrame as its subclass
import builtins
def _isinstance(instance, cls):
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602
if isinstance(cls, Iterable):
for c in cls:
if c is pd.DataFrame:
return True
elif cls is pd.DataFrame:
return True
return isinstance_orig(instance, cls) # pylint: disable=E0602
builtins.isinstance_orig = builtins.isinstance
builtins.isinstance = _isinstance
if __name__ == "__main__":
sdf = SepDataFrame({}, join=None)
print(isinstance(sdf, (pd.DataFrame,)))
print(isinstance(sdf, pd.DataFrame))

View File

@@ -3,15 +3,18 @@
from __future__ import division
from __future__ import print_function
from logging import warn
import numpy as np
import pandas as pd
import warnings
from typing import Union
from ..log import get_module_logger
from ..backtest import get_exchange, backtest as backtest_func
from ..utils import get_date_range
from ..utils.resam import Freq
from ..strategy.base import BaseStrategy
from ..backtest import get_exchange, position, backtest as backtest_func, executor as _executor
from ..data import D
from ..config import C
@@ -117,84 +120,129 @@ def indicator_analysis(df, method="mean"):
# This is the API for compatibility for legacy code
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **kwargs):
"""This function will help you set a reasonable Exchange and provide default value for strategy
def backtest_daily(
start_time: Union[str, pd.Timestamp],
end_time: Union[str, pd.Timestamp],
strategy: Union[str, dict, BaseStrategy],
executor: Union[str, dict, _executor.BaseExecutor] = None,
account: Union[float, int, position.Position] = 1e8,
benchmark: str = "SH000300",
exchange_kwargs: dict = None,
pos_type: str = "Position",
):
"""initialize the strategy and executor, then executor the backtest of daily frequency
Parameters
----------
start_time : Union[str, pd.Timestamp]
closed start time for backtest
**NOTE**: This will be applied to the outmost executor's calendar.
end_time : Union[str, pd.Timestamp]
closed end time for backtest
**NOTE**: This will be applied to the outmost executor's calendar.
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
strategy : Union[str, dict, BaseStrategy]
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information.
- **backtest workflow related or commmon arguments**
E.g.
pred : pandas.DataFrame
predict should has <datetime, instrument> index and one `score` column.
account : float
init account value.
shift : int
whether to shift prediction by one day.
benchmark : str
benchmark code, default is SH000905 CSI 500.
verbose : bool
whether to print log.
.. code-block:: python
# dict
strategy = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},
}
# BaseStrategy
pred_score = pd.read_pickle("score.pkl")["score"]
STRATEGY_CONFIG = {
"topk": 50,
"n_drop": 5,
"signal": pred_score,
}
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
# str example.
# 1) specify a pickle object
# - path like 'file:///<path to pickle file>/obj.pkl'
# 2) specify a class name
# - "ClassName": getattr(module, "ClassName")() will be used.
# 3) specify module path with class name
# - "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
- **strategy related arguments**
strategy : Strategy()
strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
- if isinstance(margin, int):
executor : Union[str, dict, BaseExecutor]
for initializing the outermost executor.
benchmark: str
the benchmark for reporting.
account : Union[float, int, Position]
information for describing how to creating the account
For `float` or `int`:
Using Account with only initial cash
For `Position`:
Using Account with a Position
exchange_kwargs : dict
the kwargs for initializing Exchange
E.g.
sell_limit = margin
.. code-block:: python
- else:
exchange_kwargs = {
"freq": freq,
"limit_threshold": None, # limit_threshold is None, using C.limit_threshold
"deal_price": None, # deal_price is None, using C.deal_price
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
sell_limit = pred_in_a_day.count() * margin
pos_type : str
the type of Position.
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
sell_limit should be no less than topk.
n_drop : int
number of stocks to be replaced in each trading date.
risk_degree: float
0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
- **exchange related arguments**
exchange: Exchange()
pass the exchange for speeding up.
subscribe_fields: list
subscribe fields.
open_cost : float
open transaction cost. The default value is 0.002(0.2%).
close_cost : float
close transaction cost. The default value is 0.002(0.2%).
min_cost : float
min transaction cost.
trade_unit : int
100 for China A.
deal_price: str
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
.. note:: This will be faster with offline qlib.
- **executor related arguments**
executor : BaseExecutor()
executor used in backtest.
verbose : bool
whether to print log.
Returns
-------
report_normal: pd.DataFrame
backtest report
positions_normal: pd.DataFrame
backtest positions
"""
warnings.warn("this function is deprecated, please use backtest function in qlib.backtest", DeprecationWarning)
report_dict = backtest_func(
pred=pred, account=account, shift=shift, benchmark=benchmark, verbose=verbose, return_order=False, **kwargs
freq = "day"
if executor is None:
executor_config = {
"time_per_step": freq,
"generate_portfolio_metrics": True,
}
executor = _executor.SimulatorExecutor(**executor_config)
_exchange_kwargs = {
"freq": freq,
"limit_threshold": None,
"deal_price": None,
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
if exchange_kwargs is not None:
_exchange_kwargs.update(exchange_kwargs)
portfolio_metric_dict, indicator_dict = backtest_func(
start_time=start_time,
end_time=end_time,
strategy=strategy,
executor=executor,
account=account,
benchmark=benchmark,
exchange_kwargs=_exchange_kwargs,
pos_type=pos_type,
)
return report_dict.get("report_df"), report_dict.get("positions")
analysis_freq = "{0}{1}".format(*Freq.parse(freq))
report_normal, positions_normal = portfolio_metric_dict.get(analysis_freq)
return report_normal, positions_normal
def long_short_backtest(
@@ -327,7 +375,12 @@ def t_run():
pred["datetime"] = pd.to_datetime(pred["datetime"])
pred = pred.set_index([pred.columns[0], pred.columns[1]])
pred = pred.iloc[:9000]
report_df, positions = backtest(pred=pred)
strategy_config = {
"topk": 50,
"n_drop": 5,
"signal": pred,
}
report_df, positions = backtest_daily(start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_config)
print(report_df.head())
print(positions.keys())
print(positions[list(positions.keys())[0]])

View File

@@ -30,8 +30,10 @@ try:
from .pytorch_nn import DNNModelPytorch
from .pytorch_tabnet import TabnetModel
from .pytorch_sfm import SFM_Model
from .pytorch_tcn import TCN
from .pytorch_add import ADD
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN, ADD)
except ModuleNotFoundError:
pytorch_classes = ()
print("Please install necessary libs for PyTorch models.")

View File

@@ -38,6 +38,8 @@ class CatBoostModel(Model, FeatureInt):
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

View File

@@ -64,6 +64,8 @@ class DEnsembleModel(Model, FeatureInt):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
# initialize the sample weights
N, F = x_train.shape

View File

@@ -14,17 +14,20 @@ from ...model.interpret.base import LightGBMFInt
class LGBModel(ModelFT, LightGBMFInt):
"""LightGBM Model"""
def __init__(self, loss="mse", **kwargs):
def __init__(self, loss="mse", early_stopping_rounds=50, **kwargs):
if loss not in {"mse", "binary"}:
raise NotImplementedError
self.params = {"objective": loss, "verbosity": -1}
self.params.update(kwargs)
self.early_stopping_rounds = early_stopping_rounds
self.model = None
def _prepare_data(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
@@ -42,7 +45,7 @@ class LGBModel(ModelFT, LightGBMFInt):
self,
dataset: DatasetH,
num_boost_round=1000,
early_stopping_rounds=50,
early_stopping_rounds=None,
verbose_eval=20,
evals_result=dict(),
**kwargs
@@ -54,7 +57,9 @@ class LGBModel(ModelFT, LightGBMFInt):
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
early_stopping_rounds=(
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
),
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs
@@ -83,6 +88,8 @@ class LGBModel(ModelFT, LightGBMFInt):
"""
# Based on existing model and finetune by train more rounds
dtrain, _ = self._prepare_data(dataset)
if dtrain.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
self.model = lgb.train(
self.params,
dtrain,

View File

@@ -82,6 +82,8 @@ class HFLGBModel(ModelFT, LightGBMFInt):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_train["feature"], df_valid["label"]

View File

@@ -51,6 +51,8 @@ class LinearModel(Model):
def fit(self, dataset: DatasetH):
df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if df_train.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
X, y = df_train["feature"].values, np.squeeze(df_train["label"].values)
if self.estimator in [self.OLS, self.RIDGE, self.LASSO]:

View File

@@ -0,0 +1,789 @@
# Copyright (c) Microsoft Corporation.
import os
from pdb import set_trace
from torch.utils.data import Dataset, DataLoader
import copy
from typing import Text, Union
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Function
from qlib.contrib.model.pytorch_utils import count_parameters
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import get_or_create_path
class ADARNN(Model):
"""ADARNN Model
Parameters
----------
d_feat : int
input dimension for each time step
metric: str
the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
"""
def __init__(
self,
d_feat=6,
hidden_size=64,
num_layers=2,
dropout=0.0,
n_epochs=200,
pre_epoch=40,
dw=0.5,
loss_type="cosine",
len_seq=60,
len_win=0,
lr=0.001,
metric="mse",
batch_size=2000,
early_stop=20,
loss="mse",
optimizer="adam",
n_splits=2,
GPU=0,
seed=None,
**kwargs
):
# Set logger.
self.logger = get_module_logger("ADARNN")
self.logger.info("ADARNN pytorch version...")
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU)
# set hyper-parameters.
self.d_feat = d_feat
self.hidden_size = hidden_size
self.num_layers = num_layers
self.dropout = dropout
self.n_epochs = n_epochs
self.pre_epoch = pre_epoch
self.dw = dw
self.loss_type = loss_type
self.len_seq = len_seq
self.len_win = len_win
self.lr = lr
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.n_splits = n_splits
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger.info(
"ADARNN parameters setting:"
"\nd_feat : {}"
"\nhidden_size : {}"
"\nnum_layers : {}"
"\ndropout : {}"
"\nn_epochs : {}"
"\nlr : {}"
"\nmetric : {}"
"\nbatch_size : {}"
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
d_feat,
hidden_size,
num_layers,
dropout,
n_epochs,
lr,
metric,
batch_size,
early_stop,
optimizer.lower(),
loss,
GPU,
self.use_gpu,
seed,
)
)
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
n_hiddens = [hidden_size for _ in range(num_layers)]
self.model = AdaRNN(
use_bottleneck=False,
bottleneck_width=64,
n_input=d_feat,
n_hiddens=n_hiddens,
n_output=1,
dropout=dropout,
model_type="AdaRNN",
len_seq=len_seq,
trans_loss=loss_type,
)
self.logger.info("model:\n{:}".format(self.model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.model.cuda()
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None):
self.model.train()
criterion = nn.MSELoss()
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
len_loader = np.inf
for loader in train_loader_list:
if len(loader) < len_loader:
len_loader = len(loader)
for data_all in zip(*train_loader_list):
# for data_all in zip(*train_loader_list):
self.train_optimizer.zero_grad()
list_feat = []
list_label = []
for data in data_all:
# feature :[36, 24, 6]
feature, label_reg = data[0].cuda().float(), data[1].cuda().float()
list_feat.append(feature)
list_label.append(label_reg)
flag = False
index = get_index(len(data_all) - 1)
for temp_index in index:
s1 = temp_index[0]
s2 = temp_index[1]
if list_feat[s1].shape[0] != list_feat[s2].shape[0]:
flag = True
break
if flag:
continue
total_loss = torch.zeros(1).cuda()
for i in range(len(index)):
feature_s = list_feat[index[i][0]]
feature_t = list_feat[index[i][1]]
label_reg_s = list_label[index[i][0]]
label_reg_t = list_label[index[i][1]]
feature_all = torch.cat((feature_s, feature_t), 0)
if epoch < self.pre_epoch:
pred_all, loss_transfer, out_weight_list = self.model.forward_pre_train(
feature_all, len_win=self.len_win
)
else:
pred_all, loss_transfer, dist, weight_mat = self.model.forward_Boosting(feature_all, weight_mat)
dist_mat = dist_mat + dist
pred_s = pred_all[0 : feature_s.size(0)]
pred_t = pred_all[feature_s.size(0) :]
loss_s = criterion(pred_s, label_reg_s)
loss_t = criterion(pred_t, label_reg_t)
total_loss = total_loss + loss_s + loss_t + self.dw * loss_transfer
self.train_optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
self.train_optimizer.step()
if epoch >= self.pre_epoch:
if epoch > self.pre_epoch:
weight_mat = self.model.update_weight_Boosting(weight_mat, dist_old, dist_mat)
return weight_mat, dist_mat
else:
weight_mat = self.transform_type(out_weight_list)
return weight_mat, None
def calc_all_metrics(self, pred):
"""pred is a pandas dataframe that has two attributes: score (pred) and label (real)"""
res = {}
ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score))
rank_ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score, method="spearman"))
res["ic"] = ic.mean()
res["icir"] = ic.mean() / ic.std()
res["ric"] = rank_ic.mean()
res["ricir"] = rank_ic.mean() / rank_ic.std()
res["mse"] = -(pred["label"] - pred["score"]).mean()
res["loss"] = res["mse"]
return res
def test_epoch(self, df):
self.model.eval()
preds = self.infer(df["feature"])
label = df["label"].squeeze()
preds = pd.DataFrame({"label": label, "score": preds}, index=df.index)
metrics = self.calc_all_metrics(preds)
return metrics
def log_metrics(self, mode, metrics):
metrics = ["{}/{}: {:.6f}".format(k, mode, v) for k, v in metrics.items()]
metrics = ", ".join(metrics)
self.logger.info(metrics)
def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):
df_train, df_valid = dataset.prepare(
["train", "valid"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
# splits = ['2011-06-30']
days = df_train.index.get_level_values(level=0).unique()
train_splits = np.array_split(days, self.n_splits)
train_splits = [df_train[s[0] : s[-1]] for s in train_splits]
train_loader_list = [get_stock_loader(df, self.batch_size) for df in train_splits]
save_path = get_or_create_path(save_path)
stop_steps = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
best_score = -np.inf
best_epoch = 0
weight_mat, dist_mat = None, None
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
weight_mat, dist_mat = self.train_AdaRNN(train_loader_list, step, dist_mat, weight_mat)
self.logger.info("evaluating...")
train_metrics = self.test_epoch(df_train)
valid_metrics = self.test_epoch(df_valid)
self.log_metrics("train: ", train_metrics)
self.log_metrics("valid: ", valid_metrics)
valid_score = valid_metrics[self.metric]
train_score = train_metrics[self.metric]
evals_result["train"].append(train_score)
evals_result["valid"].append(valid_score)
if valid_score > best_score:
best_score = valid_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
return best_score
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return self.infer(x_test)
def infer(self, x_test):
index = x_test.index
self.model.eval()
x_values = x_test.values
sample_num = x_values.shape[0]
x_values = x_values.reshape(sample_num, self.d_feat, -1).transpose(0, 2, 1)
preds = []
for begin in range(sample_num)[:: self.batch_size]:
if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().cuda()
with torch.no_grad():
pred = self.model.predict(x_batch).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index)
def transform_type(self, init_weight):
weight = torch.ones(self.num_layers, self.len_seq).cuda()
for i in range(self.num_layers):
for j in range(self.len_seq):
weight[i, j] = init_weight[i][j].item()
return weight
class data_loader(Dataset):
def __init__(self, df):
self.df_feature = df["feature"]
self.df_label_reg = df["label"]
self.df_index = df.index
self.df_feature = torch.tensor(
self.df_feature.values.reshape(-1, 6, 60).transpose(0, 2, 1), dtype=torch.float32
)
self.df_label_reg = torch.tensor(self.df_label_reg.values.reshape(-1), dtype=torch.float32)
def __getitem__(self, index):
sample, label_reg = self.df_feature[index], self.df_label_reg[index]
return sample, label_reg
def __len__(self):
return len(self.df_feature)
def get_stock_loader(df, batch_size, shuffle=True):
train_loader = DataLoader(data_loader(df), batch_size=batch_size, shuffle=shuffle)
return train_loader
def get_index(num_domain=2):
index = []
for i in range(num_domain):
for j in range(i + 1, num_domain + 1):
index.append((i, j))
return index
class AdaRNN(nn.Module):
"""
model_type: 'Boosting', 'AdaRNN'
"""
def __init__(
self,
use_bottleneck=False,
bottleneck_width=256,
n_input=128,
n_hiddens=[64, 64],
n_output=6,
dropout=0.0,
len_seq=9,
model_type="AdaRNN",
trans_loss="mmd",
):
super(AdaRNN, self).__init__()
self.use_bottleneck = use_bottleneck
self.n_input = n_input
self.num_layers = len(n_hiddens)
self.hiddens = n_hiddens
self.n_output = n_output
self.model_type = model_type
self.trans_loss = trans_loss
self.len_seq = len_seq
in_size = self.n_input
features = nn.ModuleList()
for hidden in n_hiddens:
rnn = nn.GRU(input_size=in_size, num_layers=1, hidden_size=hidden, batch_first=True, dropout=dropout)
features.append(rnn)
in_size = hidden
self.features = nn.Sequential(*features)
if use_bottleneck == True: # finance
self.bottleneck = nn.Sequential(
nn.Linear(n_hiddens[-1], bottleneck_width),
nn.Linear(bottleneck_width, bottleneck_width),
nn.BatchNorm1d(bottleneck_width),
nn.ReLU(),
nn.Dropout(),
)
self.bottleneck[0].weight.data.normal_(0, 0.005)
self.bottleneck[0].bias.data.fill_(0.1)
self.bottleneck[1].weight.data.normal_(0, 0.005)
self.bottleneck[1].bias.data.fill_(0.1)
self.fc = nn.Linear(bottleneck_width, n_output)
torch.nn.init.xavier_normal_(self.fc.weight)
else:
self.fc_out = nn.Linear(n_hiddens[-1], self.n_output)
if self.model_type == "AdaRNN":
gate = nn.ModuleList()
for i in range(len(n_hiddens)):
gate_weight = nn.Linear(len_seq * self.hiddens[i] * 2, len_seq)
gate.append(gate_weight)
self.gate = gate
bnlst = nn.ModuleList()
for i in range(len(n_hiddens)):
bnlst.append(nn.BatchNorm1d(len_seq))
self.bn_lst = bnlst
self.softmax = torch.nn.Softmax(dim=0)
self.init_layers()
def init_layers(self):
for i in range(len(self.hiddens)):
self.gate[i].weight.data.normal_(0, 0.05)
self.gate[i].bias.data.fill_(0.0)
def forward_pre_train(self, x, len_win=0):
out = self.gru_features(x)
fea = out[0] # [2N,L,H]
if self.use_bottleneck == True:
fea_bottleneck = self.bottleneck(fea[:, -1, :])
fc_out = self.fc(fea_bottleneck).squeeze()
else:
fc_out = self.fc_out(fea[:, -1, :]).squeeze() # [N,]
out_list_all, out_weight_list = out[1], out[2]
out_list_s, out_list_t = self.get_features(out_list_all)
loss_transfer = torch.zeros((1,)).cuda()
for i in range(len(out_list_s)):
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
h_start = 0
for j in range(h_start, self.len_seq, 1):
i_start = j - len_win if j - len_win >= 0 else 0
i_end = j + len_win if j + len_win < self.len_seq else self.len_seq - 1
for k in range(i_start, i_end + 1):
weight = (
out_weight_list[i][j]
if self.model_type == "AdaRNN"
else 1 / (self.len_seq - h_start) * (2 * len_win + 1)
)
loss_transfer = loss_transfer + weight * criterion_transder.compute(
out_list_s[i][:, j, :], out_list_t[i][:, k, :]
)
return fc_out, loss_transfer, out_weight_list
def gru_features(self, x, predict=False):
x_input = x
out = None
out_lis = []
out_weight_list = [] if (self.model_type == "AdaRNN") else None
for i in range(self.num_layers):
out, _ = self.features[i](x_input.float())
x_input = out
out_lis.append(out)
if self.model_type == "AdaRNN" and predict == False:
out_gate = self.process_gate_weight(x_input, i)
out_weight_list.append(out_gate)
return out, out_lis, out_weight_list
def process_gate_weight(self, out, index):
x_s = out[0 : int(out.shape[0] // 2)]
x_t = out[out.shape[0] // 2 : out.shape[0]]
x_all = torch.cat((x_s, x_t), 2)
x_all = x_all.view(x_all.shape[0], -1)
weight = torch.sigmoid(self.bn_lst[index](self.gate[index](x_all.float())))
weight = torch.mean(weight, dim=0)
res = self.softmax(weight).squeeze()
return res
def get_features(self, output_list):
fea_list_src, fea_list_tar = [], []
for fea in output_list:
fea_list_src.append(fea[0 : fea.size(0) // 2])
fea_list_tar.append(fea[fea.size(0) // 2 :])
return fea_list_src, fea_list_tar
# For Boosting-based
def forward_Boosting(self, x, weight_mat=None):
out = self.gru_features(x)
fea = out[0]
if self.use_bottleneck:
fea_bottleneck = self.bottleneck(fea[:, -1, :])
fc_out = self.fc(fea_bottleneck).squeeze()
else:
fc_out = self.fc_out(fea[:, -1, :]).squeeze()
out_list_all = out[1]
out_list_s, out_list_t = self.get_features(out_list_all)
loss_transfer = torch.zeros((1,)).cuda()
if weight_mat is None:
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).cuda()
else:
weight = weight_mat
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
for i in range(len(out_list_s)):
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
for j in range(self.len_seq):
loss_trans = criterion_transder.compute(out_list_s[i][:, j, :], out_list_t[i][:, j, :])
loss_transfer = loss_transfer + weight[i, j] * loss_trans
dist_mat[i, j] = loss_trans
return fc_out, loss_transfer, dist_mat, weight
# For Boosting-based
def update_weight_Boosting(self, weight_mat, dist_old, dist_new):
epsilon = 1e-5
dist_old = dist_old.detach()
dist_new = dist_new.detach()
ind = dist_new > dist_old + epsilon
weight_mat[ind] = weight_mat[ind] * (1 + torch.sigmoid(dist_new[ind] - dist_old[ind]))
weight_norm = torch.norm(weight_mat, dim=1, p=1)
weight_mat = weight_mat / weight_norm.t().unsqueeze(1).repeat(1, self.len_seq)
return weight_mat
def predict(self, x):
out = self.gru_features(x, predict=True)
fea = out[0]
if self.use_bottleneck == True:
fea_bottleneck = self.bottleneck(fea[:, -1, :])
fc_out = self.fc(fea_bottleneck).squeeze()
else:
fc_out = self.fc_out(fea[:, -1, :]).squeeze()
return fc_out
class TransferLoss(object):
def __init__(self, loss_type="cosine", input_dim=512):
"""
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
"""
self.loss_type = loss_type
self.input_dim = input_dim
def compute(self, X, Y):
"""Compute adaptation loss
Arguments:
X {tensor} -- source matrix
Y {tensor} -- target matrix
Returns:
[tensor] -- transfer loss
"""
if self.loss_type == "mmd_lin" or self.loss_type == "mmd":
mmdloss = MMD_loss(kernel_type="linear")
loss = mmdloss(X, Y)
elif self.loss_type == "coral":
loss = CORAL(X, Y)
elif self.loss_type == "cosine" or self.loss_type == "cos":
loss = 1 - cosine(X, Y)
elif self.loss_type == "kl":
loss = kl_div(X, Y)
elif self.loss_type == "js":
loss = js(X, Y)
elif self.loss_type == "mine":
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).cuda()
loss = mine_model(X, Y)
elif self.loss_type == "adv":
loss = adv(X, Y, input_dim=self.input_dim, hidden_dim=32)
elif self.loss_type == "mmd_rbf":
mmdloss = MMD_loss(kernel_type="rbf")
loss = mmdloss(X, Y)
elif self.loss_type == "pairwise":
pair_mat = pairwise_dist(X, Y)
loss = torch.norm(pair_mat)
return loss
def cosine(source, target):
source, target = source.mean(), target.mean()
cos = nn.CosineSimilarity(dim=0)
loss = cos(source, target)
return loss.mean()
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
class Discriminator(nn.Module):
def __init__(self, input_dim=256, hidden_dim=256):
super(Discriminator, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.dis1 = nn.Linear(input_dim, hidden_dim)
self.dis2 = nn.Linear(hidden_dim, 1)
def forward(self, x):
x = F.relu(self.dis1(x))
x = self.dis2(x)
x = torch.sigmoid(x)
return x
def adv(source, target, input_dim=256, hidden_dim=512):
domain_loss = nn.BCELoss()
# !!! Pay attention to .cuda !!!
adv_net = Discriminator(input_dim, hidden_dim).cuda()
domain_src = torch.ones(len(source)).cuda()
domain_tar = torch.zeros(len(target)).cuda()
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
reverse_src = ReverseLayerF.apply(source, 1)
reverse_tar = ReverseLayerF.apply(target, 1)
pred_src = adv_net(reverse_src)
pred_tar = adv_net(reverse_tar)
loss_s, loss_t = domain_loss(pred_src, domain_src), domain_loss(pred_tar, domain_tar)
loss = loss_s + loss_t
return loss
def CORAL(source, target):
d = source.size(1)
ns, nt = source.size(0), target.size(0)
# source covariance
tmp_s = torch.ones((1, ns)).cuda() @ source
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
# target covariance
tmp_t = torch.ones((1, nt)).cuda() @ target
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
# frobenius norm
loss = (cs - ct).pow(2).sum()
loss = loss / (4 * d * d)
return loss
class MMD_loss(nn.Module):
def __init__(self, kernel_type="linear", kernel_mul=2.0, kernel_num=5):
super(MMD_loss, self).__init__()
self.kernel_num = kernel_num
self.kernel_mul = kernel_mul
self.fix_sigma = None
self.kernel_type = kernel_type
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0]) + int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0 - total1) ** 2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)
def linear_mmd(self, X, Y):
delta = X.mean(axis=0) - Y.mean(axis=0)
loss = delta.dot(delta.T)
return loss
def forward(self, source, target):
if self.kernel_type == "linear":
return self.linear_mmd(source, target)
elif self.kernel_type == "rbf":
batch_size = int(source.size()[0])
kernels = self.guassian_kernel(
source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma
)
with torch.no_grad():
XX = torch.mean(kernels[:batch_size, :batch_size])
YY = torch.mean(kernels[batch_size:, batch_size:])
XY = torch.mean(kernels[:batch_size, batch_size:])
YX = torch.mean(kernels[batch_size:, :batch_size])
loss = torch.mean(XX + YY - XY - YX)
return loss
class Mine_estimator(nn.Module):
def __init__(self, input_dim=2048, hidden_dim=512):
super(Mine_estimator, self).__init__()
self.mine_model = Mine(input_dim, hidden_dim)
def forward(self, X, Y):
Y_shffle = Y[torch.randperm(len(Y))]
loss_joint = self.mine_model(X, Y)
loss_marginal = self.mine_model(X, Y_shffle)
ret = torch.mean(loss_joint) - torch.log(torch.mean(torch.exp(loss_marginal)))
loss = -ret
return loss
class Mine(nn.Module):
def __init__(self, input_dim=2048, hidden_dim=512):
super(Mine, self).__init__()
self.fc1_x = nn.Linear(input_dim, hidden_dim)
self.fc1_y = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, 1)
def forward(self, x, y):
h1 = F.leaky_relu(self.fc1_x(x) + self.fc1_y(y))
h2 = self.fc2(h1)
return h2
def pairwise_dist(X, Y):
n, d = X.shape
m, _ = Y.shape
assert d == Y.shape[1]
a = X.unsqueeze(1).expand(n, m, d)
b = Y.unsqueeze(0).expand(n, m, d)
return torch.pow(a - b, 2).sum(2)
def pairwise_dist_np(X, Y):
n, d = X.shape
m, _ = Y.shape
assert d == Y.shape[1]
a = np.expand_dims(X, 1)
b = np.expand_dims(Y, 0)
a = np.tile(a, (1, m, 1))
b = np.tile(b, (n, 1, 1))
return np.power(a - b, 2).sum(2)
def pa(X, Y):
XY = np.dot(X, Y.T)
XX = np.sum(np.square(X), axis=1)
XX = np.transpose([XX])
YY = np.sum(np.square(Y), axis=1)
dist = XX + YY - 2 * XY
return dist
def kl_div(source, target):
if len(source) < len(target):
target = target[: len(source)]
elif len(source) > len(target):
source = source[: len(target)]
criterion = nn.KLDivLoss(reduction="batchmean")
loss = criterion(source.log(), target)
return loss
def js(source, target):
if len(source) < len(target):
target = target[: len(source)]
elif len(source) > len(target):
source = source[: len(target)]
M = 0.5 * (source + target)
loss_1, loss_2 = kl_div(source, M), kl_div(target, M)
return 0.5 * (loss_1 + loss_2)

Some files were not shown because too many files have changed in this diff Show More