Merge nested main (#597)
* MVP for Indian Stocks in qlib using yahooquery * cleaned with black * cleaned with black * add YahooNormalizeIN and YahooNormalizeIN1d * cleaned the code * added 1min for IN and also updated readme * update comments * fix comments * recorder support upload both raw file and directory * fix comments * Update README.md * Fix docs of QlibRecorder * sort index after loader (#538) make sure the fetch method is based on a index-sorted pd.DataFrame * refactor online serving rolling api * refactor TRA * format by black * fix horizon * fix TRA when use single head * clean up * improve pretrain * update README * fix tra when logdir is None * fix tra when logdir is None * Update strategy.py * Update README.md * Update README.md * Conda Suggestion * code standard docs * Update ensemble.py (#560) * Fix CI Bug (#575) Co-authored-by: yuxwang <anduinnn@foxmail.com> * Update gen.py (#576) * Fix multi-process loop calls (#574) * check lexsort in the 'lazy_sort_index' function (#566) * check lexsort * check lexsort * lexsort comment * lexsort comment * Delete .DS_Store * Update README.md * bug fix & use oracle transport pretrain * mend * Add `backend_freq_config` parameter, support multi-freq uri * Add sample_config to QlibDataLoader, support multi-freq * add multi-freq example * get_cls_kwargs renamed get_callable_kwargs * support multi-freq uri * Add inst_processors to D.features * Fix typo * Fix the index type of the multi-freq example * Fix duplicate mlflow directories in tests * Add DataPathManager to QlibConfig && modify inst_processors to supports list only * Modify the default value in the multi_freq example * Modify client-server mode and dataset-cache to disable inst_processor * Add wheel package to github CI * fix comment * Update FAQ.rst * Update README.md Fix wrong link * Update the docs of TaskManager (#586) * Update manage.py * update yaml * update run_all_model * Modify the Feature to be case sensitive (#589) * update README * remove verbose * fix spell bug * fix typos (#592) * Update Release Note * fix portfolio bug * Add calendar support for resample * add freq kwargs * test.yml: Remove redundant code (#595) * Supporting shared processor (#596) * Supporting shared processor * fix readonly reverse bug * remove pytests dependency * with fit bug * fix parameter error * fix comments * Fix undefined names in Python code (#599) * Update pytorch_tabnet.py $ `flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics` ``` ./qlib/qlib/contrib/model/pytorch_tabnet.py:567:38: F821 undefined name 'inp' self.independ.append(GLU(inp, out_dim, vbs=vbs)) ^ ./qlib/examples/model_rolling/task_manager_rolling.py:75:18: F821 undefined name 'task_train' run_task(task_train, self.task_pool, experiment_name=self.experiment_name) ^ 2 F821 undefined name 'task_train' 2 ``` * Fix undefined names in Python code * from qlib.model.trainer import task_train * update seed * fix some docstring * add comments * Fix SimpleDatasetCache * Update setup.py updated classifiers * Update setup.py change to matplotlib==3.3 * Update python-publish.yml added python 3.9 * updategrade version number * Update model list * fix the type of filter_pipe * fix comment * fix record_temp * update cvxpy version * Update code_standard.rst (#587) * Update code_standard.rst * Update docs/developer/code_standard.rst Co-authored-by: you-n-g <you-n-g@users.noreply.github.com> Co-authored-by: you-n-g <you-n-g@users.noreply.github.com> * Add file lock for MLflowExpManager (#619) * fix torch version * Share version number (#620) * Update initialization.rst (#622) * Update initialization.rst * Update docs/start/initialization.rst Co-authored-by: you-n-g <you-n-g@users.noreply.github.com> * Update docs/start/initialization.rst Co-authored-by: you-n-g <you-n-g@users.noreply.github.com> Co-authored-by: you-n-g <you-n-g@users.noreply.github.com> * fix bugs for running previous exmaple * fix deal amount bug * update change doc (#623) * Add files via upload * Update README.md * Update README.md * Update README.md * Delete change doc.gif * Add files via upload * Update README.md * Delete change doc.gif * Add files via upload * Delete change doc.gif * Add files via upload * Update README.md Co-authored-by: you-n-g <you-n-g@users.noreply.github.com> Co-authored-by: you-n-g <you-n-g@users.noreply.github.com> * update doc * simplify run all model * fix run all model bug * Fix Models (#483) * fix gat dataset * fix tft model * Update tft.py * Fix tft.py Co-authored-by: Pengrong Zhu <zhu.pengrong@foxmail.com> * type and skip empty exp * fix model yaml config * fix tft import bug * skip empty result * fix model and yaml bug * fix wrong generate parameter * Modify multi-freq example (#626) * modify the example of multi-freq * add Copyright * add a comment to average_ops.py * modify the example of multi-freq * add comment to multi_freq_handler.py * add the Ref expression description to multi_freq_handler.py * add expression description to multi_freq_handler.py * update images * fix workflow and update framework Co-authored-by: Gaurav <2796gaurav@gmail.com> Co-authored-by: 2796gaurav <17353992+2796gaurav@users.noreply.github.com> Co-authored-by: bxdd <bxd98@126.com> Co-authored-by: Young <afe.young@gmail.com> Co-authored-by: you-n-g <you-n-g@users.noreply.github.com> Co-authored-by: Dong Zhou <Zhou.Dong@microsoft.com> Co-authored-by: ZhangTP1996 <ztp18@mails.tsinghua.edu.cn> Co-authored-by: demon143 <59681577+demon143@users.noreply.github.com> Co-authored-by: Wangwuyi123 <51237097+Wangwuyi123@users.noreply.github.com> Co-authored-by: yuxwang <anduinnn@foxmail.com> Co-authored-by: Pengrong Zhu <zhu.pengrong@foxmail.com> Co-authored-by: Mark Zhao <50850474+markzhao98@users.noreply.github.com> Co-authored-by: cslwqxx <cslwqxx@users.noreply.github.com> Co-authored-by: Dong Zhou <evanzd@users.noreply.github.com> Co-authored-by: SaintMalik <37118134+saintmalik@users.noreply.github.com> Co-authored-by: Christian Clauss <cclauss@me.com> Co-authored-by: Anurag Kumar <mailanu98@gmail.com> Co-authored-by: demon143 <785696300@qq.com>
2
.github/workflows/python-publish.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
93
.github/workflows/test.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -25,63 +25,29 @@ jobs:
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install black
|
||||
$CONDA\\python.exe -m black qlib -l 120 --check --diff
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install black
|
||||
$CONDA/bin/python -m black qlib -l 120 --check --diff
|
||||
fi
|
||||
shell: bash
|
||||
pip install --upgrade pip
|
||||
pip install black wheel
|
||||
black qlib -l 120 --check --diff
|
||||
|
||||
# Test Qlib installed with pip
|
||||
# - name: Install Qlib with pip
|
||||
# run: |
|
||||
# if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
# $CONDA\\python.exe -m pip install numpy==1.19.5
|
||||
# $CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml numpy --user
|
||||
# else
|
||||
# sudo $CONDA/bin/python -m pip install numpy==1.19.5
|
||||
# sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
# fi
|
||||
# shell: bash
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
pip install numpy==1.19.5 ruamel.yaml
|
||||
pip install pyqlib --ignore-installed
|
||||
|
||||
# - name: Test data downloads
|
||||
# run: |
|
||||
# if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
# $CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
# else
|
||||
# $CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
# fi
|
||||
# shell: bash
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
# - name: Test workflow by config (install from pip)
|
||||
# run: |
|
||||
# if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
# $CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
# $CONDA\\python.exe -m pip uninstall -y pyqlib
|
||||
# else
|
||||
# $CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
# sudo $CONDA/bin/python -m pip uninstall -y pyqlib
|
||||
# fi
|
||||
# shell: bash
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade cython
|
||||
$CONDA\\python.exe -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
$CONDA\\python.exe -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
$CONDA\\python.exe setup.py install
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade cython
|
||||
sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
sudo $CONDA/bin/python setup.py install
|
||||
fi
|
||||
shell: bash
|
||||
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install -e .
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
@@ -94,30 +60,15 @@ jobs:
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade pip
|
||||
$CONDA\\python.exe -m pip install black pytest
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade pip
|
||||
sudo $CONDA/bin/python -m pip install black pytest
|
||||
fi
|
||||
shell: bash
|
||||
pip install --upgrade pip
|
||||
pip install black pytest
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pytest . --durations=0
|
||||
else
|
||||
$CONDA/bin/python -m pytest . --durations=0
|
||||
fi
|
||||
shell: bash
|
||||
python -m pytest . --durations=10
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
else
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
fi
|
||||
shell: bash
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
|
||||
58
.github/workflows/test_macos.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -26,52 +26,46 @@ jobs:
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
sudo $CONDA/bin/python -m pip install black
|
||||
$CONDA/bin/python -m black qlib -l 120 --check --diff
|
||||
|
||||
python -m pip install pip --upgrade
|
||||
python -m pip install wheel --upgrade
|
||||
python -m pip install black
|
||||
python -m black qlib -l 120 --check --diff
|
||||
# Test Qlib installed with pip
|
||||
# - name: Install Qlib with pip
|
||||
# run: |
|
||||
# sudo $CONDA/bin/python -m pip install numpy==1.19.5
|
||||
# sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
python -m pip install numpy==1.19.5
|
||||
python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
|
||||
# - name: Test data downloads
|
||||
# run: |
|
||||
# $CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
# - name: Test workflow by config (install from pip)
|
||||
# run: |
|
||||
# $CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
# sudo $CONDA/bin/python -m pip uninstall -y pyqlib
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
sudo $CONDA/bin/python -m pip install --upgrade cython
|
||||
sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
sudo $CONDA/bin/python setup.py install
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
python setup.py install
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
sudo $CONDA/bin/python -m pip install --upgrade pip
|
||||
sudo $CONDA/bin/python -m pip install -U pyopenssl idna
|
||||
sudo $CONDA/bin/python -m pip install black pytest
|
||||
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -U pyopenssl idna
|
||||
python -m pip install black pytest
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
$CONDA/bin/python -m pytest . --durations=0
|
||||
|
||||
python -m pytest . --durations=0
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
1
.gitignore
vendored
@@ -20,6 +20,7 @@ dist/
|
||||
.nvimrc
|
||||
.vscode
|
||||
|
||||
qlib/VERSION.txt
|
||||
qlib/data/_libs/expanding.cpp
|
||||
qlib/data/_libs/rolling.cpp
|
||||
examples/estimator/estimator_example/
|
||||
|
||||
1
MANIFEST.in
Normal file
@@ -0,0 +1 @@
|
||||
include qlib/VERSION.txt
|
||||
43
README.md
@@ -11,6 +11,7 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
|
||||
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
|
||||
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
|
||||
| TCTS Model | [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |
|
||||
@@ -23,10 +24,8 @@ Recent released features
|
||||
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
|
||||
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
|
||||
</p>
|
||||
|
||||
|
||||
@@ -45,7 +44,7 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
|
||||
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
|
||||
- [**Quant Model Zoo**](#quant-model-zoo)
|
||||
- [**Quant Model(Paper) Zoo**](#quant-model-paper-zoo)
|
||||
- [Run a single model](#run-a-single-model)
|
||||
- [Run multiple models](#run-multiple-models)
|
||||
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
|
||||
@@ -71,7 +70,7 @@ Your feedbacks about the features are very important.
|
||||
# Framework of Qlib
|
||||
|
||||
<div style="align: center">
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.2" />
|
||||
<img src="docs/_static/img/framework.svg" />
|
||||
</div>
|
||||
|
||||
|
||||
@@ -107,8 +106,9 @@ This table demonstrates the supported Python version of `Qlib`:
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
|
||||
**Note**:
|
||||
1. **Conda** is suggested for managing your Python environment.
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
|
||||
### Install with pip
|
||||
Users can easily install ``Qlib`` by pip according to the following command.
|
||||
@@ -162,7 +162,7 @@ Users could create the same dataset with it.
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
|
||||
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
### Automatic update of daily frequency data(from yahoo finance)
|
||||
### Automatic update of daily frequency data (from yahoo finance)
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
|
||||
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
@@ -247,19 +247,19 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
|
||||
- Forecasting signal (model prediction) analysis
|
||||
- Cumulative Return of groups
|
||||

|
||||

|
||||
- Return distribution
|
||||

|
||||

|
||||
- Information Coefficient (IC)
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
- Auto Correlation of forecasting signal (model prediction)
|
||||

|
||||

|
||||
|
||||
- Portfolio analysis
|
||||
- Backtest return
|
||||

|
||||

|
||||
<!--
|
||||
- Score IC
|
||||

|
||||
@@ -276,7 +276,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
# [Quant Model (Paper) Zoo](examples/benchmarks)
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](qlib/contrib/model/xgboost.py)
|
||||
@@ -294,6 +294,7 @@ Here is a list of models built on `Qlib`.
|
||||
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](qlib/contrib/model/pytorch_tcts.py)
|
||||
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py)
|
||||
- [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py)
|
||||
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](qlib/contrib/model/pytorch_tra.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -307,9 +308,10 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
|
||||
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
|
||||
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.
|
||||
|
||||
@@ -374,9 +376,7 @@ Such overheads greatly slow down the data loading process.
|
||||
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
|
||||
|
||||
# Related Reports
|
||||
- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA)
|
||||
- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/)
|
||||
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
|
||||
- [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
|
||||
- [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
|
||||
|
||||
@@ -389,7 +389,7 @@ Qlib data are stored in a compact format, which is efficient to be combined into
|
||||
Join IM discussion groups:
|
||||
|[Gitter](https://gitter.im/Microsoft/qlib)|
|
||||
|----|
|
||||
||
|
||||
||
|
||||
|
||||
# Contributing
|
||||
|
||||
@@ -397,6 +397,11 @@ This project welcomes contributions and suggestions.
|
||||
**Here are some
|
||||
[code standards](docs/developer/code_standard.rst) when you submit a pull request.**
|
||||
|
||||
If you want to contribute to Qlib's document, you can follow the steps in the figure below.
|
||||
<p align="center">
|
||||
<img src="https://github.com/demon143/qlib/blob/main/docs/_static/img/change%20doc.gif" />
|
||||
</p>
|
||||
|
||||
|
||||
Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
|
||||
1
VERSION.txt
Normal file
@@ -0,0 +1 @@
|
||||
0.7.1.99
|
||||
@@ -98,3 +98,56 @@ Also, feel free to post a new issue in our GitHub repository. We always check ea
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
|
||||
|
||||
|
||||
4. BadNamespaceError: / is not a connected namespace
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
File "qlib_online.py", line 35, in <module>
|
||||
cal = D.calendar()
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
|
||||
return Cal.calendar(start_time, end_time, freq, future=future)
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
|
||||
self.conn.send_request(
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\python_socketio-5.3.0-py3.8.egg\socketio\client.py", line 369, in emit
|
||||
raise exceptions.BadNamespaceError(
|
||||
BadNamespaceError: / is not a connected namespace.
|
||||
|
||||
- The version of ``python-socketio`` in qlib needs to be the same as the version of ``python-socketio`` in qlib-server:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U python-socketio==<qlib-server python-socketio version>
|
||||
|
||||
|
||||
5. TypeError: send() got an unexpected keyword argument 'binary'
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
File "qlib_online.py", line 35, in <module>
|
||||
cal = D.calendar()
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
|
||||
return Cal.calendar(start_time, end_time, freq, future=future)
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
|
||||
self.conn.send_request(
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 263, in emit
|
||||
self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 339, in _send_packet
|
||||
self.eio.send(ep, binary=binary)
|
||||
TypeError: send() got an unexpected keyword argument 'binary'
|
||||
|
||||
|
||||
- The ``python-engineio`` version needs to be compatible with the ``python-socketio`` version, reference: https://github.com/miguelgrinberg/python-socketio#version-compatibility
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U python-engineio==<compatible python-socketio version>
|
||||
# or
|
||||
pip install -U python-socketio==3.1.2 python-engineio==3.13.2
|
||||
|
||||
BIN
docs/_static/img/analysis/analysis_model_IC.png
vendored
|
Before Width: | Height: | Size: 33 KiB After Width: | Height: | Size: 37 KiB |
BIN
docs/_static/img/analysis/analysis_model_NDQ.png
vendored
|
Before Width: | Height: | Size: 23 KiB After Width: | Height: | Size: 23 KiB |
|
Before Width: | Height: | Size: 47 KiB After Width: | Height: | Size: 44 KiB |
|
Before Width: | Height: | Size: 63 KiB After Width: | Height: | Size: 53 KiB |
|
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 16 KiB |
|
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 15 KiB |
BIN
docs/_static/img/analysis/report.png
vendored
|
Before Width: | Height: | Size: 160 KiB After Width: | Height: | Size: 144 KiB |
|
Before Width: | Height: | Size: 46 KiB After Width: | Height: | Size: 45 KiB |
BIN
docs/_static/img/analysis/risk_analysis_bar.png
vendored
|
Before Width: | Height: | Size: 13 KiB After Width: | Height: | Size: 10 KiB |
|
Before Width: | Height: | Size: 54 KiB After Width: | Height: | Size: 52 KiB |
|
Before Width: | Height: | Size: 53 KiB After Width: | Height: | Size: 48 KiB |
BIN
docs/_static/img/analysis/risk_analysis_std.png
vendored
|
Before Width: | Height: | Size: 47 KiB After Width: | Height: | Size: 44 KiB |
BIN
docs/_static/img/analysis/score_ic.png
vendored
|
Before Width: | Height: | Size: 102 KiB After Width: | Height: | Size: 93 KiB |
BIN
docs/_static/img/change doc.gif
vendored
Normal file
|
After Width: | Height: | Size: 1.3 MiB |
4
docs/_static/img/framework.svg
vendored
Normal file
|
After Width: | Height: | Size: 98 KiB |
@@ -30,7 +30,7 @@ The simple example of the default strategy is as follows.
|
||||
|
||||
from qlib.contrib.evaluate import backtest
|
||||
# pred_score is the prediction score
|
||||
report, positions = backtest(pred_score, topk=50, n_drop=0.5, verbose=False, limit_threshold=0.0095)
|
||||
report, positions = backtest(pred_score, topk=50, n_drop=0.5, limit_threshold=0.0095)
|
||||
|
||||
To know more about backtesting with a specific ``Strategy``, please refer to `Portfolio Strategy <strategy.html>`_.
|
||||
|
||||
|
||||
120
docs/component/highfreq.rst
Normal file
@@ -0,0 +1,120 @@
|
||||
.. _highfreq:
|
||||
|
||||
============================================
|
||||
Design of hierarchical order execution framework
|
||||
============================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
|
||||
In order to support reinforcement learning algorithms for high-frequency trading, a corresponding framework is required. None of the publicly available high-frequency trading frameworks now consider multi-layer trading mechanisms, and the currently designed algorithms cannot directly use existing frameworks.
|
||||
In addition to supporting the basic intraday multi-layer trading, the linkage with the day-ahead strategy is also a factor that affects the performance evaluation of the strategy. Different day strategies generate different order distributions and different patterns on different stocks. To verify that high-frequency trading strategies perform well on real trading orders, it is necessary to support day-frequency and high-frequency multi-level linkage trading. In addition to more accurate backtesting of high-frequency trading algorithms, if the distribution of day-frequency orders is considered when training a high-frequency trading model, the algorithm can also be optimized more for product-specific day-frequency orders.
|
||||
Therefore, innovation in the high-frequency trading framework is necessary to solve the various problems mentioned above, for which we designed a hierarchical order execution framework that can link daily-frequency and intra-day trading at different granularities.
|
||||
|
||||
.. image:: ../_static/img/framework.svg
|
||||
|
||||
The design of the framework is shown in the figure above. At each layer consists of Trading Agent and Execution Env. The Trading Agent has its own data processing module (Information Extractor), forecasting module (Forecast Model) and decision generator (Decision Generator). The trading algorithm generates the corresponding decisions by the Decision Generator based on the forecast signals output by the Forecast Module, and the decisions generated by the trading algorithm are passed to the Execution Env, which returns the execution results. Here the frequency of trading algorithm, decision content and execution environment can be customized by users (e.g. intra-day trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The hierarchical order execution framework is user-defined in terms of hierarchy division and decision frequency, making it easy for users to explore the effects of combining different levels of trading algorithms and breaking down the barriers between different levels of trading algorithm optimization.
|
||||
In addition to the innovation in the framework, the hierarchical order execution framework also takes into account various details of the real backtesting environment, minimizing the differences with the final real environment as much as possible. At the same time, the framework is designed to unify the interface between online and offline (e.g. data pre-processing level supports using the same set of code to process both offline and online data) to reduce the cost of strategy go-live as much as possible.
|
||||
|
||||
Prepare Data
|
||||
===================
|
||||
.. _data:: ../../examples/highfreq/README.md
|
||||
|
||||
|
||||
Example
|
||||
===========================
|
||||
|
||||
Here is an example of highfreq execution.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import qlib
|
||||
# init qlib
|
||||
provider_uri_day = "~/.qlib/qlib_data/cn_data"
|
||||
provider_uri_1min = "~/.qlib/qlib_data/cn_data_1min"
|
||||
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
|
||||
qlib.init(provider_uri=provider_uri_day, expression_cache=None, dataset_cache=None)
|
||||
|
||||
# data freq and backtest time
|
||||
freq = "1min"
|
||||
inst_list = D.list_instruments(D.instruments("all"), as_list=True)
|
||||
start_time = "2020-01-01"
|
||||
start_time = "2020-01-31"
|
||||
|
||||
When initializing qlib, if the default data is used, then both daily and minute frequency data need to be passed in.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# random order strategy config
|
||||
strategy_config = {
|
||||
"class": "RandomOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"trade_range": TradeRangeByTime("9:30", "15:00"),
|
||||
"sample_ratio": 1.0,
|
||||
"volume_ratio": 0.01,
|
||||
"market": market,
|
||||
},
|
||||
}
|
||||
|
||||
.. code-block:: python
|
||||
# backtest config
|
||||
backtest_config = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"account": 100000000,
|
||||
"benchmark": None,
|
||||
"exchange_kwargs": {
|
||||
"freq": freq,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
"codes": market,
|
||||
},
|
||||
"pos_type": "InfPosition", # Position with infinitive position
|
||||
}
|
||||
|
||||
please refer to "../../qlib/backtest".
|
||||
|
||||
.. code-block:: python
|
||||
# excutor config
|
||||
executor_config = {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": freq,
|
||||
"generate_portfolio_metrics": True,
|
||||
"verbose": False,
|
||||
# "verbose": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"track_data": True,
|
||||
"generate_portfolio_metrics": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
NestedExecutor represents not the innermost layer, the initialization parameters should contain inner_executor and inner_strategy. simulatorExecutor represents the current excutor is the innermost layer, the innermost strategy used here is the TWAP strategy, the framework currently also supports the VWAP strategy
|
||||
|
||||
.. code-block:: python
|
||||
# backtest
|
||||
portfolio_metrics_dict, indicator_dict = backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
|
||||
|
||||
The metrics of backtest are included in the portfolio_metrics_dict and indicator_dict.
|
||||
@@ -123,7 +123,6 @@ Here is a simple exampke of what is done in ``PortAnaRecord``, which users can r
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
|
||||
@@ -93,7 +93,6 @@ Usage & Example
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
|
||||
@@ -54,7 +54,6 @@ Below is a typical config file of ``qrun``.
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
@@ -242,7 +241,6 @@ The following script is the configuration of `backtest` and the `strategy` used
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
|
||||
@@ -6,12 +6,14 @@ Code Standard
|
||||
|
||||
Docstring
|
||||
=================================
|
||||
Please use the Numpy Style.
|
||||
Please use the `Numpydoc Style <https://stackoverflow.com/a/24385103>`_.
|
||||
|
||||
Continuous Integration
|
||||
=================================
|
||||
Continuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.
|
||||
|
||||
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
|
||||
|
||||
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -93,7 +93,6 @@ We write a simple configuration example as following,
|
||||
fend_time: 2018-12-11
|
||||
backtest:
|
||||
normal_backtest_args:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 500000
|
||||
benchmark: SH000905
|
||||
@@ -306,7 +305,6 @@ About the data and backtest
|
||||
fend_time: 2018-12-11
|
||||
backtest:
|
||||
normal_backtest_args:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 500000
|
||||
benchmark: SH000905
|
||||
|
||||
@@ -15,7 +15,7 @@ With ``Qlib``, users can easily try their ideas to create better Quant investmen
|
||||
Framework
|
||||
===================
|
||||
|
||||
.. image:: ../_static/img/framework.png
|
||||
.. image:: ../_static/img/framework.svg
|
||||
:align: center
|
||||
|
||||
|
||||
|
||||
@@ -77,7 +77,8 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
})
|
||||
- `mongo`
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
|
||||
Users need to follow the steps in `installation <https://www.mongodb.com/try/download/community>`_ to install MongoDB firstly and then access it via a URI.
|
||||
Users can access mongodb with credential by setting "task_url" to a string like `"mongodb://%s:%s@%s" % (user, pwd, host + ":" + port)`.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
|
||||
@@ -93,8 +93,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -83,8 +83,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -65,8 +65,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -72,8 +72,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -90,8 +90,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -97,8 +97,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -91,8 +91,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -83,8 +83,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -92,8 +92,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -82,8 +82,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -92,8 +92,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -82,8 +82,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
18
examples/benchmarks/LightGBM/features_resample_N.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.data.inst_processor import InstProcessor
|
||||
from qlib.utils.resam import resam_calendar
|
||||
|
||||
|
||||
class ResampleNProcessor(InstProcessor):
|
||||
def __init__(self, target_frq: str, **kwargs):
|
||||
self.target_frq = target_frq
|
||||
|
||||
def __call__(self, df: pd.DataFrame, *args, **kwargs):
|
||||
df.index = pd.to_datetime(df.index)
|
||||
res_index = resam_calendar(df.index, "1min", self.target_frq)
|
||||
df = df.resample(self.target_frq).last().reindex(res_index)
|
||||
return df
|
||||
135
examples/benchmarks/LightGBM/multi_freq_handler.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.data.dataset.loader import QlibDataLoader
|
||||
from qlib.contrib.data.handler import DataHandlerLP, _DEFAULT_LEARN_PROCESSORS, check_transform_proc
|
||||
|
||||
|
||||
class Avg15minLoader(QlibDataLoader):
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
df = super(Avg15minLoader, self).load(instruments, start_time, end_time)
|
||||
if self.is_group:
|
||||
# feature_day(day freq) and feature_15min(1min freq, Average every 15 minutes) renamed feature
|
||||
df.columns = df.columns.map(lambda x: ("feature", x[1]) if x[0].startswith("feature") else x)
|
||||
return df
|
||||
|
||||
|
||||
class Avg15minHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi500",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
infer_processors=[],
|
||||
learn_processors=_DEFAULT_LEARN_PROCESSORS,
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
data_loader = Avg15minLoader(
|
||||
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processor=inst_processor
|
||||
)
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
process_type=process_type,
|
||||
)
|
||||
|
||||
def loader_config(self):
|
||||
|
||||
# Results for dataset: df: pd.DataFrame
|
||||
# len(df.columns) == 6 + 6 * 16, len(df.index.get_level_values(level="datetime").unique()) == T
|
||||
# df.columns: close0, close1, ..., close16, open0, ..., open16, ..., vwap16
|
||||
# freq == day:
|
||||
# close0, open0, low0, high0, volume0, vwap0
|
||||
# freq == 1min:
|
||||
# close1, ..., close16, ..., vwap1, ..., vwap16
|
||||
# df.index.name == ["datetime", "instrument"]: pd.MultiIndex
|
||||
# Example:
|
||||
# feature ... label
|
||||
# close0 open0 low0 ... vwap1 vwap16 LABEL0
|
||||
# datetime instrument ...
|
||||
# 2020-10-09 SH600000 11.794546 11.819587 11.769505 ... NaN NaN -0.005214
|
||||
# 2020-10-15 SH600000 12.044961 11.944795 11.932274 ... NaN NaN -0.007202
|
||||
# ... ... ... ... ... ... ... ...
|
||||
# 2021-05-28 SZ300676 6.369684 6.495406 6.306568 ... NaN NaN -0.001321
|
||||
# 2021-05-31 SZ300676 6.601626 6.465643 6.465130 ... NaN NaN -0.023428
|
||||
|
||||
# features day: len(columns) == 6, freq = day
|
||||
# $close is the closing price of the current trading day:
|
||||
# if the user needs to get the `close` before the last T days, use Ref($close, T-1), for example:
|
||||
# $close Ref($close, 1) Ref($close, 2) Ref($close, 3) Ref($close, 4)
|
||||
# instrument datetime
|
||||
# SH600519 2021-06-01 244.271530
|
||||
# 2021-06-02 242.205917 244.271530
|
||||
# 2021-06-03 242.229889 242.205917 244.271530
|
||||
# 2021-06-04 245.421524 242.229889 242.205917 244.271530
|
||||
# 2021-06-07 247.547089 245.421524 242.229889 242.205917 244.271530
|
||||
|
||||
# WARNING: Ref($close, N), if N == 0, Ref($close, N) ==> $close
|
||||
|
||||
fields = ["$close", "$open", "$low", "$high", "$volume", "$vwap"]
|
||||
# names: close0, open0, ..., vwap0
|
||||
names = list(map(lambda x: x.strip("$") + "0", fields))
|
||||
|
||||
config = {"feature_day": (fields, names)}
|
||||
|
||||
# features 15min: len(columns) == 6 * 16, freq = 1min
|
||||
# $close is the closing price of the current trading day:
|
||||
# if the user gets 'close' for the i-th 15min of the last T days, use `Ref(Mean($close, 15), (T-1) * 240 + i * 15)`, for example:
|
||||
# Ref(Mean($close, 15), 225) Ref(Mean($close, 15), 465) Ref(Mean($close, 15), 705)
|
||||
# instrument datetime
|
||||
# SH600519 2021-05-31 241.769897 243.077942 244.712997
|
||||
# 2021-06-01 244.271530 241.769897 243.077942
|
||||
# 2021-06-02 242.205917 244.271530 241.769897
|
||||
|
||||
# WARNING: Ref(Mean($close, 15), N), if N == 0, Ref(Mean($close, 15), N) ==> Mean($close, 15)
|
||||
|
||||
# Results of the current script:
|
||||
# time: 09:00 --> 09:14, ..., 14:45 --> 14:59
|
||||
# fields: Ref(Mean($close, 15), 225), ..., Mean($close, 15)
|
||||
# name: close1, ..., close16
|
||||
#
|
||||
|
||||
# Expression description: take close as an example
|
||||
# Mean($close, 15) ==> df["$close"].rolling(15, min_periods=1).mean()
|
||||
# Ref(Mean($close, 15), 15) ==> df["$close"].rolling(15, min_periods=1).mean().shift(15)
|
||||
|
||||
# NOTE: The last data of each trading day, which is the average of the i-th 15 minutes
|
||||
|
||||
# Average:
|
||||
# Average of the i-th 15-minute period of each trading day: 1 <= i <= 250 // 16
|
||||
# Avg(15minutes): Ref(Mean($close, 15), 240 - i * 15)
|
||||
#
|
||||
# Average of the first 15 minutes of each trading day; i = 1
|
||||
# Avg(09:00 --> 09:14), df.index.loc["09:14"]: Ref(Mean($close, 15), 240- 1 * 15) ==> Ref(Mean($close, 15), 225)
|
||||
# Average of the last 15 minutes of each trading day; i = 16
|
||||
# Avg(14:45 --> 14:59), df.index.loc["14:59"]: Ref(Mean($close, 15), 240 - 16 * 15) ==> Ref(Mean($close, 15), 0) ==> Mean($close, 15)
|
||||
|
||||
# 15min resample to day
|
||||
# df.resample("1d").last()
|
||||
tmp_fields = []
|
||||
tmp_names = []
|
||||
for i, _f in enumerate(fields):
|
||||
_fields = [f"Ref(Mean({_f}, 15), {j * 15})" for j in range(1, 240 // 15)]
|
||||
_names = [f"{names[i][:-1]}{int(names[i][-1])+j}" for j in range(240 // 15 - 1, 0, -1)]
|
||||
_fields.append(f"Mean({_f}, 15)")
|
||||
_names.append(f"{names[i][:-1]}{int(names[i][-1])+240 // 15}")
|
||||
tmp_fields += _fields
|
||||
tmp_names += _names
|
||||
config["feature_15min"] = (tmp_fields, tmp_names)
|
||||
# label
|
||||
config["label"] = (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
return config
|
||||
@@ -66,8 +66,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -73,8 +73,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -81,8 +81,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
qlib_init:
|
||||
provider_uri:
|
||||
day: "~/.qlib/qlib_data/cn_data"
|
||||
1min: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
dataset_cache: null
|
||||
maxtasksperchild: null
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
# 1min closing time is 15:00:00
|
||||
end_time: "2020-08-01 15:00:00"
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
freq:
|
||||
label: day
|
||||
feature_15min: 1min
|
||||
feature_day: day
|
||||
# with label as reference
|
||||
inst_processor:
|
||||
feature_15min:
|
||||
- class: ResampleNProcessor
|
||||
module_path: features_resample_N.py
|
||||
kwargs:
|
||||
target_frq: 1d
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Avg15minHandler
|
||||
module_path: multi_freq_handler.py
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -72,8 +72,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: True
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -34,15 +34,19 @@ data_handler_config: &data_handler_config
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
@@ -70,7 +74,9 @@ task:
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -26,15 +26,19 @@ data_handler_config: &data_handler_config
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
@@ -61,7 +65,9 @@ task:
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -95,8 +95,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -82,8 +82,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -25,6 +25,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0491±0.01 | 0.3868±0.06 | 0.0589±0.00 | 0.4802±0.04 | 0.0898±0.02 | 1.2490±0.32 | -0.0778±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -43,6 +44,8 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features)| 0.0409±0.00 | 0.3253±0.04 | 0.0488±0.00 | 0.4045±0.02 | 0.0673±0.02 | 1.0389±0.39 | -0.0830±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0442±0.00 | 0.3426±0.03 | 0.0555±0.00 | 0.4395±0.03 | 0.0833±0.03 | 1.2064±0.36 | -0.0849±0.02 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
@@ -85,8 +85,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
4
examples/benchmarks/TCTS/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -90,8 +90,6 @@ task:
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
label_col: 1
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
|
||||
|
||||
### Notes
|
||||
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
|
||||
1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.
|
||||
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
|
||||
3. The model must run in GPU, or an error will be raised.
|
||||
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
tensorflow-gpu==1.15.0
|
||||
numpy == 1.19.4
|
||||
pandas==1.1.0
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow.compat.v1 as tf
|
||||
@@ -243,7 +245,7 @@ class TFTModel(ModelFT):
|
||||
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
|
||||
# 0.9)
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
print("Training completed.".format(dte.datetime.now()))
|
||||
print("Training completed at {}.".format(dte.datetime.now()))
|
||||
# ===========================Training Process===========================
|
||||
|
||||
def predict(self, dataset):
|
||||
@@ -289,3 +291,25 @@ class TFTModel(ModelFT):
|
||||
dataset for finetuning
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_pickle(self, path: Union[Path, str]):
|
||||
"""
|
||||
Tensorflow model can't be dumped directly.
|
||||
So the data should be save seperatedly
|
||||
|
||||
**TODO**: Please implement the function to load the files
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Union[Path, str]
|
||||
the target path to be dumped
|
||||
"""
|
||||
# FIXME: implementing saving tensorflow models
|
||||
# save tensorflow model
|
||||
# path = Path(path)
|
||||
# path.mkdir(parents=True)
|
||||
# self.model.save(path)
|
||||
|
||||
# save qlib model wrapper
|
||||
self.model = None
|
||||
super(TFTModel, self).to_pickle(path)
|
||||
|
||||
@@ -58,8 +58,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -1,29 +1,57 @@
|
||||
# Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport
|
||||
|
||||
This code provides a PyTorch implementation for TRA (Temporal Routing Adaptor), as described in the paper [Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport](http://arxiv.org/abs/2106.12950).
|
||||
Temporal Routing Adaptor (TRA) is designed to capture multiple trading patterns in the stock market data. Please refer to [our paper](http://arxiv.org/abs/2106.12950) for more details.
|
||||
|
||||
* TRA (Temporal Routing Adaptor) is a lightweight module that consists of a set of independent predictors for learning multiple patterns as well as a router to dispatch samples to different predictors.
|
||||
* We also design a learning algorithm based on Optimal Transport (OT) to obtain the optimal sample to predictor assignment and effectively optimize the router with such assignment through an auxiliary loss term.
|
||||
If you find our work useful in your research, please cite:
|
||||
```
|
||||
@inproceedings{HengxuKDD2021,
|
||||
author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
|
||||
title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
|
||||
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
|
||||
series = {KDD '21},
|
||||
year = {2021},
|
||||
publisher = {ACM},
|
||||
}
|
||||
|
||||
@article{yang2020qlib,
|
||||
title={Qlib: An AI-oriented Quantitative Investment Platform},
|
||||
author={Yang, Xiao and Liu, Weiqing and Zhou, Dong and Bian, Jiang and Liu, Tie-Yan},
|
||||
journal={arXiv preprint arXiv:2009.11189},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
# Running TRA
|
||||
## Usage (Recommended)
|
||||
|
||||
## Requirements
|
||||
- Install `Qlib` main branch
|
||||
**Update**: `TRA` has been moved to `qlib.contrib.model.pytorch_tra` to support other `Qlib` components like `qlib.workflow` and `Alpha158/Alpha360` dataset.
|
||||
|
||||
## Running
|
||||
Please follow the official [doc](https://qlib.readthedocs.io/en/latest/component/workflow.html) to use `TRA` with `workflow`. Here we also provide several example config files:
|
||||
|
||||
- `workflow_config_tra_Alpha360.yaml`: running `TRA` with `Alpha360` dataset
|
||||
- `workflow_config_tra_Alpha158.yaml`: running `TRA` with `Alpha158` dataset (with feature subsampling)
|
||||
- `workflow_config_tra_Alpha158_full.yaml`: running `TRA` with `Alpha158` dataset (without feature subsampling)
|
||||
|
||||
The performances of `TRA` are reported in [Benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).
|
||||
|
||||
## Usage (Not Maintained)
|
||||
|
||||
This section is used to reproduce the results in the paper.
|
||||
|
||||
### Running
|
||||
|
||||
We attach our running scripts for the paper in `run.sh`.
|
||||
|
||||
And here are two ways to run the model:
|
||||
|
||||
* Running from scripts with default parameters
|
||||
|
||||
You can directly run from Qlib command `qrun`:
|
||||
```
|
||||
qrun configs/config_alstm.yaml
|
||||
```
|
||||
|
||||
* Running from code with self-defined parameters
|
||||
|
||||
Setting different parameters is also allowed. See codes in `example.py`:
|
||||
```
|
||||
python example.py --config_file configs/config_alstm.yaml
|
||||
@@ -31,23 +59,20 @@ And here are two ways to run the model:
|
||||
|
||||
Here we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scipts.
|
||||
|
||||
# Results
|
||||
|
||||
## Outputs
|
||||
### Results
|
||||
|
||||
After running the scripts, you can find result files in path `./output`:
|
||||
|
||||
`info.json` - config settings and result metrics.
|
||||
* `info.json` - config settings and result metrics.
|
||||
* `log.csv` - running logs.
|
||||
* `model.bin` - the model parameter dictionary.
|
||||
* `pred.pkl` - the prediction scores and output for inference.
|
||||
|
||||
`log.csv` - running logs.
|
||||
Evaluation metrics reported in the paper:
|
||||
This result is generated by qlib==0.7.1.
|
||||
|
||||
`model.bin` - the model parameter dictionary.
|
||||
|
||||
`pred.pkl` - the prediction scores and output for inference.
|
||||
|
||||
## Our Results
|
||||
| Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD |
|
||||
|-------------------|-------------------|---------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|
||||
|-------|-------|------|-----|-----|-----|-----|-----|-----|
|
||||
|Linear|0.163|0.327|0.020|0.132|-3.2%|16.8%|-0.191|32.1%|
|
||||
|LightGBM|0.160(0.000)|0.323(0.000)|0.041|0.292|7.8%|15.5%|0.503|25.7%|
|
||||
|MLP|0.160(0.002)|0.323(0.003)|0.037|0.273|3.7%|15.3%|0.264|26.2%|
|
||||
@@ -61,21 +86,8 @@ After running the scripts, you can find result files in path `./output`:
|
||||
|
||||
A more detailed demo for our experiment results in the paper can be found in `Report.ipynb`.
|
||||
|
||||
# Common Issues
|
||||
## Common Issues
|
||||
|
||||
For help or issues using TRA, please submit a GitHub issue.
|
||||
|
||||
Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.
|
||||
|
||||
# Citation
|
||||
If you find this repository useful in your research, please cite:
|
||||
```
|
||||
@inproceedings{HengxuKDD2021,
|
||||
author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
|
||||
title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
|
||||
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
|
||||
series = {KDD '21},
|
||||
year = {2021},
|
||||
publisher = {ACM},
|
||||
}
|
||||
```
|
||||
|
||||
5
examples/benchmarks/TRA/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
seaborn
|
||||
132
examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml
Normal file
@@ -0,0 +1,132 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 20
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.0
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha158
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
126
examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml
Normal file
@@ -0,0 +1,126 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 158
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.2
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha158_full
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
126
examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml
Normal file
@@ -0,0 +1,126 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.0
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha360
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size: 6
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -75,8 +75,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -75,8 +75,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -34,15 +34,19 @@ data_handler_config: &data_handler_config
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
@@ -70,7 +74,9 @@ task:
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -26,15 +26,19 @@ data_handler_config: &data_handler_config
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
@@ -61,7 +65,9 @@ task:
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -64,8 +64,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -71,8 +71,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -60,8 +60,6 @@ task:
|
||||
- class: "SignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: "HFSignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
1
examples/model_rolling/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
xgboost
|
||||
@@ -17,7 +17,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from qlib.model.trainer import TrainerRM, task_train
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class NestedDecisionExecutionWorkflow:
|
||||
benchmark = "SH000300"
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-12-31",
|
||||
"end_time": "2021-05-31",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
@@ -53,7 +53,7 @@ class NestedDecisionExecutionWorkflow:
|
||||
"segments": {
|
||||
"train": ("2007-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2020-01-01", "2020-12-31"),
|
||||
"test": ("2020-01-01", "2021-05-31"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -75,7 +75,7 @@ class NestedDecisionExecutionWorkflow:
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "5min",
|
||||
"generate_report": True,
|
||||
"generate_portfolio_metrics": True,
|
||||
"verbose": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
@@ -86,7 +86,7 @@ class NestedDecisionExecutionWorkflow:
|
||||
"class": "TWAPStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"generate_report": True,
|
||||
"generate_portfolio_metrics": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
@@ -101,15 +101,15 @@ class NestedDecisionExecutionWorkflow:
|
||||
},
|
||||
},
|
||||
"track_data": True,
|
||||
"generate_report": True,
|
||||
"generate_portfolio_metrics": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"backtest": {
|
||||
"start_time": "2020-01-01",
|
||||
"end_time": "2020-12-31",
|
||||
"start_time": "2020-09-20",
|
||||
"end_time": "2021-05-20",
|
||||
"account": 100000000,
|
||||
"exchange_kwargs": {
|
||||
"freq": "1min",
|
||||
@@ -124,8 +124,6 @@ class NestedDecisionExecutionWorkflow:
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# provider_uri_day = "/data/stock_data/huaxia/qlib"
|
||||
# provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib"
|
||||
provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True)
|
||||
provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
|
||||
@@ -133,31 +131,7 @@ class NestedDecisionExecutionWorkflow:
|
||||
target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True
|
||||
)
|
||||
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
|
||||
client_config = {
|
||||
"calendar_provider": {
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
"feature_provider": {
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri_day, **client_config, redis_port=-1)
|
||||
qlib.init(provider_uri=provider_uri_map, dataset_cache=None, expression_cache=None)
|
||||
|
||||
def _train_model(self, model, dataset):
|
||||
with R.start(experiment_name="train"):
|
||||
@@ -186,9 +160,8 @@ class NestedDecisionExecutionWorkflow:
|
||||
},
|
||||
}
|
||||
self.port_analysis_config["strategy"] = strategy_config
|
||||
self.port_analysis_config["backtest"]["benchmark"] = D.list_instruments(
|
||||
instruments=D.instruments(market=self.market), as_list=True
|
||||
)
|
||||
self.port_analysis_config["backtest"]["benchmark"] = self.benchmark
|
||||
|
||||
with R.start(experiment_name="backtest"):
|
||||
|
||||
recorder = R.get_recorder()
|
||||
@@ -201,6 +174,7 @@ class NestedDecisionExecutionWorkflow:
|
||||
)
|
||||
par.generate()
|
||||
|
||||
# user could use following methods to analysis the position
|
||||
# report_normal_df = recorder.load_object("portfolio_analysis/report_normal_1day.pkl")
|
||||
# from qlib.contrib.report import analysis_position
|
||||
# analysis_position.report_graph(report_normal_df)
|
||||
@@ -212,7 +186,7 @@ class NestedDecisionExecutionWorkflow:
|
||||
self._train_model(model, dataset)
|
||||
executor_config = self.port_analysis_config["executor"]
|
||||
backtest_config = self.port_analysis_config["backtest"]
|
||||
backtest_config["benchmark"] = D.list_instruments(instruments=D.instruments(market=self.market), as_list=True)
|
||||
backtest_config["benchmark"] = self.benchmark
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import fire
|
||||
import time
|
||||
import glob
|
||||
import yaml
|
||||
import shutil
|
||||
import signal
|
||||
import inspect
|
||||
@@ -23,22 +24,6 @@ from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
# init qlib
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
exp_folder_name = "run_all_model_records"
|
||||
exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name)
|
||||
exp_manager = {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "file:" + exp_path,
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
|
||||
|
||||
# decorator to check the arguments
|
||||
def only_allow_defined_args(function_to_decorate):
|
||||
@@ -88,11 +73,11 @@ def create_env():
|
||||
sys.stderr.write("\n")
|
||||
# get anaconda activate path
|
||||
conda_activate = Path(os.environ["CONDA_PREFIX"]) / "bin" / "activate" # TODO: FIX ME!
|
||||
return env_path, python_path, conda_activate
|
||||
return temp_dir, env_path, python_path, conda_activate
|
||||
|
||||
|
||||
# function to execute the cmd
|
||||
def execute(cmd, wait_when_err=False):
|
||||
def execute(cmd, wait_when_err=False, raise_err=True):
|
||||
print("Running CMD:", cmd)
|
||||
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
|
||||
for line in p.stdout:
|
||||
@@ -105,6 +90,8 @@ def execute(cmd, wait_when_err=False):
|
||||
if p.returncode != 0:
|
||||
if wait_when_err:
|
||||
input("Press Enter to Continue")
|
||||
if raise_err:
|
||||
raise RuntimeError(f"Error when executing command: {cmd}")
|
||||
return p.stderr
|
||||
else:
|
||||
return None
|
||||
@@ -134,14 +121,23 @@ def get_all_folders(models, exclude) -> dict:
|
||||
def get_all_files(folder_path, dataset) -> (str, str):
|
||||
yaml_path = str(Path(f"{folder_path}") / f"*{dataset}*.yaml")
|
||||
req_path = str(Path(f"{folder_path}") / f"*.txt")
|
||||
return glob.glob(yaml_path)[0], glob.glob(req_path)[0]
|
||||
yaml_file = glob.glob(yaml_path)
|
||||
req_file = glob.glob(req_path)
|
||||
if len(yaml_file) == 0:
|
||||
return None, None
|
||||
else:
|
||||
return yaml_file[0], req_file[0]
|
||||
|
||||
|
||||
# function to retrieve all the results
|
||||
def get_all_results(folders) -> dict:
|
||||
results = dict()
|
||||
for fn in folders:
|
||||
try:
|
||||
exp = R.get_exp(experiment_name=fn, create=False)
|
||||
except ValueError:
|
||||
# No experiment results
|
||||
continue
|
||||
recorders = exp.list_recorders()
|
||||
result = dict()
|
||||
result["annualized_return_with_cost"] = list()
|
||||
@@ -155,9 +151,9 @@ def get_all_results(folders) -> dict:
|
||||
if recorders[recorder_id].status == "FINISHED":
|
||||
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
|
||||
metrics = recorder.list_metrics()
|
||||
result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
|
||||
result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
|
||||
result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
|
||||
result["annualized_return_with_cost"].append(metrics["1day.excess_return_with_cost.annualized_return"])
|
||||
result["information_ratio_with_cost"].append(metrics["1day.excess_return_with_cost.information_ratio"])
|
||||
result["max_drawdown_with_cost"].append(metrics["1day.excess_return_with_cost.max_drawdown"])
|
||||
result["ic"].append(metrics["IC"])
|
||||
result["icir"].append(metrics["ICIR"])
|
||||
result["rank_ic"].append(metrics["Rank IC"])
|
||||
@@ -185,6 +181,25 @@ def gen_and_save_md_table(metrics, dataset):
|
||||
return table
|
||||
|
||||
|
||||
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
||||
def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
|
||||
with open(yaml_path, "r") as fp:
|
||||
config = yaml.load(fp)
|
||||
try:
|
||||
del config["task"]["model"]["kwargs"]["seed"]
|
||||
except KeyError:
|
||||
# If the key does not exists, use original yaml
|
||||
# NOTE: it is very important if the model most run in original path(when sys.rel_path is used)
|
||||
return yaml_path
|
||||
else:
|
||||
# otherwise, generating a new yaml without random seed
|
||||
file_name = yaml_path.split("/")[-1]
|
||||
temp_path = os.path.join(temp_dir, file_name)
|
||||
with open(temp_path, "w") as fp:
|
||||
yaml.dump(config, fp)
|
||||
return temp_path
|
||||
|
||||
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(
|
||||
@@ -193,12 +208,13 @@ def run(
|
||||
dataset="Alpha360",
|
||||
exclude=False,
|
||||
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
|
||||
exp_folder_name: str = "run_all_model_records",
|
||||
wait_before_rm_env: bool = False,
|
||||
wait_when_err: bool = False,
|
||||
):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parrallel running the same model
|
||||
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model
|
||||
for multiple times, and this will be fixed in the future development.
|
||||
|
||||
Parameters:
|
||||
@@ -214,6 +230,8 @@ def run(
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path
|
||||
exp_folder_name: str
|
||||
the name of the experiment folder
|
||||
wait_before_rm_env : bool
|
||||
wait before remove environment.
|
||||
wait_when_err : bool
|
||||
@@ -240,26 +258,58 @@ def run(
|
||||
# Case 5 - run specific models for one time
|
||||
python run_all_model.py --models=[mlp,lightgbm]
|
||||
|
||||
# Case 6 - run other models except those are given as aruments for one time
|
||||
# Case 6 - run other models except those are given as arguments for one time
|
||||
python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
|
||||
|
||||
"""
|
||||
# init qlib
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
qlib.init(
|
||||
exp_manager={
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "file:" + str(Path(os.getcwd()).resolve() / exp_folder_name),
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# get all folders
|
||||
folders = get_all_folders(models, exclude)
|
||||
# init error messages:
|
||||
errors = dict()
|
||||
# run all the model for iterations
|
||||
for fn in folders:
|
||||
# create env by anaconda
|
||||
env_path, python_path, conda_activate = create_env()
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn], dataset)
|
||||
if yaml_path is None:
|
||||
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
|
||||
continue
|
||||
sys.stderr.write("\n")
|
||||
# create env by anaconda
|
||||
temp_dir, env_path, python_path, conda_activate = create_env()
|
||||
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
with open(req_path) as f:
|
||||
content = f.read()
|
||||
if "torch" in content:
|
||||
# automatically install pytorch according to nvidia's version
|
||||
execute(
|
||||
f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err
|
||||
) # for automatically installing torch according to the nvidia driver
|
||||
execute(
|
||||
f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
else:
|
||||
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
|
||||
sys.stderr.write("\n")
|
||||
|
||||
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
||||
yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)
|
||||
# setup gpu for tft
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
@@ -302,6 +352,7 @@ def run(
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
results = get_all_results(folders)
|
||||
if len(results) > 0:
|
||||
# calculating the mean and std
|
||||
sys.stderr.write(f"Calculating the mean and std of results...\n")
|
||||
results = cal_mean_std(results)
|
||||
@@ -309,12 +360,12 @@ def run(
|
||||
sys.stderr.write(f"Generating markdown table...\n")
|
||||
gen_and_save_md_table(results, dataset)
|
||||
sys.stderr.write("\n")
|
||||
# print erros
|
||||
# print errors
|
||||
sys.stderr.write(f"Here are some of the errors of the models...\n")
|
||||
pprint(errors)
|
||||
sys.stderr.write("\n")
|
||||
# move results folder
|
||||
shutil.move(exp_path, exp_path + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
|
||||
shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
|
||||
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
|
||||
|
||||
|
||||
|
||||
@@ -20,9 +20,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys, site\n",
|
||||
@@ -201,7 +199,7 @@
|
||||
" \"module_path\": \"qlib.backtest.executor\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"time_per_step\": \"day\",\n",
|
||||
" \"generate_report\": True,\n",
|
||||
" \"generate_portfolio_metrics\": True,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"strategy\": {\n",
|
||||
@@ -362,7 +360,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -375,8 +373,7 @@
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.3"
|
||||
"pygments_lexer": "ipython3"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
|
||||
@@ -26,7 +26,7 @@ if __name__ == "__main__":
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"generate_report": True,
|
||||
"generate_portfolio_metrics": True,
|
||||
},
|
||||
},
|
||||
"strategy": {
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
__version__ = "0.7.0.99"
|
||||
_version_path = Path(__file__).absolute().parent / "VERSION.txt" # This file is copyed from setup.py
|
||||
__version__ = _version_path.read_text(encoding="utf-8").strip()
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
|
||||
|
||||
import os
|
||||
import yaml
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
@@ -33,17 +31,21 @@ def init(default_conf="client", **kwargs):
|
||||
H.clear()
|
||||
C.set(default_conf, **kwargs)
|
||||
|
||||
# mount nfs
|
||||
for _freq, provider_uri in C.provider_uri.items():
|
||||
mount_path = C["mount_path"][_freq]
|
||||
# check path if server/local
|
||||
if C.get_uri_type() == C.LOCAL_URI:
|
||||
if not os.path.exists(C["provider_uri"]):
|
||||
uri_type = C.dpm.get_uri_type(provider_uri)
|
||||
if uri_type == C.LOCAL_URI:
|
||||
if not Path(provider_uri).exists():
|
||||
if C["auto_mount"]:
|
||||
logger.error(
|
||||
f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
|
||||
f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
|
||||
elif C.get_uri_type() == C.NFS_URI:
|
||||
_mount_nfs_uri(C)
|
||||
logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
|
||||
elif uri_type == C.NFS_URI:
|
||||
_mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
|
||||
else:
|
||||
raise NotImplementedError(f"This type of URI is not supported")
|
||||
|
||||
@@ -52,50 +54,48 @@ def init(default_conf="client", **kwargs):
|
||||
if "flask_server" in C:
|
||||
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
logger.info("qlib successfully initialized based on %s settings." % default_conf)
|
||||
logger.info(f"data_path={C.get_data_path()}")
|
||||
data_path = {_freq: C.dpm.get_data_uri(_freq) for _freq in C.dpm.provider_uri.keys()}
|
||||
logger.info(f"data_path={data_path}")
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
|
||||
LOG = get_module_logger("mount nfs", level=logging.INFO)
|
||||
|
||||
if mount_path is None:
|
||||
raise ValueError(f"Invalid mount path: {mount_path}!")
|
||||
# FIXME: the C["provider_uri"] is modified in this function
|
||||
# If it is not modified, we can pass only provider_uri or mount_path instead of C
|
||||
mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"])
|
||||
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
|
||||
# If the provider uri looks like this 172.23.233.89//data/csdesign'
|
||||
# It will be a nfs path. The client provider will be used
|
||||
if not C["auto_mount"]:
|
||||
if not os.path.exists(C["mount_path"]):
|
||||
if not auto_mount:
|
||||
if not Path(mount_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Invalid mount path: {C['mount_path']}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
)
|
||||
else:
|
||||
# Judging system type
|
||||
sys_type = platform.system()
|
||||
if "win" in sys_type.lower():
|
||||
# system: window
|
||||
exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":"))
|
||||
exec_result = os.popen("mount -o anon %s %s" % (provider_uri, mount_path + ":"))
|
||||
result = exec_result.read()
|
||||
if "85" in result:
|
||||
LOG.warning("already mounted or window mount path already exists")
|
||||
LOG.warning(f"{provider_uri} on Windows:{mount_path} is already mounted")
|
||||
elif "53" in result:
|
||||
raise OSError("not find network path")
|
||||
elif "error" in result or "错误" in result:
|
||||
raise OSError("Invalid mount path")
|
||||
elif C["provider_uri"] in result:
|
||||
elif provider_uri in result:
|
||||
LOG.info("window success mount..")
|
||||
else:
|
||||
raise OSError(f"unknown error: {result}")
|
||||
|
||||
# config mount path
|
||||
C["mount_path"] = C["mount_path"] + ":\\"
|
||||
else:
|
||||
# system: linux/Unix/Mac
|
||||
# check mount
|
||||
_remote_uri = C["provider_uri"]
|
||||
_remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri
|
||||
_mount_path = C["mount_path"]
|
||||
_mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path
|
||||
_remote_uri = provider_uri[:-1] if provider_uri.endswith("/") else provider_uri
|
||||
_mount_path = mount_path[:-1] if mount_path.endswith("/") else mount_path
|
||||
_check_level_num = 2
|
||||
_is_mount = False
|
||||
while _check_level_num:
|
||||
@@ -121,11 +121,9 @@ def _mount_nfs_uri(C):
|
||||
|
||||
if not _is_mount:
|
||||
try:
|
||||
os.makedirs(C["mount_path"], exist_ok=True)
|
||||
Path(mount_path).mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
raise OSError(
|
||||
f"Failed to create directory {C['mount_path']}, please create {C['mount_path']} manually!"
|
||||
)
|
||||
raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
|
||||
|
||||
# check nfs-common
|
||||
command_res = os.popen("dpkg -l | grep nfs-common")
|
||||
@@ -136,11 +134,11 @@ def _mount_nfs_uri(C):
|
||||
command_status = os.system(mount_command)
|
||||
if command_status == 256:
|
||||
raise OSError(
|
||||
f"mount {C['provider_uri']} on {C['mount_path']} error! Needs SUDO! Please mount manually: {mount_command}"
|
||||
f"mount {provider_uri} on {mount_path} error! Needs SUDO! Please mount manually: {mount_command}"
|
||||
)
|
||||
elif command_status == 32512:
|
||||
# LOG.error("Command error")
|
||||
raise OSError(f"mount {C['provider_uri']} on {C['mount_path']} error! Command error")
|
||||
raise OSError(f"mount {provider_uri} on {mount_path} error! Command error")
|
||||
elif command_status == 0:
|
||||
LOG.info("Mount finished")
|
||||
else:
|
||||
|
||||
@@ -9,13 +9,13 @@ from .account import Account
|
||||
if TYPE_CHECKING:
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
from .order import BaseTradeDecision
|
||||
from .order import Order
|
||||
from .decision import BaseTradeDecision
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .backtest import backtest_loop
|
||||
from .backtest import collect_data_loop
|
||||
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from .utils import CommonInfrastructure
|
||||
from .decision import Order
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
from ..config import C
|
||||
@@ -231,10 +231,9 @@ def backtest(
|
||||
|
||||
Returns
|
||||
-------
|
||||
report: Report
|
||||
it records the trading report information
|
||||
It is organized in a dict format
|
||||
indicator: Indicator
|
||||
portfolio_metrics_dict: Dict[PortfolioMetrics]
|
||||
it records the trading portfolio_metrics information
|
||||
indicator_dict: Dict[Indicator]
|
||||
it computes the trading indicator
|
||||
It is organized in a dict format
|
||||
|
||||
@@ -249,9 +248,8 @@ def backtest(
|
||||
exchange_kwargs,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
report, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
|
||||
return report, indicator
|
||||
portfolio_metrics, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
return portfolio_metrics, indicator
|
||||
|
||||
|
||||
def collect_data(
|
||||
|
||||
@@ -4,22 +4,19 @@ from __future__ import annotations
|
||||
import copy
|
||||
from typing import Dict, List, Tuple, TYPE_CHECKING
|
||||
from qlib.utils import init_instance_by_config
|
||||
import warnings
|
||||
import pandas as pd
|
||||
|
||||
from .position import BasePosition, InfPosition, Position
|
||||
from .report import Report, Indicator
|
||||
from .order import BaseTradeDecision, Order
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .exchange import Exchange
|
||||
from .report import PortfolioMetrics, Indicator
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
|
||||
"""
|
||||
rtn & earning in the Account
|
||||
rtn:
|
||||
from order's view
|
||||
1.change if any order is executed, sell order or buy order
|
||||
2.change at the end of today, (today_clse - stock_price) * amount
|
||||
2.change at the end of today, (today_close - stock_price) * amount
|
||||
earning
|
||||
from value of current position
|
||||
earning will be updated at the end of trade date
|
||||
@@ -32,7 +29,7 @@ rtn & earning in the Account
|
||||
|
||||
|
||||
class AccumulatedInfo:
|
||||
"""accumulated trading info, including accumulated return\cost\turnover"""
|
||||
"""accumulated trading info, including accumulated return/cost/turnover"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -94,9 +91,12 @@ class Account:
|
||||
|
||||
self._pos_type = pos_type
|
||||
self._port_metr_enabled = port_metr_enabled
|
||||
self.benchmark_config = None # avoid no attribute error
|
||||
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
|
||||
self.init_cash = init_cash
|
||||
self.current: BasePosition = init_instance_by_config(
|
||||
self.current_position: BasePosition = init_instance_by_config(
|
||||
{
|
||||
"class": self._pos_type,
|
||||
"kwargs": {
|
||||
@@ -106,37 +106,33 @@ class Account:
|
||||
"module_path": "qlib.backtest.position",
|
||||
}
|
||||
)
|
||||
self.report = None
|
||||
self.positions = {}
|
||||
|
||||
# in of reset ignore None values
|
||||
self.benchmark_config = benchmark_config
|
||||
self.freq = freq
|
||||
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
|
||||
self.portfolio_metrics = None
|
||||
self.hist_positions = {}
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config)
|
||||
|
||||
def is_port_metr_enabled(self):
|
||||
"""
|
||||
Is portfolio-based metrics enabled.
|
||||
"""
|
||||
return self._port_metr_enabled and not self.current.skip_update()
|
||||
return self._port_metr_enabled and not self.current_position.skip_update()
|
||||
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
# portfolio related metrics
|
||||
if self.is_port_metr_enabled():
|
||||
self.accum_info = AccumulatedInfo()
|
||||
self.report = Report(freq, benchmark_config)
|
||||
self.positions = {}
|
||||
self.portfolio_metrics = PortfolioMetrics(freq, benchmark_config)
|
||||
self.hist_positions = {}
|
||||
|
||||
# fill stock value
|
||||
# The frequency of account may not align with the trading frequency.
|
||||
# This may result in obscure bugs when data quality is low.
|
||||
if isinstance(self.benchmark_config, dict) and self.benchmark_config.get("start_time") is not None:
|
||||
self.current.fill_stock_value(self.benchmark_config["start_time"], self.freq)
|
||||
self.current_position.fill_stock_value(self.benchmark_config["start_time"], self.freq)
|
||||
|
||||
# trading related metrics(e.g. high-frequency trading)
|
||||
self.indicator = Indicator()
|
||||
|
||||
def reset(self, freq=None, benchmark_config=None, init_report=False, port_metr_enabled: bool = None):
|
||||
def reset(self, freq=None, benchmark_config=None, port_metr_enabled: bool = None):
|
||||
"""reset freq and report of account
|
||||
|
||||
Parameters
|
||||
@@ -145,27 +141,23 @@ class Account:
|
||||
frequency of account & report, by default None
|
||||
benchmark_config : {}, optional
|
||||
benchmark config of report, by default None
|
||||
init_report : bool, optional
|
||||
whether to initialize the report, by default False
|
||||
"""
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
if benchmark_config is not None:
|
||||
self.benchmark_config = benchmark_config
|
||||
|
||||
if port_metr_enabled is not None:
|
||||
self._port_metr_enabled = port_metr_enabled
|
||||
|
||||
if freq is not None or benchmark_config is not None or init_report:
|
||||
self.reset_report(self.freq, self.benchmark_config)
|
||||
|
||||
def get_positions(self):
|
||||
return self.positions
|
||||
def get_hist_positions(self):
|
||||
return self.hist_positions
|
||||
|
||||
def get_cash(self):
|
||||
return self.current.get_cash()
|
||||
return self.current_position.get_cash()
|
||||
|
||||
def _update_accum_info_from_order(self, order, trade_val, cost, trade_price):
|
||||
def _update_state_from_order(self, order, trade_val, cost, trade_price):
|
||||
if self.is_port_metr_enabled():
|
||||
# update turnover
|
||||
self.accum_info.add_turnover(trade_val)
|
||||
@@ -176,17 +168,17 @@ class Account:
|
||||
trade_amount = trade_val / trade_price
|
||||
if order.direction == Order.SELL: # 0 for sell
|
||||
# when sell stock, get profit from price change
|
||||
profit = trade_val - self.current.get_stock_price(order.stock_id) * trade_amount
|
||||
profit = trade_val - self.current_position.get_stock_price(order.stock_id) * trade_amount
|
||||
self.accum_info.add_return_value(profit) # note here do not consider cost
|
||||
|
||||
elif order.direction == Order.BUY: # 1 for buy
|
||||
# when buy stock, we get return for the rtn computing method
|
||||
# profit in buy order is to make rtn is consistent with earning at the end of bar
|
||||
profit = self.current.get_stock_price(order.stock_id) * trade_amount - trade_val
|
||||
profit = self.current_position.get_stock_price(order.stock_id) * trade_amount - trade_val
|
||||
self.accum_info.add_return_value(profit) # note here do not consider cost
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
if self.current.skip_update():
|
||||
if self.current_position.skip_update():
|
||||
# TODO: supporting polymorphism for account
|
||||
# updating order for infinite position is meaningless
|
||||
return
|
||||
@@ -196,65 +188,61 @@ class Account:
|
||||
# The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
|
||||
if order.direction == Order.SELL:
|
||||
# sell stock
|
||||
self._update_accum_info_from_order(order, trade_val, cost, trade_price)
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
# update current position
|
||||
# for may sell all of stock_id
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
self.current_position.update_order(order, trade_val, cost, trade_price)
|
||||
else:
|
||||
# buy stock
|
||||
# deal order, then update state
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
self._update_accum_info_from_order(order, trade_val, cost, trade_price)
|
||||
self.current_position.update_order(order, trade_val, cost, trade_price)
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_bar_count(self):
|
||||
"""at the end of the trading bar, update holding bar, count of stock"""
|
||||
# update holding day count
|
||||
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
|
||||
if not self.current.skip_update():
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
|
||||
def update_current(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""update current to make rtn consistent with earning at the end of bar"""
|
||||
def update_current_position(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
|
||||
if not self.current.skip_update():
|
||||
stock_list = self.current.get_stock_list()
|
||||
if not self.current_position.skip_update():
|
||||
stock_list = self.current_position.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
self.current_position.update_stock_price(stock_id=code, price=bar_close)
|
||||
# update holding day count
|
||||
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
|
||||
self.current_position.add_count_all(bar=self.freq)
|
||||
|
||||
def update_report(self, trade_start_time, trade_end_time):
|
||||
"""update position history, report"""
|
||||
def update_portfolio_metrics(self, trade_start_time, trade_end_time):
|
||||
"""update portfolio_metrics"""
|
||||
# calculate earning
|
||||
# account_value - last_account_value
|
||||
# for the first trade date, account_value - init_cash
|
||||
# self.report.is_empty() to judge is_first_trade_date
|
||||
# self.portfolio_metrics.is_empty() to judge is_first_trade_date
|
||||
# get last_account_value, last_total_cost, last_total_turnover
|
||||
if self.report.is_empty():
|
||||
if self.portfolio_metrics.is_empty():
|
||||
last_account_value = self.init_cash
|
||||
last_total_cost = 0
|
||||
last_total_turnover = 0
|
||||
else:
|
||||
last_account_value = self.report.get_latest_account_value()
|
||||
last_total_cost = self.report.get_latest_total_cost()
|
||||
last_total_turnover = self.report.get_latest_total_turnover()
|
||||
last_account_value = self.portfolio_metrics.get_latest_account_value()
|
||||
last_total_cost = self.portfolio_metrics.get_latest_total_cost()
|
||||
last_total_turnover = self.portfolio_metrics.get_latest_total_turnover()
|
||||
# get now_account_value, now_stock_value, now_earning, now_cost, now_turnover
|
||||
now_account_value = self.current.calculate_value()
|
||||
now_stock_value = self.current.calculate_stock_value()
|
||||
now_account_value = self.current_position.calculate_value()
|
||||
now_stock_value = self.current_position.calculate_stock_value()
|
||||
now_earning = now_account_value - last_account_value
|
||||
now_cost = self.accum_info.get_cost - last_total_cost
|
||||
now_turnover = self.accum_info.get_turnover - last_total_turnover
|
||||
# update report for today
|
||||
# update portfolio_metrics for today
|
||||
# judge whether the the trading is begin.
|
||||
# and don't add init account state into report, due to we don't have excess return in those days.
|
||||
self.report.update_report_record(
|
||||
# and don't add init account state into portfolio_metrics, due to we don't have excess return in those days.
|
||||
self.portfolio_metrics.update_portfolio_metrics_record(
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
account_value=now_account_value,
|
||||
cash=self.current.position["cash"],
|
||||
cash=self.current_position.position["cash"],
|
||||
return_rate=(now_earning + now_cost) / last_account_value,
|
||||
# here use earning to calculate return, position's view, earning consider cost, true return
|
||||
# in order to make same definition with original backtest in evaluate.py
|
||||
@@ -264,12 +252,51 @@ class Account:
|
||||
cost_rate=now_cost / last_account_value,
|
||||
stock_value=now_stock_value,
|
||||
)
|
||||
|
||||
def update_hist_positions(self, trade_start_time):
|
||||
"""update history position"""
|
||||
now_account_value = self.current_position.calculate_value()
|
||||
# set now_account_value to position
|
||||
self.current.position["now_account_value"] = now_account_value
|
||||
self.current.update_weight_all()
|
||||
# update positions
|
||||
self.current_position.position["now_account_value"] = now_account_value
|
||||
self.current_position.update_weight_all()
|
||||
# update hist_positions
|
||||
# note use deepcopy
|
||||
self.positions[trade_start_time] = copy.deepcopy(self.current)
|
||||
self.hist_positions[trade_start_time] = copy.deepcopy(self.current_position)
|
||||
|
||||
def update_indicator(
|
||||
self,
|
||||
trade_start_time: pd.Timestamp,
|
||||
trade_exchange: Exchange,
|
||||
atomic: bool,
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_info: list = None,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
indicator_config: dict = {},
|
||||
):
|
||||
"""update trade indicators and order indicators in each bar end"""
|
||||
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
|
||||
|
||||
# indicator is trading (e.g. high-frequency order execution) related analysis
|
||||
self.indicator.reset()
|
||||
|
||||
# aggregate the information for each order
|
||||
if atomic:
|
||||
self.indicator.update_order_indicators(trade_info)
|
||||
else:
|
||||
self.indicator.agg_order_indicators(
|
||||
inner_order_indicators,
|
||||
decision_list=decision_list,
|
||||
outer_trade_decision=outer_trade_decision,
|
||||
trade_exchange=trade_exchange,
|
||||
indicator_config=indicator_config,
|
||||
)
|
||||
|
||||
# aggregate all the order metrics a single step
|
||||
self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config)
|
||||
|
||||
# record the metrics
|
||||
self.indicator.record(trade_start_time)
|
||||
|
||||
def update_bar_end(
|
||||
self,
|
||||
@@ -316,44 +343,34 @@ class Account:
|
||||
elif atomic is False and inner_order_indicators is None:
|
||||
raise ValueError("inner_order_indicators is necessary in un-atomic executor")
|
||||
|
||||
# TODO: `update_bar_count` and `update_current` should placed in Position and be merged.
|
||||
self.update_bar_count()
|
||||
self.update_current(trade_start_time, trade_end_time, trade_exchange)
|
||||
# update current position and hold bar count in each bar end
|
||||
self.update_current_position(trade_start_time, trade_end_time, trade_exchange)
|
||||
|
||||
if self.is_port_metr_enabled():
|
||||
# report is portfolio related analysis
|
||||
self.update_report(trade_start_time, trade_end_time)
|
||||
# portfolio_metrics is portfolio related analysis
|
||||
self.update_portfolio_metrics(trade_start_time, trade_end_time)
|
||||
self.update_hist_positions(trade_start_time)
|
||||
|
||||
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
|
||||
|
||||
# indicator is trading (e.g. high-frequency order execution) related analysis
|
||||
self.indicator.reset()
|
||||
|
||||
# aggregate the information for each order
|
||||
if atomic:
|
||||
self.indicator.update_order_indicators(trade_info)
|
||||
else:
|
||||
self.indicator.agg_order_indicators(
|
||||
inner_order_indicators,
|
||||
decision_list=decision_list,
|
||||
outer_trade_decision=outer_trade_decision,
|
||||
# update indicator in each bar end
|
||||
self.update_indicator(
|
||||
trade_start_time=trade_start_time,
|
||||
trade_exchange=trade_exchange,
|
||||
atomic=atomic,
|
||||
outer_trade_decision=outer_trade_decision,
|
||||
trade_info=trade_info,
|
||||
inner_order_indicators=inner_order_indicators,
|
||||
decision_list=decision_list,
|
||||
indicator_config=indicator_config,
|
||||
)
|
||||
|
||||
# aggregate all the order metrics a single step
|
||||
self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config)
|
||||
|
||||
# record the metrics
|
||||
self.indicator.record(trade_start_time)
|
||||
|
||||
def get_report(self):
|
||||
"""get the history report and postions instance"""
|
||||
def get_portfolio_metrics(self):
|
||||
"""get the history portfolio_metrics and postions instance"""
|
||||
if self.is_port_metr_enabled():
|
||||
_report = self.report.generate_report_dataframe()
|
||||
_positions = self.get_positions()
|
||||
return _report, _positions
|
||||
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
|
||||
_positions = self.get_hist_positions()
|
||||
return _portfolio_metrics, _positions
|
||||
else:
|
||||
raise ValueError("generate_report should be True if you want to generate report")
|
||||
raise ValueError("generate_portfolio_metrics should be True if you want to generate portfolio_metrics")
|
||||
|
||||
def get_trade_indicator(self) -> Indicator:
|
||||
"""get the trade indicator instance, which has pa/pos/ffr info."""
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
from qlib.backtest.order import BaseTradeDecision
|
||||
from qlib.backtest.decision import BaseTradeDecision
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -19,15 +19,15 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec
|
||||
|
||||
Returns
|
||||
-------
|
||||
report: Report
|
||||
it records the trading report information
|
||||
portfolio_metrics: PortfolioMetrics
|
||||
it records the trading portfolio_metrics information
|
||||
indicator: Indicator
|
||||
it computes the trading indicator
|
||||
"""
|
||||
return_value = {}
|
||||
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||
pass
|
||||
return return_value.get("report"), return_value.get("indicator")
|
||||
return return_value.get("portfolio_metrics"), return_value.get("indicator")
|
||||
|
||||
|
||||
def collect_data_loop(
|
||||
@@ -68,9 +68,8 @@ def collect_data_loop(
|
||||
|
||||
if return_value is not None:
|
||||
all_executors = trade_executor.get_all_executors()
|
||||
|
||||
all_reports = {
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_report()
|
||||
all_portfolio_metrics = {
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
|
||||
for _executor in all_executors
|
||||
if _executor.trade_account.is_port_metr_enabled()
|
||||
}
|
||||
@@ -79,4 +78,4 @@ def collect_data_loop(
|
||||
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
|
||||
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
|
||||
return_value.update({"report": all_reports, "indicator": all_indicators})
|
||||
return_value.update({"portfolio_metrics": all_portfolio_metrics, "indicator": all_indicators})
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
# TODO: rename it with decision.py
|
||||
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from qlib.data.data import Cal
|
||||
@@ -15,10 +15,9 @@ import pandas as pd
|
||||
|
||||
from ..data.data import D
|
||||
from ..config import C, REG_CN
|
||||
from ..utils.resam import resam_ts_data, ts_data_last
|
||||
from ..log import get_module_logger
|
||||
from .order import Order, OrderDir, OrderHelper
|
||||
from .high_performance_ds import BaseQuote, PandasQuote, CN1minNumpyQuote
|
||||
from .decision import Order, OrderDir, OrderHelper
|
||||
from .high_performance_ds import BaseQuote, PandasQuote, NumpyQuote
|
||||
|
||||
|
||||
class Exchange:
|
||||
@@ -36,29 +35,24 @@ class Exchange:
|
||||
close_cost=0.0025,
|
||||
min_cost=5,
|
||||
extra_quote=None,
|
||||
quote_cls=CN1minNumpyQuote,
|
||||
quote_cls=NumpyQuote,
|
||||
**kwargs,
|
||||
):
|
||||
"""__init__
|
||||
|
||||
:param freq: frequency of data
|
||||
:param start_time: closed start time for backtest
|
||||
:param end_time: closed end time for backtest
|
||||
:param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
|
||||
|
||||
:param deal_price: Union[str, Tuple[str, str], List[str]]
|
||||
The `deal_price` supports following two types of input
|
||||
- <deal_price> : str
|
||||
- (<buy_price>, <sell_price>): Tuple[str] or List[str]
|
||||
|
||||
<deal_price>, <buy_price> or <sell_price> := <price>
|
||||
<price> := str
|
||||
- for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend
|
||||
"$" to the expression)
|
||||
|
||||
:param subscribe_fields: list, subscribe fields. This expressions will be added to the query and `self.quote`.
|
||||
It is useful when users want more fields to be queried
|
||||
|
||||
:param limit_threshold: Union[Tuple[str, str], float, None]
|
||||
1) `None`: no limitation
|
||||
2) float, 0.1 for example, default None
|
||||
@@ -66,7 +60,6 @@ class Exchange:
|
||||
<the expression for sell stock limitation>)
|
||||
`False` value indicates the stock is tradable
|
||||
`True` value indicates the stock is limited and not tradable
|
||||
|
||||
:param volume_threshold: Union[
|
||||
Dict[
|
||||
"all": ("cum" or "current", limit_str),
|
||||
@@ -85,26 +78,22 @@ class Exchange:
|
||||
- "current" means that this is a real-time value and will not accumulate over time,
|
||||
so it can be directly used as a capacity limit.
|
||||
e.g. ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"), ("current", "$bidV1")
|
||||
|
||||
2) "all" means the volume limits are both buying and selling.
|
||||
"buy" means the volume limits of buying. "sell" means the volume limits of selling.
|
||||
Different volume limits will be aggregated with min(). If volume_threshold is only
|
||||
("cum" or "current", limit_str) instead of a dict, the volume limits are for
|
||||
both by deault. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
|
||||
|
||||
3) e.g. "volume_threshold": {
|
||||
"all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"),
|
||||
"buy": ("current", "$askV1"),
|
||||
"sell": ("current", "$bidV1"),
|
||||
}
|
||||
|
||||
:param open_cost: cost rate for open, default 0.0015
|
||||
:param close_cost: cost rate for close, default 0.0025
|
||||
:param trade_unit: trade unit, 100 for China A market.
|
||||
None for disable trade unit.
|
||||
**NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must
|
||||
distinguish `not set` and `disable trade_unit`
|
||||
|
||||
:param min_cost: min cost, default 5
|
||||
:param extra_quote: pandas, dataframe consists of
|
||||
columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].
|
||||
@@ -185,7 +174,7 @@ class Exchange:
|
||||
|
||||
# init quote by quote_df
|
||||
self.quote_cls = quote_cls
|
||||
self.quote: BaseQuote = self.quote_cls(self.quote_df)
|
||||
self.quote: BaseQuote = self.quote_cls(self.quote_df, freq)
|
||||
|
||||
def get_quote_from_qlib(self):
|
||||
# get stock data from qlib
|
||||
@@ -273,12 +262,10 @@ class Exchange:
|
||||
preproccess the volume limit.
|
||||
get the fields need to get from qlib.
|
||||
get the volume limit list of buying and selling which is composed of all limits.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
volume_threshold :
|
||||
please refer to the doc of exchange.
|
||||
|
||||
Returns
|
||||
-------
|
||||
fields: set
|
||||
@@ -287,7 +274,6 @@ class Exchange:
|
||||
all volume limits of buying.
|
||||
sell_vol_limit: List[Tuple[str]]
|
||||
all volume limits of selling.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
@@ -324,7 +310,6 @@ class Exchange:
|
||||
- if direction is None, check if tradable for buying and selling.
|
||||
- if direction == Order.BUY, check the if tradable for buying
|
||||
- if direction == Order.SELL, check the sell limit for selling.
|
||||
|
||||
"""
|
||||
if direction is None:
|
||||
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
||||
@@ -372,9 +357,7 @@ class Exchange:
|
||||
):
|
||||
"""
|
||||
Deal order when the actual transaction
|
||||
|
||||
the results section in `Order` will be changed.
|
||||
|
||||
:param order: Deal the order.
|
||||
:param trade_account: Trade account to be updated after dealing the order.
|
||||
:param position: position to be updated after dealing the order.
|
||||
@@ -393,12 +376,12 @@ class Exchange:
|
||||
|
||||
# NOTE: order will be changed in this function
|
||||
trade_price, trade_val, trade_cost = self._calc_trade_info_by_order(
|
||||
order, trade_account.current if trade_account else position, dealt_order_amount
|
||||
order, trade_account.current_position if trade_account else position, dealt_order_amount
|
||||
)
|
||||
if order.deal_amount > 1e-5:
|
||||
# If the order can only be deal 0 amount. Nothing to be updated
|
||||
if trade_val > 1e-5:
|
||||
# If the order can only be deal 0 value. Nothing to be updated
|
||||
# Otherwise, it will result in
|
||||
# 1) some stock with 0 amount in the position
|
||||
# 1) some stock with 0 value in the position
|
||||
# 2) `trade_unit` of trade_cost will be lost in user account
|
||||
if trade_account:
|
||||
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
|
||||
@@ -407,16 +390,17 @@ class Exchange:
|
||||
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time, method=ts_data_last):
|
||||
def get_quote_info(self, stock_id, start_time, end_time, method="ts_data_last"):
|
||||
return self.quote.get_data(stock_id, start_time, end_time, method=method)
|
||||
|
||||
def get_close(self, stock_id, start_time, end_time, method=ts_data_last):
|
||||
def get_close(self, stock_id, start_time, end_time, method="ts_data_last"):
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method)
|
||||
|
||||
def get_volume(self, stock_id, start_time, end_time, method="sum"):
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
|
||||
def get_volume(self, stock_id, start_time, end_time):
|
||||
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method="sum")
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method=ts_data_last):
|
||||
def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method="ts_data_last"):
|
||||
if direction == OrderDir.SELL:
|
||||
pstr = self.sell_price
|
||||
elif direction == OrderDir.BUY:
|
||||
@@ -441,7 +425,7 @@ class Exchange:
|
||||
assert start_time is not None and end_time is not None, "the time range must be given"
|
||||
if stock_id not in self.quote.get_all_stock():
|
||||
return None
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method=ts_data_last)
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method="ts_data_last")
|
||||
|
||||
def generate_amount_position_from_weight_position(
|
||||
self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
|
||||
@@ -449,7 +433,6 @@ class Exchange:
|
||||
"""
|
||||
The generate the target position according to the weight and the cash.
|
||||
NOTE: All the cash will assigned to the tadable stock.
|
||||
|
||||
Parameter:
|
||||
weight_position : dict {stock_id : weight}; allocate cash by weight_position
|
||||
among then, weight must be in this range: 0 < weight < 1
|
||||
@@ -493,7 +476,6 @@ class Exchange:
|
||||
def get_real_deal_amount(self, current_amount, target_amount, factor):
|
||||
"""
|
||||
Calculate the real adjust deal amount when considering the trading unit
|
||||
|
||||
:param current_amount:
|
||||
:param target_amount:
|
||||
:param factor:
|
||||
@@ -516,7 +498,6 @@ class Exchange:
|
||||
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
|
||||
"""
|
||||
Note: some future information is used in this function
|
||||
|
||||
Parameter:
|
||||
target_position : dict { stock_id : amount }
|
||||
current_postion : dict { stock_id : amount}
|
||||
@@ -590,8 +571,10 @@ class Exchange:
|
||||
value = 0
|
||||
for stock_id in amount_dict:
|
||||
if (
|
||||
self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
only_tradable is True
|
||||
and self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
or only_tradable is False
|
||||
):
|
||||
value += (
|
||||
self.get_deal_price(
|
||||
@@ -613,10 +596,8 @@ class Exchange:
|
||||
def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
|
||||
"""
|
||||
get the trade unit of amount based on **factor**
|
||||
|
||||
the factor can be given directly or calculated in given time range and stock id.
|
||||
`factor` has higher priority than `stock_id`, `start_time` and `end_time`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
factor : float
|
||||
@@ -641,7 +622,6 @@ class Exchange:
|
||||
):
|
||||
"""Parameter
|
||||
Please refer to the docs of get_amount_of_trade_unit
|
||||
|
||||
deal_amount : float, adjusted amount
|
||||
factor : float, adjusted factor
|
||||
return : float, real amount
|
||||
@@ -656,11 +636,9 @@ class Exchange:
|
||||
|
||||
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int:
|
||||
"""parse the capacity limit string and return the actual amount of orders that can be executed.
|
||||
|
||||
NOTE:
|
||||
this function will change the order.deal_amount **inplace**
|
||||
- This will make the order info more accurate
|
||||
|
||||
Parameters
|
||||
----------
|
||||
order : Order
|
||||
@@ -694,7 +672,7 @@ class Exchange:
|
||||
order.start_time,
|
||||
order.end_time,
|
||||
field=limit[1],
|
||||
method=ts_data_last,
|
||||
method="ts_data_last",
|
||||
)
|
||||
vol_limit_num.append(limit_value - dealt_order_amount[order.stock_id])
|
||||
else:
|
||||
@@ -709,12 +687,10 @@ class Exchange:
|
||||
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash):
|
||||
"""return the real order amount after cash limit for buying.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_price : float
|
||||
position : cash
|
||||
|
||||
Return
|
||||
----------
|
||||
float
|
||||
@@ -735,9 +711,7 @@ class Exchange:
|
||||
def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount):
|
||||
"""
|
||||
Calculation of trade info
|
||||
|
||||
**NOTE**: Order will be changed in this function
|
||||
|
||||
:param order:
|
||||
:param position: Position
|
||||
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
|
||||
@@ -745,18 +719,27 @@ class Exchange:
|
||||
"""
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
|
||||
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
|
||||
order.deal_amount = order.amount # set to full amount and clip it step by step
|
||||
# Clipping amount first
|
||||
# - It simulates that the order is rejected directly by the exchange due to large order
|
||||
# Another choice is placing it after rounding the order
|
||||
# - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit.
|
||||
self._clip_amount_by_volume(order, dealt_order_amount)
|
||||
|
||||
if order.direction == Order.SELL:
|
||||
cost_ratio = self.close_cost
|
||||
# sell
|
||||
# if we don't know current position, we choose to sell all
|
||||
# Otherwise, we clip the amount based on current position
|
||||
if position is not None:
|
||||
current_amount = (
|
||||
position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
|
||||
)
|
||||
if np.isclose(order.amount, current_amount):
|
||||
# when selling last stock. The amount don't need rounding
|
||||
order.deal_amount = order.amount
|
||||
else:
|
||||
order.deal_amount = self.round_amount_by_trade_unit(min(current_amount, order.amount), order.factor)
|
||||
if not np.isclose(order.deal_amount, current_amount):
|
||||
# when not selling last stock. rounding is necessary
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(current_amount, order.deal_amount), order.factor
|
||||
)
|
||||
|
||||
# in case of negative value of cash
|
||||
if position.get_cash() + order.deal_amount * trade_price < max(
|
||||
@@ -765,33 +748,30 @@ class Exchange:
|
||||
):
|
||||
order.deal_amount = 0
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
else:
|
||||
# TODO: We don't know current position.
|
||||
# We choose to sell all
|
||||
order.deal_amount = order.amount
|
||||
|
||||
elif order.direction == Order.BUY:
|
||||
cost_ratio = self.open_cost
|
||||
# buy
|
||||
if position is not None:
|
||||
cash = position.get_cash()
|
||||
trade_val = order.amount * trade_price
|
||||
trade_val = order.deal_amount * trade_price
|
||||
if cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
|
||||
# The money is not enough
|
||||
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash)
|
||||
order.deal_amount = self.round_amount_by_trade_unit(max_buy_amount, order.factor)
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(max_buy_amount, order.deal_amount), order.factor
|
||||
)
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
else:
|
||||
# The money is enough
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
|
||||
else:
|
||||
# Unknown amount of money. Just round the amount
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("order type {} error".format(order.type))
|
||||
|
||||
self._clip_amount_by_volume(order, dealt_order_amount)
|
||||
trade_val = order.deal_amount * trade_price
|
||||
trade_cost = max(trade_val * cost_ratio, self.min_cost)
|
||||
if trade_val <= 1e-5:
|
||||
|
||||
@@ -11,7 +11,7 @@ from collections import defaultdict
|
||||
|
||||
from qlib.backtest.report import Indicator
|
||||
|
||||
from .order import EmptyTradeDecision, Order, BaseTradeDecision
|
||||
from .decision import EmptyTradeDecision, Order, BaseTradeDecision
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
|
||||
|
||||
@@ -29,7 +29,7 @@ class BaseExecutor:
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
indicator_config: dict = {},
|
||||
generate_report: bool = False,
|
||||
generate_portfolio_metrics: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
@@ -77,8 +77,8 @@ class BaseExecutor:
|
||||
'weight_method': 'value_weighted',
|
||||
}
|
||||
}
|
||||
generate_report : bool, optional
|
||||
whether to generate report, by default False
|
||||
generate_portfolio_metrics : bool, optional
|
||||
whether to generate portfolio_metrics, by default False
|
||||
verbose : bool, optional
|
||||
whether to print trading info, by default False
|
||||
track_data : bool, optional
|
||||
@@ -87,8 +87,8 @@ class BaseExecutor:
|
||||
- Else, `trade_decision` will not be generated
|
||||
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to generate report
|
||||
- If generate_report is None, trade_exchange will be ignored
|
||||
exchange that provides market info, used to generate portfolio_metrics
|
||||
- If generate_portfolio_metrics is None, trade_exchange will be ignored
|
||||
- Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
|
||||
common_infra : CommonInfrastructure, optional:
|
||||
@@ -103,7 +103,7 @@ class BaseExecutor:
|
||||
"""
|
||||
self.time_per_step = time_per_step
|
||||
self.indicator_config = indicator_config
|
||||
self.generate_report = generate_report
|
||||
self.generate_portfolio_metrics = generate_portfolio_metrics
|
||||
self.verbose = verbose
|
||||
self.track_data = track_data
|
||||
self._trade_exchange = trade_exchange
|
||||
@@ -132,7 +132,7 @@ class BaseExecutor:
|
||||
# NOTE: there is a trick in the code.
|
||||
# copy is used instead of deepcopy. So positions are shared
|
||||
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
|
||||
self.trade_account.reset(freq=self.time_per_step, init_report=True, port_metr_enabled=self.generate_report)
|
||||
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
|
||||
|
||||
@property
|
||||
def trade_exchange(self) -> Exchange:
|
||||
@@ -246,7 +246,7 @@ class BaseExecutor:
|
||||
raise ValueError("atomic executor doesn't support specify `range_limit`")
|
||||
|
||||
if self._settle_type != BasePosition.ST_NO:
|
||||
self.trade_account.current.settle_start(self._settle_type)
|
||||
self.trade_account.current_position.settle_start(self._settle_type)
|
||||
|
||||
obj = self._collect_data(trade_decision=trade_decision, level=level)
|
||||
|
||||
@@ -271,7 +271,7 @@ class BaseExecutor:
|
||||
self.trade_calendar.step()
|
||||
|
||||
if self._settle_type != BasePosition.ST_NO:
|
||||
self.trade_account.current.settle_commit()
|
||||
self.trade_account.current_position.settle_commit()
|
||||
|
||||
if return_value is not None:
|
||||
return_value.update({"execute_result": res})
|
||||
@@ -296,7 +296,7 @@ class NestedExecutor(BaseExecutor):
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
indicator_config: dict = {},
|
||||
generate_report: bool = False,
|
||||
generate_portfolio_metrics: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
skip_empty_decision: bool = True,
|
||||
@@ -335,7 +335,7 @@ class NestedExecutor(BaseExecutor):
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
indicator_config=indicator_config,
|
||||
generate_report=generate_report,
|
||||
generate_portfolio_metrics=generate_portfolio_metrics,
|
||||
verbose=verbose,
|
||||
track_data=track_data,
|
||||
common_infra=common_infra,
|
||||
@@ -444,7 +444,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
indicator_config: dict = {},
|
||||
generate_report: bool = False,
|
||||
generate_portfolio_metrics: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
@@ -462,7 +462,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
indicator_config=indicator_config,
|
||||
generate_report=generate_report,
|
||||
generate_portfolio_metrics=generate_portfolio_metrics,
|
||||
verbose=verbose,
|
||||
track_data=track_data,
|
||||
common_infra=common_infra,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import List, Text, Union, Callable, Iterable, Dict
|
||||
@@ -14,12 +13,12 @@ import numpy as np
|
||||
from ..utils.index_data import IndexData, SingleData
|
||||
from ..utils.resam import resam_ts_data, ts_data_last
|
||||
from ..log import get_module_logger
|
||||
from ..utils.time import is_single_value
|
||||
from ..utils.time import is_single_value, Freq
|
||||
import qlib.utils.index_data as idd
|
||||
|
||||
|
||||
class BaseQuote:
|
||||
def __init__(self, quote_df: pd.DataFrame):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq):
|
||||
self.logger = get_module_logger("online operator", level=logging.INFO)
|
||||
|
||||
def get_all_stock(self) -> Iterable:
|
||||
@@ -39,7 +38,7 @@ class BaseQuote:
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
field: Union[str],
|
||||
method: Union[str, Callable, None] = None,
|
||||
method: Union[str, None] = None,
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
"""get the specific field of stock data during start time and end_time,
|
||||
and apply method to the data.
|
||||
@@ -83,9 +82,9 @@ class BaseQuote:
|
||||
closed end time for backtest
|
||||
field : str
|
||||
the columns of data to fetch
|
||||
method : Union[str, Callable, None]
|
||||
method : Union[str, None]
|
||||
the method apply to data.
|
||||
e.g [None, "last", "all", "sum", "mean", qlib/utils/resam.py/ts_data_last]
|
||||
e.g [None, "last", "all", "sum", "mean", "ts_data_last"]
|
||||
|
||||
Return
|
||||
----------
|
||||
@@ -99,8 +98,8 @@ class BaseQuote:
|
||||
|
||||
|
||||
class PandasQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame):
|
||||
super().__init__(quote_df=quote_df)
|
||||
def __init__(self, quote_df: pd.DataFrame, freq):
|
||||
super().__init__(quote_df=quote_df, freq=freq)
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
quote_dict[stock_id] = stock_val.droplevel(level="instrument")
|
||||
@@ -110,6 +109,8 @@ class PandasQuote(BaseQuote):
|
||||
return self.data.keys()
|
||||
|
||||
def get_data(self, stock_id, start_time, end_time, field, method=None):
|
||||
if method == "ts_data_last":
|
||||
method = ts_data_last
|
||||
stock_data = resam_ts_data(self.data[stock_id][field], start_time, end_time, method=method)
|
||||
if stock_data is None:
|
||||
return None
|
||||
@@ -121,9 +122,9 @@ class PandasQuote(BaseQuote):
|
||||
raise ValueError(f"stock data from resam_ts_data must be a number, pd.Series or pd.DataFrame")
|
||||
|
||||
|
||||
class CN1minNumpyQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame):
|
||||
"""CN1minNumpyQuote
|
||||
class NumpyQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq, region="cn"):
|
||||
"""NumpyQuote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -131,13 +132,19 @@ class CN1minNumpyQuote(BaseQuote):
|
||||
the init dataframe from qlib.
|
||||
self.data : Dict(stock_id, IndexData.DataFrame)
|
||||
"""
|
||||
super().__init__(quote_df=quote_df)
|
||||
super().__init__(quote_df=quote_df, freq=freq)
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
quote_dict[stock_id] = idd.MultiData(stock_val.droplevel(level="instrument"))
|
||||
quote_dict[stock_id].sort_index() # To support more flexible slicing, we must sort data first
|
||||
self.data = quote_dict
|
||||
self.freq = pd.Timedelta(minutes=1)
|
||||
|
||||
n, unit = Freq.parse(freq)
|
||||
if unit in Freq.SUPPORT_CAL_LIST:
|
||||
self.freq = Freq.get_timedelta(1, unit)
|
||||
else:
|
||||
raise ValueError(f"{freq} is not supported in NumpyQuote")
|
||||
self.region = region
|
||||
|
||||
def get_all_stock(self):
|
||||
return self.data.keys()
|
||||
@@ -150,7 +157,7 @@ class CN1minNumpyQuote(BaseQuote):
|
||||
|
||||
# single data
|
||||
# If it don't consider the classification of single data, it will consume a lot of time.
|
||||
if is_single_value(start_time, end_time, self.freq):
|
||||
if is_single_value(start_time, end_time, self.freq, self.region):
|
||||
# this is a very special case.
|
||||
# skip aggregating function to speed-up the query calculation
|
||||
try:
|
||||
@@ -178,9 +185,7 @@ class CN1minNumpyQuote(BaseQuote):
|
||||
return data[-1]
|
||||
elif method == "all":
|
||||
return data.all()
|
||||
elif method == "any":
|
||||
return data.any()
|
||||
elif method == ts_data_last:
|
||||
elif method == "ts_data_last":
|
||||
valid_data = data.loc[~data.isna().data.astype(bool)]
|
||||
if len(valid_data) == 0:
|
||||
return None
|
||||
|
||||
@@ -10,7 +10,7 @@ import pandas as pd
|
||||
from datetime import timedelta
|
||||
import numpy as np
|
||||
|
||||
from .order import Order
|
||||
from .decision import Order
|
||||
from ..data.data import D
|
||||
|
||||
|
||||
@@ -151,7 +151,8 @@ class BasePosition:
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
"""
|
||||
generate stock weight dict {stock_id : value weight of stock in the position}
|
||||
it is meaningful in the beginning or the end of each trade date
|
||||
it is meaningful in the beginning or the end of each trade step
|
||||
- During execution of each trading step, the weight may be not consistant with the portfolio value
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -408,7 +409,7 @@ class Position(BasePosition):
|
||||
return self.position[code]["price"]
|
||||
|
||||
def get_stock_amount(self, code):
|
||||
return self.position[code]["amount"]
|
||||
return self.position[code]["amount"] if code in self.position else 0
|
||||
|
||||
def get_stock_count(self, code, bar):
|
||||
"""the days the account has been hold, it may be used in some special strategies"""
|
||||
@@ -531,7 +532,7 @@ class InfPosition(BasePosition):
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
raise NotImplementedError(f"InfPosition doesn't support add_count_all")
|
||||
|
||||
def update_weight_all(self):
|
||||
raise NotImplementedError(f"InfPosition doesn't support update_weight_all")
|
||||
|
||||
@@ -18,6 +18,7 @@ def get_benchmark_weight(
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
path=None,
|
||||
freq="day",
|
||||
):
|
||||
"""get_benchmark_weight
|
||||
|
||||
@@ -27,6 +28,7 @@ def get_benchmark_weight(
|
||||
:param start_date:
|
||||
:param end_date:
|
||||
:param path:
|
||||
:param freq:
|
||||
|
||||
:return: The weight distribution of the the benchmark described by a pandas dataframe
|
||||
Every row corresponds to a trading day.
|
||||
@@ -35,7 +37,7 @@ def get_benchmark_weight(
|
||||
|
||||
"""
|
||||
if not path:
|
||||
path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
|
||||
path = Path(C.dpm.get_data_uri(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
|
||||
# TODO: the storage of weights should be implemented in a more elegent way
|
||||
# TODO: The benchmark is not consistant with the filename in instruments.
|
||||
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
|
||||
@@ -224,6 +226,7 @@ def brinson_pa(
|
||||
group_method="category",
|
||||
group_n=None,
|
||||
deal_price="vwap",
|
||||
freq="day",
|
||||
):
|
||||
"""brinson profit attribution
|
||||
|
||||
@@ -245,7 +248,7 @@ def brinson_pa(
|
||||
|
||||
start_date, end_date = min(dates), max(dates)
|
||||
|
||||
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date)
|
||||
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date, freq)
|
||||
|
||||
# The attributes for allocation will not
|
||||
if not group_field.startswith("$"):
|
||||
@@ -261,13 +264,14 @@ def brinson_pa(
|
||||
start_time=shift_start_date,
|
||||
end_time=end_date,
|
||||
as_list=True,
|
||||
freq=freq,
|
||||
)
|
||||
stock_df = D.features(
|
||||
instruments,
|
||||
[group_field, deal_price],
|
||||
start_time=shift_start_date,
|
||||
end_time=end_date,
|
||||
freq="day",
|
||||
freq=freq,
|
||||
)
|
||||
stock_df.columns = [group_field, "deal_price"]
|
||||
|
||||
|
||||
@@ -10,21 +10,24 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.order import BaseTradeDecision, Order, OrderDir
|
||||
from .decision import IdxTradeRange
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from ..data import D
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
from .order import IdxTradeRange
|
||||
import qlib.utils.index_data as idd
|
||||
|
||||
|
||||
class Report:
|
||||
class PortfolioMetrics:
|
||||
"""
|
||||
Motivation:
|
||||
Report is for supporting portfolio related metrics.
|
||||
PortfolioMetrics is for supporting portfolio related metrics.
|
||||
|
||||
Implementation:
|
||||
daily report of the account
|
||||
|
||||
daily portfolio metrics of the account
|
||||
contain those followings: return, cost, turnover, account, cash, bench, value
|
||||
For each step(bar/day/minute), each column represents
|
||||
- return: the return of the portfolio generated by strategy **without transaction fee**.
|
||||
@@ -79,7 +82,7 @@ class Report:
|
||||
self.values = OrderedDict() # value for each trade time
|
||||
self.cashes = OrderedDict()
|
||||
self.benches = OrderedDict()
|
||||
self.latest_report_time = None # pd.TimeStamp
|
||||
self.latest_pm_time = None # pd.TimeStamp
|
||||
|
||||
def init_bench(self, freq=None, benchmark_config=None):
|
||||
if freq is not None:
|
||||
@@ -123,18 +126,18 @@ class Report:
|
||||
return len(self.accounts) == 0
|
||||
|
||||
def get_latest_date(self):
|
||||
return self.latest_report_time
|
||||
return self.latest_pm_time
|
||||
|
||||
def get_latest_account_value(self):
|
||||
return self.accounts[self.latest_report_time]
|
||||
return self.accounts[self.latest_pm_time]
|
||||
|
||||
def get_latest_total_cost(self):
|
||||
return self.total_costs[self.latest_report_time]
|
||||
return self.total_costs[self.latest_pm_time]
|
||||
|
||||
def get_latest_total_turnover(self):
|
||||
return self.total_turnovers[self.latest_report_time]
|
||||
return self.total_turnovers[self.latest_pm_time]
|
||||
|
||||
def update_report_record(
|
||||
def update_portfolio_metrics_record(
|
||||
self,
|
||||
trade_start_time=None,
|
||||
trade_end_time=None,
|
||||
@@ -169,7 +172,7 @@ class Report:
|
||||
elif bench_value is None:
|
||||
bench_value = self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
|
||||
|
||||
# update report data
|
||||
# update pm data
|
||||
self.accounts[trade_start_time] = account_value
|
||||
self.returns[trade_start_time] = return_rate
|
||||
self.total_turnovers[trade_start_time] = total_turnover
|
||||
@@ -179,30 +182,30 @@ class Report:
|
||||
self.values[trade_start_time] = stock_value
|
||||
self.cashes[trade_start_time] = cash
|
||||
self.benches[trade_start_time] = bench_value
|
||||
# update latest_report_date
|
||||
self.latest_report_time = trade_start_time
|
||||
# finish report update in each step
|
||||
# update pm
|
||||
self.latest_pm_time = trade_start_time
|
||||
# finish pm update in each step
|
||||
|
||||
def generate_report_dataframe(self):
|
||||
report = pd.DataFrame()
|
||||
report["account"] = pd.Series(self.accounts)
|
||||
report["return"] = pd.Series(self.returns)
|
||||
report["total_turnover"] = pd.Series(self.total_turnovers)
|
||||
report["turnover"] = pd.Series(self.turnovers)
|
||||
report["total_cost"] = pd.Series(self.total_costs)
|
||||
report["cost"] = pd.Series(self.costs)
|
||||
report["value"] = pd.Series(self.values)
|
||||
report["cash"] = pd.Series(self.cashes)
|
||||
report["bench"] = pd.Series(self.benches)
|
||||
report.index.name = "datetime"
|
||||
return report
|
||||
def generate_portfolio_metrics_dataframe(self):
|
||||
pm = pd.DataFrame()
|
||||
pm["account"] = pd.Series(self.accounts)
|
||||
pm["return"] = pd.Series(self.returns)
|
||||
pm["total_turnover"] = pd.Series(self.total_turnovers)
|
||||
pm["turnover"] = pd.Series(self.turnovers)
|
||||
pm["total_cost"] = pd.Series(self.total_costs)
|
||||
pm["cost"] = pd.Series(self.costs)
|
||||
pm["value"] = pd.Series(self.values)
|
||||
pm["cash"] = pd.Series(self.cashes)
|
||||
pm["bench"] = pd.Series(self.benches)
|
||||
pm.index.name = "datetime"
|
||||
return pm
|
||||
|
||||
def save_report(self, path):
|
||||
r = self.generate_report_dataframe()
|
||||
def save_portfolio_metrics(self, path):
|
||||
r = self.generate_portfolio_metrics_dataframe()
|
||||
r.to_csv(path)
|
||||
|
||||
def load_report(self, path):
|
||||
"""load report from a file
|
||||
def load_portfolio_metrics(self, path):
|
||||
"""load pm from a file
|
||||
should have format like
|
||||
columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
|
||||
:param
|
||||
@@ -215,7 +218,7 @@ class Report:
|
||||
index = r.index
|
||||
self.init_vars()
|
||||
for trade_start_time in index:
|
||||
self.update_report_record(
|
||||
self.update_portfolio_metrics_record(
|
||||
trade_start_time=trade_start_time,
|
||||
account_value=r.loc[trade_start_time]["account"],
|
||||
cash=r.loc[trade_start_time]["cash"],
|
||||
@@ -376,8 +379,6 @@ class Indicator:
|
||||
price = pa_config.get("price", "deal_price").lower()
|
||||
|
||||
if decision.trade_range is not None:
|
||||
if isinstance(decision.trade_range, IdxTradeRange):
|
||||
raise TypeError(f"IdxTradeRange is not supported")
|
||||
trade_start_time, trade_end_time = decision.trade_range.clip_time_range(
|
||||
start_time=trade_start_time, end_time=trade_end_time
|
||||
)
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
import bisect
|
||||
from qlib.utils.time import epsilon_change
|
||||
from typing import Union, TYPE_CHECKING, Tuple, Union, List, Set
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.order import BaseTradeDecision
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.decision import BaseTradeDecision
|
||||
|
||||
import pandas as pd
|
||||
import warnings
|
||||
|
||||
from ..utils.resam import get_resam_calendar
|
||||
from ..data.data import Cal
|
||||
|
||||
|
||||
@@ -56,9 +55,9 @@ class TradeCalendarManager:
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(end_time) if end_time else None
|
||||
|
||||
_calendar, freq, freq_sam = get_resam_calendar(freq=freq)
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
self._calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam)
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq)
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
|
||||
120
qlib/config.py
@@ -15,8 +15,10 @@ import os
|
||||
import re
|
||||
import copy
|
||||
import logging
|
||||
import platform
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -73,6 +75,12 @@ REG_US = "us"
|
||||
|
||||
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
|
||||
|
||||
DISK_DATASET_CACHE = "DiskDatasetCache"
|
||||
SIMPLE_DATASET_CACHE = "SimpleDatasetCache"
|
||||
DISK_EXPRESSION_CACHE = "DiskExpressionCache"
|
||||
|
||||
DEPENDENCY_REDIS_CACHE = (DISK_DATASET_CACHE, DISK_EXPRESSION_CACHE)
|
||||
|
||||
_default_config = {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
@@ -82,6 +90,15 @@ _default_config = {
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
# config it in qlib.init()
|
||||
# "provider_uri" str or dict:
|
||||
# # str
|
||||
# "~/.qlib/stock_data/cn_data"
|
||||
# # dict
|
||||
# {"day": "~/.qlib/stock_data/cn_data", "1min": "~/.qlib/stock_data/cn_data_1min"}
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
|
||||
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
|
||||
# 3. qlib.init: provider_uri
|
||||
"provider_uri": "",
|
||||
# cache
|
||||
"expression_cache": None,
|
||||
@@ -173,8 +190,9 @@ MODE_CONF = {
|
||||
"redis_task_db": 1,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# cache
|
||||
"expression_cache": "DiskExpressionCache",
|
||||
"dataset_cache": "DiskDatasetCache",
|
||||
"expression_cache": DISK_EXPRESSION_CACHE,
|
||||
"dataset_cache": DISK_DATASET_CACHE,
|
||||
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
|
||||
"mount_path": None,
|
||||
},
|
||||
"client": {
|
||||
@@ -189,8 +207,10 @@ MODE_CONF = {
|
||||
"provider_uri": "~/.qlib/qlib_data/cn_data",
|
||||
# cache
|
||||
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
|
||||
"expression_cache": "DiskExpressionCache",
|
||||
"dataset_cache": "DiskDatasetCache",
|
||||
"expression_cache": DISK_EXPRESSION_CACHE,
|
||||
"dataset_cache": DISK_DATASET_CACHE,
|
||||
# SimpleDatasetCache directory
|
||||
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
|
||||
"calendar_cache": None,
|
||||
# client config
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
@@ -234,11 +254,43 @@ class QlibConfig(Config):
|
||||
# URI_TYPE
|
||||
LOCAL_URI = "local"
|
||||
NFS_URI = "nfs"
|
||||
DEFAULT_FREQ = "__DEFAULT_FREQ"
|
||||
|
||||
def __init__(self, default_conf):
|
||||
super().__init__(default_conf)
|
||||
self._registered = False
|
||||
|
||||
class DataPathManager:
|
||||
def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]):
|
||||
self.provider_uri = provider_uri
|
||||
self.mount_path = mount_path
|
||||
|
||||
@staticmethod
|
||||
def get_uri_type(uri: Union[str, Path]):
|
||||
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
|
||||
is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:'
|
||||
# such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None
|
||||
|
||||
if is_nfs_or_win and not is_win:
|
||||
return QlibConfig.NFS_URI
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def get_data_uri(self, freq: str = None) -> Path:
|
||||
if freq is None or freq not in self.provider_uri:
|
||||
freq = QlibConfig.DEFAULT_FREQ
|
||||
_provider_uri = self.provider_uri[freq]
|
||||
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
|
||||
return Path(_provider_uri)
|
||||
elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
|
||||
if "win" in platform.system().lower():
|
||||
# windows, mount_path is the drive
|
||||
return Path(f"{self.mount_path[freq]}:\\")
|
||||
return Path(self.mount_path[freq])
|
||||
else:
|
||||
raise NotImplementedError(f"This type of uri is not supported")
|
||||
|
||||
def set_mode(self, mode):
|
||||
# raise KeyError
|
||||
self.update(MODE_CONF[mode])
|
||||
@@ -248,13 +300,42 @@ class QlibConfig(Config):
|
||||
# raise KeyError
|
||||
self.update(_default_region_config[region])
|
||||
|
||||
@staticmethod
|
||||
def is_depend_redis(cache_name: str):
|
||||
return cache_name in DEPENDENCY_REDIS_CACHE
|
||||
|
||||
@property
|
||||
def dpm(self):
|
||||
return self.DataPathManager(self["provider_uri"], self["mount_path"])
|
||||
|
||||
def resolve_path(self):
|
||||
# resolve path
|
||||
if self["mount_path"] is not None:
|
||||
self["mount_path"] = str(Path(self["mount_path"]).expanduser().resolve())
|
||||
_mount_path = self["mount_path"]
|
||||
_provider_uri = self["provider_uri"]
|
||||
if _provider_uri is None:
|
||||
raise ValueError("provider_uri cannot be None")
|
||||
if not isinstance(_provider_uri, dict):
|
||||
_provider_uri = {self.DEFAULT_FREQ: _provider_uri}
|
||||
if not isinstance(_mount_path, dict):
|
||||
_mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}
|
||||
|
||||
if self.get_uri_type() == QlibConfig.LOCAL_URI:
|
||||
self["provider_uri"] = str(Path(self["provider_uri"]).expanduser().resolve())
|
||||
# check provider_uri and mount_path
|
||||
_miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())
|
||||
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
|
||||
|
||||
# resolve
|
||||
for _freq, _uri in _provider_uri.items():
|
||||
# provider_uri
|
||||
if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
|
||||
_provider_uri[_freq] = str(Path(_uri).expanduser().resolve())
|
||||
# mount_path
|
||||
_mount_path[_freq] = (
|
||||
_mount_path[_freq]
|
||||
if _mount_path[_freq] is None
|
||||
else str(Path(_mount_path[_freq]).expanduser().resolve())
|
||||
)
|
||||
self["provider_uri"] = _provider_uri
|
||||
self["mount_path"] = _mount_path
|
||||
|
||||
def get_uri_type(self):
|
||||
path = self["provider_uri"]
|
||||
@@ -270,14 +351,6 @@ class QlibConfig(Config):
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def get_data_path(self):
|
||||
if self.get_uri_type() == QlibConfig.LOCAL_URI:
|
||||
return self["provider_uri"]
|
||||
elif self.get_uri_type() == QlibConfig.NFS_URI:
|
||||
return self["mount_path"]
|
||||
else:
|
||||
raise NotImplementedError(f"This type of uri is not supported")
|
||||
|
||||
def set(self, default_conf: str = "client", **kwargs):
|
||||
"""
|
||||
configure qlib based on the input parameters
|
||||
@@ -325,11 +398,20 @@ class QlibConfig(Config):
|
||||
if not (self["expression_cache"] is None and self["dataset_cache"] is None):
|
||||
# check redis
|
||||
if not can_use_cache():
|
||||
logger.warning(
|
||||
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), cache will not be used!"
|
||||
)
|
||||
log_str = ""
|
||||
# check expression cache
|
||||
if self.is_depend_redis(self["expression_cache"]):
|
||||
log_str += self["expression_cache"]
|
||||
self["expression_cache"] = None
|
||||
# check dataset cache
|
||||
if self.is_depend_redis(self["dataset_cache"]):
|
||||
log_str += f" and {self['dataset_cache']}" if log_str else self["dataset_cache"]
|
||||
self["dataset_cache"] = None
|
||||
if log_str:
|
||||
logger.warning(
|
||||
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), "
|
||||
f"{log_str} will not be used!"
|
||||
)
|
||||
|
||||
def register(self):
|
||||
from .utils import init_instance_by_config
|
||||
|
||||
346
qlib/contrib/data/dataset.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return torch.tensor(x, dtype=torch.float, device=device)
|
||||
return x
|
||||
|
||||
|
||||
def _create_ts_slices(index, seq_len):
|
||||
"""
|
||||
create time series slices from pandas index
|
||||
|
||||
Args:
|
||||
index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
|
||||
seq_len (int): sequence length
|
||||
"""
|
||||
assert isinstance(index, pd.MultiIndex), "unsupported index type"
|
||||
assert seq_len > 0, "sequence length should be larger than 0"
|
||||
assert index.is_monotonic_increasing, "index should be sorted"
|
||||
|
||||
# number of dates for each instrument
|
||||
sample_count_by_insts = index.to_series().groupby(level=0).size().values
|
||||
|
||||
# start index for each instrument
|
||||
start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1)
|
||||
start_index_of_insts[0] = 0
|
||||
|
||||
# all the [start, stop) indices of features
|
||||
# features between [start, stop) will be used to predict label at `stop - 1`
|
||||
slices = []
|
||||
for cur_loc, cur_cnt in zip(start_index_of_insts, sample_count_by_insts):
|
||||
for stop in range(1, cur_cnt + 1):
|
||||
end = cur_loc + stop
|
||||
start = max(end - seq_len, 0)
|
||||
slices.append(slice(start, end))
|
||||
slices = np.array(slices, dtype="object")
|
||||
|
||||
assert len(slices) == len(index) # the i-th slice = index[i]
|
||||
|
||||
return slices
|
||||
|
||||
|
||||
def _get_date_parse_fn(target):
|
||||
"""get date parse function
|
||||
|
||||
This method is used to parse date arguments as target type.
|
||||
|
||||
Example:
|
||||
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
|
||||
get_date_parse_fn(20120101)('2017-01-01') => 20170101
|
||||
"""
|
||||
if isinstance(target, pd.Timestamp):
|
||||
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
|
||||
elif isinstance(target, int):
|
||||
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
|
||||
elif isinstance(target, str) and len(target) == 8:
|
||||
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
|
||||
else:
|
||||
_fn = lambda x: x # '2021-01-01'
|
||||
return _fn
|
||||
|
||||
|
||||
def _maybe_padding(x, seq_len, zeros=None):
|
||||
"""padding 2d <time * feature> data with zeros
|
||||
|
||||
Args:
|
||||
x (np.ndarray): 2d data with shape <time * feature>
|
||||
seq_len (int): target sequence length
|
||||
zeros (np.ndarray): zeros with shape <seq_len * feature>
|
||||
"""
|
||||
assert seq_len > 0, "sequence length should be larger than 0"
|
||||
if zeros is None:
|
||||
zeros = np.zeros((seq_len, x.shape[1]), dtype=np.float32)
|
||||
else:
|
||||
assert len(zeros) >= seq_len, "zeros matrix is not large enough for padding"
|
||||
if len(x) != seq_len: # padding zeros
|
||||
x = np.concatenate([zeros[: seq_len - len(x), : x.shape[1]], x], axis=0)
|
||||
return x
|
||||
|
||||
|
||||
class MTSDatasetH(DatasetH):
|
||||
"""Memory Augmented Time Series Dataset
|
||||
|
||||
Args:
|
||||
handler (DataHandler): data handler
|
||||
segments (dict): data split segments
|
||||
seq_len (int): time series sequence length
|
||||
horizon (int): label horizon
|
||||
num_states (int): how many memory states to be added
|
||||
memory_mode (str): memory mode (daily or sample)
|
||||
batch_size (int): batch size (<0 will use daily sampling)
|
||||
n_samples (int): number of samples in the same day
|
||||
shuffle (bool): whether shuffle data
|
||||
drop_last (bool): whether drop last batch < batch_size
|
||||
input_size (int): reshape flatten rows as this input_size (backward compatibility)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler,
|
||||
segments,
|
||||
seq_len=60,
|
||||
horizon=0,
|
||||
num_states=0,
|
||||
memory_mode="sample",
|
||||
batch_size=-1,
|
||||
n_samples=None,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
input_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
assert memory_mode in ["sample", "daily"], "unsupported memory mode"
|
||||
assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)"
|
||||
assert batch_size != 0, "invalid batch size"
|
||||
|
||||
if batch_size > 0 and n_samples is not None:
|
||||
warnings.warn("`n_samples` can only be used for daily sampling (`batch_size < 0`)")
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.horizon = horizon
|
||||
self.num_states = num_states
|
||||
self.memory_mode = memory_mode
|
||||
self.batch_size = batch_size
|
||||
self.n_samples = n_samples
|
||||
self.shuffle = shuffle
|
||||
self.drop_last = drop_last
|
||||
self.input_size = input_size
|
||||
self.params = (batch_size, n_samples, drop_last, shuffle) # for train/eval switch
|
||||
|
||||
super().__init__(handler, segments, **kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
|
||||
super().setup_data(**kwargs)
|
||||
|
||||
if handler_kwargs is not None:
|
||||
self.handler.setup_data(**handler_kwargs)
|
||||
|
||||
# pre-fetch data and change index to <code, date>
|
||||
# NOTE: we will use inplace sort to reduce memory use
|
||||
try:
|
||||
df = self.handler._learn.copy() # use copy otherwise recorder will fail
|
||||
# FIXME: currently we cannot support switching from `_learn` to `_infer` for inference
|
||||
except:
|
||||
warnings.warn("cannot access `_learn`, will load raw data")
|
||||
df = self.handler._data.copy()
|
||||
df.index = df.index.swaplevel()
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
# convert to numpy
|
||||
self._data = df["feature"].values.astype("float32")
|
||||
np.nan_to_num(self._data, copy=False) # NOTE: fillna in case users forget using the fillna processor
|
||||
self._label = df["label"].squeeze().values.astype("float32")
|
||||
self._index = df.index
|
||||
|
||||
if self.input_size is not None and self.input_size != self._data.shape[1]:
|
||||
warnings.warn("the data has different shape from input_size and the data will be reshaped")
|
||||
assert self._data.shape[1] % self.input_size == 0, "data mismatch, please check `input_size`"
|
||||
|
||||
# create batch slices
|
||||
self._batch_slices = _create_ts_slices(self._index, self.seq_len)
|
||||
|
||||
# create daily slices
|
||||
daily_slices = {date: [] for date in sorted(self._index.unique(level=1))} # sorted by date
|
||||
for i, (code, date) in enumerate(self._index):
|
||||
daily_slices[date].append(self._batch_slices[i])
|
||||
self._daily_slices = np.array(list(daily_slices.values()), dtype="object")
|
||||
self._daily_index = pd.Series(list(daily_slices.keys())) # index is the original date index
|
||||
|
||||
# add memory (sample wise and daily)
|
||||
if self.memory_mode == "sample":
|
||||
self._memory = np.zeros((len(self._data), self.num_states), dtype=np.float32)
|
||||
elif self.memory_mode == "daily":
|
||||
self._memory = np.zeros((len(self._daily_index), self.num_states), dtype=np.float32)
|
||||
else:
|
||||
raise ValueError(f"invalid memory_mode `{self.memory_mode}`")
|
||||
|
||||
# padding tensor
|
||||
self._zeros = np.zeros((self.seq_len, max(self.num_states, self._data.shape[1])), dtype=np.float32)
|
||||
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
fn = _get_date_parse_fn(self._index[0][1])
|
||||
start_date = fn(slc.start)
|
||||
end_date = fn(slc.stop)
|
||||
obj = copy.copy(self) # shallow copy
|
||||
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
|
||||
obj._data = self._data # reference (no copy)
|
||||
obj._label = self._label
|
||||
obj._index = self._index
|
||||
obj._memory = self._memory
|
||||
obj._zeros = self._zeros
|
||||
# update index for this batch
|
||||
date_index = self._index.get_level_values(1)
|
||||
obj._batch_slices = self._batch_slices[(date_index >= start_date) & (date_index <= end_date)]
|
||||
mask = (self._daily_index.values >= start_date) & (self._daily_index.values <= end_date)
|
||||
obj._daily_slices = self._daily_slices[mask]
|
||||
obj._daily_index = self._daily_index[mask]
|
||||
return obj
|
||||
|
||||
def restore_index(self, index):
|
||||
return self._index[index]
|
||||
|
||||
def restore_daily_index(self, daily_index):
|
||||
return pd.Index(self._daily_index.loc[daily_index])
|
||||
|
||||
def assign_data(self, index, vals):
|
||||
if self.num_states == 0:
|
||||
raise ValueError("cannot assign data as `num_states==0`")
|
||||
if isinstance(vals, torch.Tensor):
|
||||
vals = vals.detach().cpu().numpy()
|
||||
self._memory[index] = vals
|
||||
|
||||
def clear_memory(self):
|
||||
if self.num_states == 0:
|
||||
raise ValueError("cannot clear memory as `num_states==0`")
|
||||
self._memory[:] = 0
|
||||
|
||||
def train(self):
|
||||
"""enable traning mode"""
|
||||
self.batch_size, self.n_samples, self.drop_last, self.shuffle = self.params
|
||||
|
||||
def eval(self):
|
||||
"""enable evaluation mode"""
|
||||
self.batch_size = -1
|
||||
self.n_samples = None
|
||||
self.drop_last = False
|
||||
self.shuffle = False
|
||||
|
||||
def _get_slices(self):
|
||||
if self.batch_size < 0: # daily sampling
|
||||
slices = self._daily_slices.copy()
|
||||
batch_size = -1 * self.batch_size
|
||||
else: # normal sampling
|
||||
slices = self._batch_slices.copy()
|
||||
batch_size = self.batch_size
|
||||
return slices, batch_size
|
||||
|
||||
def __len__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
if self.drop_last:
|
||||
return len(slices) // batch_size
|
||||
return (len(slices) + batch_size - 1) // batch_size
|
||||
|
||||
def __iter__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
indices = np.arange(len(slices))
|
||||
if self.shuffle:
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[::batch_size]:
|
||||
if self.drop_last and i + batch_size > len(indices):
|
||||
break
|
||||
|
||||
data = [] # store features
|
||||
label = [] # store labels
|
||||
index = [] # store index
|
||||
state = [] # store memory states
|
||||
daily_index = [] # store daily index
|
||||
daily_count = [] # store number of samples for each day
|
||||
|
||||
for j in indices[i : i + batch_size]:
|
||||
|
||||
# normal sampling: self.batch_size > 0 => slices is a list => slices_subset is a slice
|
||||
# daily sampling: self.batch_size < 0 => slices is a nested list => slices_subset is a list
|
||||
slices_subset = slices[j]
|
||||
|
||||
# daily sampling
|
||||
# each slices_subset contains a list of slices for multiple stocks
|
||||
# NOTE: daily sampling is used in 1) eval mode, 2) train mode with self.batch_size < 0
|
||||
if self.batch_size < 0:
|
||||
|
||||
# store daily index
|
||||
idx = self._daily_index.index[j] # daily_index.index is the index of the original data
|
||||
daily_index.append(idx)
|
||||
|
||||
# store daily memory if specified
|
||||
# NOTE: daily memory always requires daily sampling (self.batch_size < 0)
|
||||
if self.memory_mode == "daily":
|
||||
slc = slice(max(idx - self.seq_len - self.horizon, 0), max(idx - self.horizon, 0))
|
||||
state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros))
|
||||
|
||||
# down-sample stocks and store count
|
||||
if self.n_samples and 0 < self.n_samples < len(slices_subset): # intraday subsample
|
||||
slices_subset = np.random.choice(slices_subset, self.n_samples, replace=False)
|
||||
daily_count.append(len(slices_subset))
|
||||
|
||||
# normal sampling
|
||||
# each slices_subset is a single slice
|
||||
# NOTE: normal sampling is used in train mode with self.batch_size > 0
|
||||
else:
|
||||
slices_subset = [slices_subset]
|
||||
|
||||
for slc in slices_subset:
|
||||
|
||||
# legacy support for Alpha360 data by `input_size`
|
||||
if self.input_size:
|
||||
data.append(self._data[slc.stop - 1].reshape(self.input_size, -1).T)
|
||||
else:
|
||||
data.append(_maybe_padding(self._data[slc], self.seq_len, self._zeros))
|
||||
|
||||
if self.memory_mode == "sample":
|
||||
state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros)[: -self.horizon])
|
||||
|
||||
label.append(self._label[slc.stop - 1])
|
||||
index.append(slc.stop - 1)
|
||||
|
||||
# end slices loop
|
||||
|
||||
# end indices batch loop
|
||||
|
||||
# concate
|
||||
data = _to_tensor(np.stack(data))
|
||||
state = _to_tensor(np.stack(state))
|
||||
label = _to_tensor(np.stack(label))
|
||||
index = np.array(index)
|
||||
daily_index = np.array(daily_index)
|
||||
daily_count = np.array(daily_count)
|
||||
|
||||
# yield -> generator
|
||||
yield {
|
||||
"data": data,
|
||||
"label": label,
|
||||
"state": state,
|
||||
"index": index,
|
||||
"daily_index": daily_index,
|
||||
"daily_count": daily_count,
|
||||
}
|
||||
|
||||
# end indice loop
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_cls_kwargs
|
||||
from ...utils import get_callable_kwargs
|
||||
from ...data.dataset import processor as processor_module
|
||||
from ...log import TimeInspector
|
||||
from inspect import getfullargspec
|
||||
@@ -14,7 +14,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
if not isinstance(p, Processor):
|
||||
klass, pkwargs = get_cls_kwargs(p, processor_module)
|
||||
klass, pkwargs = get_callable_kwargs(p, processor_module)
|
||||
args = getfullargspec(klass).args
|
||||
if "fit_start_time" in args and "fit_end_time" in args:
|
||||
assert (
|
||||
@@ -58,6 +58,7 @@ class Alpha360(DataHandlerLP):
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
@@ -72,6 +73,7 @@ class Alpha360(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"inst_processor": inst_processor,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -144,6 +146,7 @@ class Alpha158(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
@@ -158,6 +161,7 @@ class Alpha158(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"inst_processor": inst_processor,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
|
||||
@@ -36,9 +36,10 @@ def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
|
||||
def cal_risk_analysis_scaler(freq):
|
||||
_count, _freq = Freq.parse(freq)
|
||||
# len(D.calendar(start_time='2010-01-01', end_time='2019-12-31', freq='day')) = 2384
|
||||
_freq_scaler = {
|
||||
Freq.NORM_FREQ_MINUTE: 240 * 252,
|
||||
Freq.NORM_FREQ_DAY: 252,
|
||||
Freq.NORM_FREQ_MINUTE: 240 * 238,
|
||||
Freq.NORM_FREQ_DAY: 238,
|
||||
Freq.NORM_FREQ_WEEK: 50,
|
||||
Freq.NORM_FREQ_MONTH: 12,
|
||||
}
|
||||
|
||||
@@ -27,7 +27,6 @@ from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
class DailyBatchSampler(Sampler):
|
||||
def __init__(self, data_source):
|
||||
|
||||
self.data_source = data_source
|
||||
# calculate number of samples in each batch
|
||||
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values
|
||||
|
||||
@@ -564,7 +564,7 @@ class FeatureTransformer(nn.Module):
|
||||
self.shared = None
|
||||
self.independ = nn.ModuleList()
|
||||
if first:
|
||||
self.independ.append(GLU(inp, out_dim, vbs=vbs))
|
||||
self.independ.append(GLU(inp_dim, out_dim, vbs=vbs))
|
||||
for x in range(first, n_ind):
|
||||
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
|
||||
self.scale = float(np.sqrt(0.5))
|
||||
|
||||
944
qlib/contrib/model/pytorch_tra.py
Normal file
@@ -0,0 +1,944 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import io
|
||||
import os
|
||||
import copy
|
||||
import math
|
||||
import json
|
||||
import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except:
|
||||
SummaryWriter = None
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.contrib.data.dataset import MTSDatasetH
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
class TRAModel(Model):
|
||||
"""
|
||||
TRA Model
|
||||
|
||||
Args:
|
||||
model_config (dict): model config (will be used by RNN or Transformer)
|
||||
tra_config (dict): TRA config (will be used by TRA)
|
||||
model_type (str): which backbone model to use (RNN/Transformer)
|
||||
lr (float): learning rate
|
||||
n_epochs (int): number of total epochs
|
||||
early_stop (int): early stop when performance not improved at this step
|
||||
update_freq (int): gradient update frequency
|
||||
max_steps_per_epoch (int): maximum number of steps in one epoch
|
||||
lamb (float): regularization parameter
|
||||
rho (float): exponential decay rate for `lamb`
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
seed (int): random seed
|
||||
logdir (str): local log directory
|
||||
eval_train (bool): whether evaluate train set between epochs
|
||||
eval_test (bool): whether evaluate test set between epochs
|
||||
pretrain (bool): whether pretrain the backbone model before training TRA.
|
||||
Note that only TRA will be optimized after pretraining
|
||||
init_state (str): model init state path
|
||||
freeze_model (bool): whether freeze backbone model parameters
|
||||
freeze_predictors (bool): whether freeze predictors parameters
|
||||
transport_method (str): transport method, can be none/router/oracle
|
||||
memory_mode (str): memory mode, the same argument for MTSDatasetH
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
tra_config,
|
||||
model_type="RNN",
|
||||
lr=1e-3,
|
||||
n_epochs=500,
|
||||
early_stop=50,
|
||||
update_freq=1,
|
||||
max_steps_per_epoch=None,
|
||||
lamb=0.0,
|
||||
rho=0.99,
|
||||
alpha=1.0,
|
||||
seed=0,
|
||||
logdir=None,
|
||||
eval_train=False,
|
||||
eval_test=False,
|
||||
pretrain=False,
|
||||
init_state=None,
|
||||
reset_router=False,
|
||||
freeze_model=False,
|
||||
freeze_predictors=False,
|
||||
transport_method="none",
|
||||
memory_mode="sample",
|
||||
):
|
||||
|
||||
self.logger = get_module_logger("TRA")
|
||||
|
||||
assert memory_mode in ["sample", "daily"], "invalid memory mode"
|
||||
assert transport_method in ["none", "router", "oracle"], f"invalid transport method {transport_method}"
|
||||
assert transport_method == "none" or tra_config["num_states"] > 1, "optimal transport requires `num_states` > 1"
|
||||
assert (
|
||||
memory_mode != "daily" or tra_config["src_info"] == "TPE"
|
||||
), "daily transport can only support TPE as `src_info`"
|
||||
|
||||
if transport_method == "router" and not eval_train:
|
||||
self.logger.warning("`eval_train` will be ignored when using TRA.router")
|
||||
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
self.model_config = model_config
|
||||
self.tra_config = tra_config
|
||||
self.model_type = model_type
|
||||
self.lr = lr
|
||||
self.n_epochs = n_epochs
|
||||
self.early_stop = early_stop
|
||||
self.update_freq = update_freq
|
||||
self.max_steps_per_epoch = max_steps_per_epoch
|
||||
self.lamb = lamb
|
||||
self.rho = rho
|
||||
self.alpha = alpha
|
||||
self.seed = seed
|
||||
self.logdir = logdir
|
||||
self.eval_train = eval_train
|
||||
self.eval_test = eval_test
|
||||
self.pretrain = pretrain
|
||||
self.init_state = init_state
|
||||
self.reset_router = reset_router
|
||||
self.freeze_model = freeze_model
|
||||
self.freeze_predictors = freeze_predictors
|
||||
self.transport_method = transport_method
|
||||
self.use_daily_transport = memory_mode == "daily"
|
||||
self.transport_fn = transport_daily if self.use_daily_transport else transport_sample
|
||||
|
||||
self._writer = None
|
||||
if self.logdir is not None:
|
||||
if os.path.exists(self.logdir):
|
||||
self.logger.warning(f"logdir {self.logdir} is not empty")
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
if SummaryWriter is not None:
|
||||
self._writer = SummaryWriter(log_dir=self.logdir)
|
||||
|
||||
self._init_model()
|
||||
|
||||
def _init_model(self):
|
||||
|
||||
self.logger.info("init TRAModel...")
|
||||
|
||||
self.model = eval(self.model_type)(**self.model_config).to(device)
|
||||
print(self.model)
|
||||
|
||||
self.tra = TRA(self.model.output_size, **self.tra_config).to(device)
|
||||
print(self.tra)
|
||||
|
||||
if self.init_state:
|
||||
self.logger.warning(f"load state dict from `init_state`")
|
||||
state_dict = torch.load(self.init_state, map_location="cpu")
|
||||
self.model.load_state_dict(state_dict["model"])
|
||||
res = load_state_dict_unsafe(self.tra, state_dict["tra"])
|
||||
self.logger.warning(str(res))
|
||||
|
||||
if self.reset_router:
|
||||
self.logger.warning(f"reset TRA.router parameters")
|
||||
self.tra.fc.reset_parameters()
|
||||
self.tra.router.reset_parameters()
|
||||
|
||||
if self.freeze_model:
|
||||
self.logger.warning(f"freeze model parameters")
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
if self.freeze_predictors:
|
||||
self.logger.warning(f"freeze TRA.predictors parameters")
|
||||
for param in self.tra.predictors.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters() if p.requires_grad]))
|
||||
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters() if p.requires_grad]))
|
||||
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
|
||||
|
||||
self.fitted = False
|
||||
self.global_step = -1
|
||||
|
||||
def train_epoch(self, epoch, data_set, is_pretrain=False):
|
||||
|
||||
self.model.train()
|
||||
self.tra.train()
|
||||
data_set.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
P_all = []
|
||||
prob_all = []
|
||||
choice_all = []
|
||||
max_steps = len(data_set)
|
||||
if self.max_steps_per_epoch is not None:
|
||||
if epoch == 0 and self.max_steps_per_epoch < max_steps:
|
||||
self.logger.info(f"max steps updated from {max_steps} to {self.max_steps_per_epoch}")
|
||||
max_steps = min(self.max_steps_per_epoch, max_steps)
|
||||
|
||||
cur_step = 0
|
||||
total_loss = 0
|
||||
total_count = 0
|
||||
for batch in tqdm(data_set, total=max_steps):
|
||||
cur_step += 1
|
||||
if cur_step > max_steps:
|
||||
break
|
||||
|
||||
if not is_pretrain:
|
||||
self.global_step += 1
|
||||
|
||||
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
|
||||
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
|
||||
|
||||
with torch.set_grad_enabled(not self.freeze_model):
|
||||
hidden = self.model(data)
|
||||
|
||||
all_preds, choice, prob = self.tra(hidden, state)
|
||||
|
||||
if is_pretrain or self.transport_method != "none":
|
||||
# NOTE: use oracle transport for pre-training
|
||||
loss, pred, L, P = self.transport_fn(
|
||||
all_preds,
|
||||
label,
|
||||
choice,
|
||||
prob,
|
||||
state.mean(dim=1),
|
||||
count,
|
||||
self.transport_method if not is_pretrain else "oracle",
|
||||
self.alpha,
|
||||
training=True,
|
||||
)
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
if self.use_daily_transport: # only save for daily transport
|
||||
P_all.append(pd.DataFrame(P.detach().cpu().numpy(), index=index))
|
||||
prob_all.append(pd.DataFrame(prob.detach().cpu().numpy(), index=index))
|
||||
choice_all.append(pd.DataFrame(choice.detach().cpu().numpy(), index=index))
|
||||
decay = self.rho ** (self.global_step // 100) # decay every 100 steps
|
||||
lamb = 0 if is_pretrain else self.lamb * decay
|
||||
reg = prob.log().mul(P).sum(dim=1).mean() # train router to predict OT assignment
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/router_loss", -reg.item(), self.global_step)
|
||||
self._writer.add_scalar("training/reg_loss", loss.item(), self.global_step)
|
||||
self._writer.add_scalar("training/lamb", lamb, self.global_step)
|
||||
if not self.use_daily_transport:
|
||||
P_mean = P.mean(axis=0).detach()
|
||||
self._writer.add_scalar("training/P", P_mean.max() / P_mean.min(), self.global_step)
|
||||
loss = loss - lamb * reg
|
||||
else:
|
||||
pred = all_preds.mean(dim=1)
|
||||
loss = loss_fn(pred, label)
|
||||
|
||||
(loss / self.update_freq).backward()
|
||||
if cur_step % self.update_freq == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/total_loss", loss.item(), self.global_step)
|
||||
|
||||
total_loss += loss.item()
|
||||
total_count += 1
|
||||
|
||||
if self.use_daily_transport and len(P_all):
|
||||
P_all = pd.concat(P_all, axis=0)
|
||||
prob_all = pd.concat(prob_all, axis=0)
|
||||
choice_all = pd.concat(choice_all, axis=0)
|
||||
P_all.index = data_set.restore_daily_index(P_all.index)
|
||||
prob_all.index = P_all.index
|
||||
choice_all.index = P_all.index
|
||||
if not is_pretrain:
|
||||
self._writer.add_image("P", plot(P_all), epoch, dataformats="HWC")
|
||||
self._writer.add_image("prob", plot(prob_all), epoch, dataformats="HWC")
|
||||
self._writer.add_image("choice", plot(choice_all), epoch, dataformats="HWC")
|
||||
|
||||
total_loss /= total_count
|
||||
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/loss", total_loss, epoch)
|
||||
|
||||
return total_loss
|
||||
|
||||
def test_epoch(self, epoch, data_set, return_pred=False, prefix="test", is_pretrain=False):
|
||||
|
||||
self.model.eval()
|
||||
self.tra.eval()
|
||||
data_set.eval()
|
||||
|
||||
preds = []
|
||||
probs = []
|
||||
P_all = []
|
||||
metrics = []
|
||||
for batch in tqdm(data_set):
|
||||
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
|
||||
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
|
||||
|
||||
with torch.no_grad():
|
||||
hidden = self.model(data)
|
||||
all_preds, choice, prob = self.tra(hidden, state)
|
||||
|
||||
if is_pretrain or self.transport_method != "none":
|
||||
loss, pred, L, P = self.transport_fn(
|
||||
all_preds,
|
||||
label,
|
||||
choice,
|
||||
prob,
|
||||
state.mean(dim=1),
|
||||
count,
|
||||
self.transport_method if not is_pretrain else "oracle",
|
||||
self.alpha,
|
||||
training=False,
|
||||
)
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
if P is not None and return_pred:
|
||||
P_all.append(pd.DataFrame(P.cpu().numpy(), index=index))
|
||||
else:
|
||||
pred = all_preds.mean(dim=1)
|
||||
|
||||
X = np.c_[pred.cpu().numpy(), label.cpu().numpy(), all_preds.cpu().numpy()]
|
||||
columns = ["score", "label"] + ["score_%d" % d for d in range(all_preds.shape[1])]
|
||||
pred = pd.DataFrame(X, index=batch["index"], columns=columns)
|
||||
|
||||
metrics.append(evaluate(pred))
|
||||
|
||||
if return_pred:
|
||||
preds.append(pred)
|
||||
if prob is not None:
|
||||
columns = ["prob_%d" % d for d in range(all_preds.shape[1])]
|
||||
probs.append(pd.DataFrame(prob.cpu().numpy(), index=index, columns=columns))
|
||||
|
||||
metrics = pd.DataFrame(metrics)
|
||||
metrics = {
|
||||
"MSE": metrics.MSE.mean(),
|
||||
"MAE": metrics.MAE.mean(),
|
||||
"IC": metrics.IC.mean(),
|
||||
"ICIR": metrics.IC.mean() / metrics.IC.std(),
|
||||
}
|
||||
|
||||
if self._writer is not None and epoch >= 0 and not is_pretrain:
|
||||
for key, value in metrics.items():
|
||||
self._writer.add_scalar(prefix + "/" + key, value, epoch)
|
||||
|
||||
if return_pred:
|
||||
preds = pd.concat(preds, axis=0)
|
||||
preds.index = data_set.restore_index(preds.index)
|
||||
preds.index = preds.index.swaplevel()
|
||||
preds.sort_index(inplace=True)
|
||||
|
||||
if probs:
|
||||
probs = pd.concat(probs, axis=0)
|
||||
if self.use_daily_transport:
|
||||
probs.index = data_set.restore_daily_index(probs.index)
|
||||
else:
|
||||
probs.index = data_set.restore_index(probs.index)
|
||||
probs.index = probs.index.swaplevel()
|
||||
probs.sort_index(inplace=True)
|
||||
|
||||
if len(P_all):
|
||||
P_all = pd.concat(P_all, axis=0)
|
||||
if self.use_daily_transport:
|
||||
P_all.index = data_set.restore_daily_index(P_all.index)
|
||||
else:
|
||||
P_all.index = data_set.restore_index(P_all.index)
|
||||
P_all.index = P_all.index.swaplevel()
|
||||
P_all.sort_index(inplace=True)
|
||||
|
||||
return metrics, preds, probs, P_all
|
||||
|
||||
def _fit(self, train_set, valid_set, test_set, evals_result, is_pretrain=True):
|
||||
|
||||
best_score = -1
|
||||
best_epoch = 0
|
||||
stop_rounds = 0
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
# train
|
||||
if not is_pretrain and self.transport_method != "none":
|
||||
self.logger.info("init memory...")
|
||||
self.test_epoch(-1, train_set)
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
self.logger.info("Epoch %d:", epoch)
|
||||
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(epoch, train_set, is_pretrain=is_pretrain)
|
||||
|
||||
self.logger.info("evaluating...")
|
||||
# NOTE: during evaluating, the whole memory will be refreshed
|
||||
if not is_pretrain and (self.transport_method == "router" or self.eval_train):
|
||||
train_set.clear_memory() # NOTE: clear the shared memory
|
||||
train_metrics = self.test_epoch(epoch, train_set, is_pretrain=is_pretrain, prefix="train")[0]
|
||||
evals_result["train"].append(train_metrics)
|
||||
self.logger.info("train metrics: %s" % train_metrics)
|
||||
|
||||
valid_metrics = self.test_epoch(epoch, valid_set, is_pretrain=is_pretrain, prefix="valid")[0]
|
||||
evals_result["valid"].append(valid_metrics)
|
||||
self.logger.info("valid metrics: %s" % valid_metrics)
|
||||
|
||||
if self.eval_test:
|
||||
test_metrics = self.test_epoch(epoch, test_set, is_pretrain=is_pretrain, prefix="test")[0]
|
||||
evals_result["test"].append(test_metrics)
|
||||
self.logger.info("test metrics: %s" % test_metrics)
|
||||
|
||||
if valid_metrics["IC"] > best_score:
|
||||
best_score = valid_metrics["IC"]
|
||||
stop_rounds = 0
|
||||
best_epoch = epoch
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
if self.logdir is not None:
|
||||
torch.save(best_params, self.logdir + "/model.bin")
|
||||
else:
|
||||
stop_rounds += 1
|
||||
if stop_rounds >= self.early_stop:
|
||||
self.logger.info("early stop @ %s" % epoch)
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_params["model"])
|
||||
self.tra.load_state_dict(best_params["tra"])
|
||||
|
||||
return best_score
|
||||
|
||||
def fit(self, dataset, evals_result=dict()):
|
||||
|
||||
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
|
||||
|
||||
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
|
||||
|
||||
self.fitted = True
|
||||
self.global_step = -1
|
||||
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
evals_result["test"] = []
|
||||
|
||||
if self.pretrain:
|
||||
self.logger.info("pretraining...")
|
||||
self.optimizer = optim.Adam(
|
||||
list(self.model.parameters()) + list(self.tra.predictors.parameters()), lr=self.lr
|
||||
)
|
||||
self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True)
|
||||
|
||||
# reset optimizer
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
|
||||
|
||||
self.logger.info("training...")
|
||||
best_score = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=False)
|
||||
|
||||
self.logger.info("inference")
|
||||
train_metrics, train_preds, train_probs, train_P = self.test_epoch(-1, train_set, return_pred=True)
|
||||
self.logger.info("train metrics: %s" % train_metrics)
|
||||
|
||||
valid_metrics, valid_preds, valid_probs, valid_P = self.test_epoch(-1, valid_set, return_pred=True)
|
||||
self.logger.info("valid metrics: %s" % valid_metrics)
|
||||
|
||||
test_metrics, test_preds, test_probs, test_P = self.test_epoch(-1, test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % test_metrics)
|
||||
|
||||
if self.logdir:
|
||||
self.logger.info("save model & pred to local directory")
|
||||
|
||||
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
|
||||
self.logdir + "/logs.csv", index=False
|
||||
)
|
||||
|
||||
torch.save({"model": self.model.state_dict(), "tra": self.tra.state_dict()}, self.logdir + "/model.bin")
|
||||
|
||||
train_preds.to_pickle(self.logdir + "/train_pred.pkl")
|
||||
valid_preds.to_pickle(self.logdir + "/valid_pred.pkl")
|
||||
test_preds.to_pickle(self.logdir + "/test_pred.pkl")
|
||||
|
||||
if len(train_probs):
|
||||
train_probs.to_pickle(self.logdir + "/train_prob.pkl")
|
||||
valid_probs.to_pickle(self.logdir + "/valid_prob.pkl")
|
||||
test_probs.to_pickle(self.logdir + "/test_prob.pkl")
|
||||
|
||||
if len(train_P):
|
||||
train_P.to_pickle(self.logdir + "/train_P.pkl")
|
||||
valid_P.to_pickle(self.logdir + "/valid_P.pkl")
|
||||
test_P.to_pickle(self.logdir + "/test_P.pkl")
|
||||
|
||||
info = {
|
||||
"config": {
|
||||
"model_config": self.model_config,
|
||||
"tra_config": self.tra_config,
|
||||
"model_type": self.model_type,
|
||||
"lr": self.lr,
|
||||
"n_epochs": self.n_epochs,
|
||||
"early_stop": self.early_stop,
|
||||
"max_steps_per_epoch": self.max_steps_per_epoch,
|
||||
"lamb": self.lamb,
|
||||
"rho": self.rho,
|
||||
"alpha": self.alpha,
|
||||
"seed": self.seed,
|
||||
"logdir": self.logdir,
|
||||
"pretrain": self.pretrain,
|
||||
"init_state": self.init_state,
|
||||
"transport_method": self.transport_method,
|
||||
"use_daily_transport": self.use_daily_transport,
|
||||
},
|
||||
"best_eval_metric": -best_score, # NOTE: -1 for minimize
|
||||
"metrics": {"train": train_metrics, "valid": valid_metrics, "test": test_metrics},
|
||||
}
|
||||
with open(self.logdir + "/info.json", "w") as f:
|
||||
json.dump(info, f)
|
||||
|
||||
def predict(self, dataset, segment="test"):
|
||||
|
||||
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
|
||||
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
test_set = dataset.prepare(segment)
|
||||
|
||||
metrics, preds, _, _ = self.test_epoch(-1, test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % metrics)
|
||||
|
||||
return preds
|
||||
|
||||
|
||||
class RNN(nn.Module):
|
||||
|
||||
"""RNN Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of hidden layers
|
||||
rnn_arch (str): rnn architecture
|
||||
use_attn (bool): whether use attention layer.
|
||||
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
|
||||
dropout (float): dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
rnn_arch="GRU",
|
||||
use_attn=True,
|
||||
dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.rnn_arch = rnn_arch
|
||||
self.use_attn = use_attn
|
||||
|
||||
if hidden_size < input_size:
|
||||
# compression
|
||||
self.input_proj = nn.Linear(input_size, hidden_size)
|
||||
else:
|
||||
self.input_proj = None
|
||||
|
||||
self.rnn = getattr(nn, rnn_arch)(
|
||||
input_size=min(input_size, hidden_size),
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
if self.use_attn:
|
||||
self.W = nn.Linear(hidden_size, hidden_size)
|
||||
self.u = nn.Linear(hidden_size, 1, bias=False)
|
||||
self.output_size = hidden_size * 2
|
||||
else:
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.input_proj is not None:
|
||||
x = self.input_proj(x)
|
||||
|
||||
rnn_out, last_out = self.rnn(x)
|
||||
if self.rnn_arch == "LSTM":
|
||||
last_out = last_out[0]
|
||||
last_out = last_out.mean(dim=0)
|
||||
|
||||
if self.use_attn:
|
||||
laten = self.W(rnn_out).tanh()
|
||||
scores = self.u(laten).softmax(dim=1)
|
||||
att_out = (rnn_out * scores).sum(dim=1)
|
||||
last_out = torch.cat([last_out, att_out], dim=1)
|
||||
|
||||
return last_out
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[: x.size(0), :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
"""Transformer Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of transformer layers
|
||||
num_heads (int): number of heads in transformer
|
||||
dropout (float): dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
num_heads=2,
|
||||
dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.input_proj = nn.Linear(input_size, hidden_size)
|
||||
|
||||
self.pe = PositionalEncoding(input_size, dropout)
|
||||
layer = nn.TransformerEncoderLayer(
|
||||
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
|
||||
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = x.permute(1, 0, 2).contiguous() # the first dim need to be time
|
||||
x = self.pe(x)
|
||||
|
||||
x = self.input_proj(x)
|
||||
out = self.encoder(x)
|
||||
|
||||
return out[-1]
|
||||
|
||||
|
||||
class TRA(nn.Module):
|
||||
|
||||
"""Temporal Routing Adaptor (TRA)
|
||||
|
||||
TRA takes historical prediction erros & latent representation as inputs,
|
||||
then routes the input sample to a specific predictor for training & inference.
|
||||
|
||||
Args:
|
||||
input_size (int): input size (RNN/Transformer's hidden size)
|
||||
num_states (int): number of latent states (i.e., trading patterns)
|
||||
If `num_states=1`, then TRA falls back to traditional methods
|
||||
hidden_size (int): hidden size of the router
|
||||
tau (float): gumbel softmax temperature
|
||||
src_info (str): information for the router
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
num_states=1,
|
||||
hidden_size=8,
|
||||
rnn_arch="GRU",
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
tau=1.0,
|
||||
src_info="LR_TPE",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert src_info in ["LR", "TPE", "LR_TPE"], "invalid `src_info`"
|
||||
|
||||
self.num_states = num_states
|
||||
self.tau = tau
|
||||
self.rnn_arch = rnn_arch
|
||||
self.src_info = src_info
|
||||
|
||||
self.predictors = nn.Linear(input_size, num_states)
|
||||
|
||||
if self.num_states > 1:
|
||||
if "TPE" in src_info:
|
||||
self.router = getattr(nn, rnn_arch)(
|
||||
input_size=num_states,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.fc = nn.Linear(hidden_size + input_size if "LR" in src_info else hidden_size, num_states)
|
||||
else:
|
||||
self.fc = nn.Linear(input_size, num_states)
|
||||
|
||||
def reset_parameters(self):
|
||||
for child in self.children():
|
||||
child.reset_parameters()
|
||||
|
||||
def forward(self, hidden, hist_loss):
|
||||
|
||||
preds = self.predictors(hidden)
|
||||
|
||||
if self.num_states == 1: # no need for router when having only one prediction
|
||||
return preds, None, None
|
||||
|
||||
if "TPE" in self.src_info:
|
||||
out = self.router(hist_loss)[1] # TPE
|
||||
if self.rnn_arch == "LSTM":
|
||||
out = out[0]
|
||||
out = out.mean(dim=0)
|
||||
if "LR" in self.src_info:
|
||||
out = torch.cat([hidden, out], dim=-1) # LR_TPE
|
||||
else:
|
||||
out = hidden # LR
|
||||
|
||||
out = self.fc(out)
|
||||
|
||||
choice = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=True)
|
||||
prob = torch.softmax(out / self.tau, dim=-1)
|
||||
|
||||
return preds, choice, prob
|
||||
|
||||
|
||||
def evaluate(pred):
|
||||
pred = pred.rank(pct=True) # transform into percentiles
|
||||
score = pred.score
|
||||
label = pred.label
|
||||
diff = score - label
|
||||
MSE = (diff ** 2).mean()
|
||||
MAE = (diff.abs()).mean()
|
||||
IC = score.corr(label, method="spearman")
|
||||
return {"MSE": MSE, "MAE": MAE, "IC": IC}
|
||||
|
||||
|
||||
def shoot_infs(inp_tensor):
|
||||
"""Replaces inf by maximum of tensor"""
|
||||
mask_inf = torch.isinf(inp_tensor)
|
||||
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
|
||||
if len(ind_inf) > 0:
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = 0
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = 0
|
||||
m = torch.max(inp_tensor)
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = m
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = m
|
||||
return inp_tensor
|
||||
|
||||
|
||||
def sinkhorn(Q, n_iters=3, epsilon=0.1):
|
||||
# epsilon should be adjusted according to logits value's scale
|
||||
with torch.no_grad():
|
||||
Q = torch.exp(Q / epsilon)
|
||||
Q = shoot_infs(Q)
|
||||
for i in range(n_iters):
|
||||
Q /= Q.sum(dim=0, keepdim=True)
|
||||
Q /= Q.sum(dim=1, keepdim=True)
|
||||
return Q
|
||||
|
||||
|
||||
def loss_fn(pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
if len(pred.shape) == 2:
|
||||
label = label[:, None]
|
||||
return (pred[mask] - label[mask]).pow(2).mean(dim=0)
|
||||
|
||||
|
||||
def minmax_norm(x):
|
||||
xmin = x.min(dim=-1, keepdim=True).values
|
||||
xmax = x.max(dim=-1, keepdim=True).values
|
||||
mask = (xmin == xmax).squeeze()
|
||||
x = (x - xmin) / (xmax - xmin + 1e-12)
|
||||
x[mask] = 1
|
||||
return x
|
||||
|
||||
|
||||
def transport_sample(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
|
||||
"""
|
||||
sample-wise transport
|
||||
|
||||
Args:
|
||||
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
|
||||
label (torch.Tensor): label, [sample]
|
||||
choice (torch.Tensor): gumbel softmax choice, [sample x states]
|
||||
prob (torch.Tensor): router predicted probility, [sample x states]
|
||||
hist_loss (torch.Tensor): history loss matrix, [sample x states]
|
||||
count (list): sample counts for each day, empty list for sample-wise transport
|
||||
transport_method (str): transportation method
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
training (bool): indicate training or inference
|
||||
"""
|
||||
assert all_preds.shape == choice.shape
|
||||
assert len(all_preds) == len(label)
|
||||
assert transport_method in ["oracle", "router"]
|
||||
|
||||
all_loss = torch.zeros_like(all_preds)
|
||||
mask = ~torch.isnan(label)
|
||||
all_loss[mask] = (all_preds[mask] - label[mask, None]).pow(2) # [sample x states]
|
||||
|
||||
L = minmax_norm(all_loss.detach())
|
||||
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
|
||||
Lh = minmax_norm(Lh)
|
||||
P = sinkhorn(-Lh)
|
||||
del Lh
|
||||
|
||||
if transport_method == "router":
|
||||
if training:
|
||||
pred = (all_preds * choice).sum(dim=1) # gumbel softmax
|
||||
else:
|
||||
pred = all_preds[range(len(all_preds)), prob.argmax(dim=-1)] # argmax
|
||||
else:
|
||||
pred = (all_preds * P).sum(dim=1)
|
||||
|
||||
if transport_method == "router":
|
||||
loss = loss_fn(pred, label)
|
||||
else:
|
||||
loss = (all_loss * P).sum(dim=1).mean()
|
||||
|
||||
return loss, pred, L, P
|
||||
|
||||
|
||||
def transport_daily(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
|
||||
"""
|
||||
daily transport
|
||||
|
||||
Args:
|
||||
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
|
||||
label (torch.Tensor): label, [sample]
|
||||
choice (torch.Tensor): gumbel softmax choice, [days x states]
|
||||
prob (torch.Tensor): router predicted probility, [days x states]
|
||||
hist_loss (torch.Tensor): history loss matrix, [days x states]
|
||||
count (list): sample counts for each day, [days]
|
||||
transport_method (str): transportation method
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
training (bool): indicate training or inference
|
||||
"""
|
||||
assert len(prob) == len(count)
|
||||
assert len(all_preds) == sum(count)
|
||||
assert transport_method in ["oracle", "router"]
|
||||
|
||||
all_loss = [] # loss of all predictions
|
||||
start = 0
|
||||
for i, cnt in enumerate(count):
|
||||
slc = slice(start, start + cnt) # samples from the i-th day
|
||||
start += cnt
|
||||
tloss = loss_fn(all_preds[slc], label[slc]) # loss of the i-th day
|
||||
all_loss.append(tloss)
|
||||
all_loss = torch.stack(all_loss, dim=0) # [days x states]
|
||||
|
||||
L = minmax_norm(all_loss.detach())
|
||||
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
|
||||
Lh = minmax_norm(Lh)
|
||||
P = sinkhorn(-Lh)
|
||||
del Lh
|
||||
|
||||
pred = []
|
||||
start = 0
|
||||
for i, cnt in enumerate(count):
|
||||
slc = slice(start, start + cnt) # samples from the i-th day
|
||||
start += cnt
|
||||
if transport_method == "router":
|
||||
if training:
|
||||
tpred = all_preds[slc] @ choice[i] # gumbel softmax
|
||||
else:
|
||||
tpred = all_preds[slc][:, prob[i].argmax(dim=-1)] # argmax
|
||||
else:
|
||||
tpred = all_preds[slc] @ P[i]
|
||||
pred.append(tpred)
|
||||
pred = torch.cat(pred, dim=0) # [samples]
|
||||
|
||||
if transport_method == "router":
|
||||
loss = loss_fn(pred, label)
|
||||
else:
|
||||
loss = (all_loss * P).sum(dim=1).mean()
|
||||
|
||||
return loss, pred, L, P
|
||||
|
||||
|
||||
def load_state_dict_unsafe(model, state_dict):
|
||||
"""
|
||||
Load state dict to provided model while ignore exceptions.
|
||||
"""
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(model)
|
||||
load = None # break load->load reference cycle
|
||||
|
||||
return {"unexpected_keys": unexpected_keys, "missing_keys": missing_keys, "error_msgs": error_msgs}
|
||||
|
||||
|
||||
def plot(P):
|
||||
assert isinstance(P, pd.DataFrame)
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
|
||||
P.plot.area(ax=axes[0], xlabel="")
|
||||
P.idxmax(axis=1).value_counts().sort_index().plot.bar(ax=axes[1], xlabel="")
|
||||
plt.tight_layout()
|
||||
|
||||
with io.BytesIO() as buf:
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
img = plt.imread(buf)
|
||||
plt.close()
|
||||
|
||||
return np.uint8(img * 255)
|
||||