1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-29 00:51:19 +08:00

Compare commits

...

113 Commits

Author SHA1 Message Date
you-n-g
92055d64ec Update VERSION.txt 2021-09-30 22:53:57 +08:00
you-n-g
b9809a4c33 make the prediction update more friendly (#609)
* make the prediction update more friendly

* Update test_storage.py

* LabelUpdater

* Update test_storage.py

* Update test_storage.py

* Update test_storage.py

* Update test_storage.py

* Update setup.py

* Update workflow_config_lightgbm_Alpha158.yaml

* Update workflow_config_lightgbm_Alpha158.yaml

* Update workflow_config_lightgbm_Alpha158.yaml

* Update workflow_config_lightgbm_Alpha158.yaml

* Update workflow_config_lightgbm_Alpha158.yaml

* Update setup.py

* Update setup.py

* test CI only

* test CI only

* Update workflow_config_lightgbm_Alpha158.yaml

* Update setup.py

* fix "Segmentation fault" in macos

* Update test.yml

github action no longer supported ubuntu-16.04

* Update api.rst

update doc with new_lable

* Update api.rst

Co-authored-by: Wangwuyi123 <51237097+Wangwuyi123@users.noreply.github.com>
Co-authored-by: Pengrong Zhu <zhu.pengrong@foxmail.com>
2021-09-30 20:54:44 +08:00
you-n-g
fc243fd29b Fix Models (#483)
* fix gat dataset

* fix tft model

* Update tft.py

* Fix tft.py

Co-authored-by: Pengrong Zhu <zhu.pengrong@foxmail.com>
2021-09-30 13:11:06 +08:00
demon143
b6a8bd5b80 update change doc (#623)
* Add files via upload

* Update README.md

* Update README.md

* Update README.md

* Delete change doc.gif

* Add files via upload

* Update README.md

* Delete change doc.gif

* Add files via upload

* Delete change doc.gif

* Add files via upload

* Update README.md

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2021-09-29 19:42:38 +08:00
demon143
6ee0fe366c Update initialization.rst (#622)
* Update initialization.rst

* Update docs/start/initialization.rst

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>

* Update docs/start/initialization.rst

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2021-09-27 21:44:06 +08:00
you-n-g
55b6ff123e Share version number (#620) 2021-09-27 16:12:12 +08:00
you-n-g
45ea4bae4e Add file lock for MLflowExpManager (#619) 2021-09-26 16:21:15 +08:00
demon143
17d472cf01 Update code_standard.rst (#587)
* Update code_standard.rst

* Update docs/developer/code_standard.rst

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2021-09-26 15:35:14 +08:00
you-n-g
c500a01226 update cvxpy version 2021-09-25 17:12:02 +08:00
zhupr
114c38b4c3 fix the type of filter_pipe 2021-09-20 19:04:59 +08:00
you-n-g
414c3082c0 Update model list 2021-09-18 12:57:58 +08:00
Young
3fc2f8c93c updategrade version number 2021-09-16 02:15:16 +00:00
Anurag Kumar
66ff3e5bf6 Update python-publish.yml
added python 3.9
2021-09-16 10:09:39 +08:00
Anurag Kumar
8ff68a182e Update setup.py
change to matplotlib==3.3
2021-09-16 10:09:39 +08:00
Anurag Kumar
a105ef1d76 Update setup.py
updated classifiers
2021-09-16 10:09:39 +08:00
zhupr
d02965ea70 Fix SimpleDatasetCache 2021-09-16 10:08:56 +08:00
Christian Clauss
b8d1e08010 Fix undefined names in Python code (#599)
* Update pytorch_tabnet.py

$ `flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics`
```
./qlib/qlib/contrib/model/pytorch_tabnet.py:567:38: F821 undefined name 'inp'
            self.independ.append(GLU(inp, out_dim, vbs=vbs))
                                     ^
./qlib/examples/model_rolling/task_manager_rolling.py:75:18: F821 undefined name 'task_train'
        run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
                 ^
2     F821 undefined name 'task_train'
2
```

* Fix undefined names in Python code

* from qlib.model.trainer import task_train
2021-09-14 12:13:27 +08:00
you-n-g
51709c20d8 Supporting shared processor (#596)
* Supporting shared processor

* fix readonly reverse bug

* remove pytests dependency

* with fit bug

* fix parameter error
2021-09-13 17:11:08 +08:00
Christian Clauss
28c99c77be test.yml: Remove redundant code (#595) 2021-09-13 14:31:32 +08:00
you-n-g
bb5cdfe050 Update Release Note 2021-09-12 17:06:00 +08:00
SaintMalik
fb21c591bb fix typos (#592) 2021-09-12 16:39:22 +08:00
Dong Zhou
5279e71423 Merge pull request #591 from evanzd/fix_tra
Fix TRA
2021-09-11 18:48:13 +08:00
Dong Zhou
f35254c288 update README 2021-09-10 07:38:22 +00:00
Pengrong Zhu
5e82c18cb2 Modify the Feature to be case sensitive (#589) 2021-09-10 11:47:23 +08:00
demon143
2759e8c28d Update the docs of TaskManager (#586)
* Update manage.py
2021-09-09 20:13:45 +08:00
you-n-g
2461575d30 Update README.md
Fix wrong link
2021-09-09 08:28:48 +08:00
Pengrong Zhu
867667531d Update FAQ.rst 2021-09-08 18:06:51 +08:00
zhupr
0fc52333b7 Add wheel package to github CI 2021-09-07 20:41:10 +08:00
zhupr
ab9b6dc47a Modify client-server mode and dataset-cache to disable inst_processor 2021-09-07 20:41:10 +08:00
zhupr
4c5a4d5cd7 Modify the default value in the multi_freq example 2021-09-07 20:41:10 +08:00
zhupr
e84cc23589 Add DataPathManager to QlibConfig && modify inst_processors to supports list only 2021-09-07 20:41:10 +08:00
zhupr
707399a245 Fix duplicate mlflow directories in tests 2021-09-07 20:41:10 +08:00
zhupr
6e88ccca88 Fix the index type of the multi-freq example 2021-09-07 20:41:10 +08:00
zhupr
ee5f3de800 Fix typo 2021-09-07 20:41:10 +08:00
zhupr
3605cd7b96 Add inst_processors to D.features 2021-09-07 20:41:10 +08:00
zhupr
d1cbf4c3d9 support multi-freq uri 2021-09-07 20:41:10 +08:00
zhupr
6011a21308 get_cls_kwargs renamed get_callable_kwargs 2021-09-07 20:41:10 +08:00
zhupr
76a05f37a9 add multi-freq example 2021-09-07 20:41:10 +08:00
zhupr
c99494eb76 Add sample_config to QlibDataLoader, support multi-freq 2021-09-07 20:41:10 +08:00
zhupr
e8126b0c39 Add backend_freq_config parameter, support multi-freq uri 2021-09-07 20:41:10 +08:00
Dong Zhou
8f4d320832 bug fix & use oracle transport pretrain 2021-08-30 07:32:04 +00:00
cslwqxx
e2739ac72c Update README.md 2021-08-29 12:29:11 +08:00
you-n-g
19d15ddc38 Merge pull request #513 from 2796gaurav/main
MVP for Indian Stocks in qlib using yahooquery
2021-08-26 20:59:26 +08:00
you-n-g
12af8f304b Delete .DS_Store 2021-08-26 15:36:35 +08:00
Mark Zhao
25b771ddf1 check lexsort in the 'lazy_sort_index' function (#566)
* check lexsort

* check lexsort

* lexsort comment

* lexsort comment
2021-08-25 18:07:30 +08:00
Pengrong Zhu
1158472489 Fix multi-process loop calls (#574) 2021-08-25 18:05:35 +08:00
you-n-g
84d2cb3226 Update gen.py (#576) 2021-08-25 18:05:10 +08:00
Wangwuyi123
509bfcb02e Fix CI Bug (#575)
Co-authored-by: yuxwang <anduinnn@foxmail.com>
2021-08-25 08:51:39 +08:00
demon143
6608a40965 Update ensemble.py (#560) 2021-08-14 18:07:49 +08:00
you-n-g
3e75cead93 code standard docs 2021-08-12 09:19:57 +00:00
you-n-g
6697f209d4 Conda Suggestion 2021-08-12 16:30:46 +08:00
you-n-g
e3b57b1901 Update README.md 2021-08-06 09:59:30 +08:00
you-n-g
82a5223166 Update README.md 2021-08-06 09:59:30 +08:00
ZhangTP1996
398131cff7 Update strategy.py 2021-08-05 17:21:10 +08:00
Dong Zhou
e71e2f941c fix tra when logdir is None 2021-08-02 19:02:37 +08:00
Dong Zhou
0483406c12 fix tra when logdir is None 2021-08-02 03:57:14 -07:00
Dong Zhou
da1f4db968 update README 2021-07-30 16:05:07 +08:00
Dong Zhou
a7c41b6969 improve pretrain 2021-07-30 16:05:07 +08:00
Dong Zhou
5b7b48e376 clean up 2021-07-30 16:05:07 +08:00
Dong Zhou
4f9f978909 fix TRA when use single head 2021-07-30 16:05:07 +08:00
Dong Zhou
319a2f38cc fix horizon 2021-07-30 16:05:07 +08:00
Dong Zhou
a2c38c979e format by black 2021-07-30 16:05:07 +08:00
Dong Zhou
07655f2d5b refactor TRA 2021-07-30 16:05:07 +08:00
Young
9303415666 refactor online serving rolling api 2021-07-29 18:13:12 +08:00
you-n-g
05d28469ad sort index after loader (#538)
make sure the fetch method is based on a index-sorted pd.DataFrame
2021-07-29 12:06:59 +08:00
you-n-g
dc6859bdd9 Fix docs of QlibRecorder 2021-07-26 19:00:47 +08:00
you-n-g
a6f9dde006 Update README.md 2021-07-26 18:36:09 +08:00
Young
1d22ee56d3 recorder support upload both raw file and directory 2021-07-25 16:35:16 +00:00
panshuaiyin
3810a4cd33 Update data.rst
use own alpha-factor
2021-07-22 20:07:04 +08:00
you-n-g
48af7126b6 Update news about models 2021-07-22 11:07:09 +08:00
Ying-Tao Luo
025b1dcff9 Add two new models in model zoo 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
29e66b2dea Add two new model in zoo
Add transformer and localformer (SLGT) models for time series prediction in finance in the Quant Model Zoo.
2021-07-22 11:05:39 +08:00
Ying-Tao Luo
698e59ac72 Add performance of two new models
Add the performance of transformer and localformer.
2021-07-22 11:05:39 +08:00
Ying-Tao Luo
e006ef40ad Update pytorch_localformer_ts.py 2021-07-22 11:05:39 +08:00
Young
59d4bc9394 update run_all_model and black format 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
b07e0bffb1 Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
161343018f Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
bee031af68 Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
35840606a8 Update pytorch_localformer.py 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
2df9b6e076 Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
0c3eaf3f16 Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
2eee064eb8 Add files via upload 2021-07-22 11:05:39 +08:00
Ying-Tao Luo
096ef5a62b Update pytorch_transformer.py
Have passed black
2021-07-22 11:05:39 +08:00
Ying-Tao Luo
dd0eebed53 Update pytorch_localformer.py
Have passed black.
2021-07-22 11:05:39 +08:00
Ying-Tao Luo
7b20abeda1 Add files via upload
Add naive transformer model and a improved transformer model.
2021-07-22 11:05:39 +08:00
you-n-g
5519420efd Update test_macos.yml
Give more comments about the MacOS test yaml
2021-07-21 18:30:25 +08:00
zhupr
eb3c5b3088 macos-test-ci split out separately 2021-07-21 18:25:31 +08:00
zhupr
f03df874bf fix macos-test-ci 2021-07-21 18:25:31 +08:00
2796gaurav
8fa22bd2e1 added 1min for IN and also updated readme 2021-07-21 14:16:22 +05:30
Gaurav
d1c8d885aa cleaned the code 2021-07-21 17:59:50 +05:30
zhupr
bf7732e284 fix df_features.index conține np.nan 2021-07-21 14:28:20 +08:00
wuzhe1234
3f5334ab39 Update qrun to automaticly save the config to the artifacts uri 2021-07-19 13:32:14 +08:00
zhupr
c97a96363d Add a check if change is mutated to YahooNormalize1d 2021-07-18 20:28:46 +08:00
slowy07
2023f714c9 [fixed] lgtm issue : unused imported module of 'signal' and change to PEP8 style code imported module 2021-07-18 15:25:18 +08:00
slowy07
f8a2b0533b lgtm issue: fixing unused import of 'time' 2021-07-18 15:25:18 +08:00
chaosyu
3183a232df update doc str 2021-07-18 15:24:23 +08:00
chaosyu
8b715268bd use list_kwargs instead filter_string 2021-07-18 15:24:23 +08:00
chaosyu
28cb827a23 fix lint issue 2021-07-18 15:24:23 +08:00
chaosyu
b723f14619 apply filter string to recorder collector 2021-07-18 15:24:23 +08:00
chaosyu
47535ba530 add mlflow filter string support to limit too much run number 2021-07-18 15:24:23 +08:00
Gaurav
d70e5a4f88 add YahooNormalizeIN and YahooNormalizeIN1d 2021-07-17 10:40:16 +05:30
you-n-g
3b8087677c Update online.rst 2021-07-16 12:24:33 +08:00
zhupr
4ec41ea0e7 Add a check if change is mutated to YahooNormalize1d 2021-07-15 19:13:25 +08:00
Gaurav
cfcd9fb1f8 cleaned with black 2021-07-15 11:24:41 +05:30
Gaurav
457dcaa466 cleaned with black 2021-07-14 20:12:00 +05:30
Gaurav
3c740fc2de MVP for Indian Stocks in qlib using yahooquery 2021-07-14 19:54:55 +05:30
you-n-g
6d91f28474 Update README.md 2021-07-14 10:07:02 +08:00
you-n-g
be8653c505 Update contributing section 2021-07-14 09:56:12 +08:00
chaosyu
a8974ce535 bug fix: ClientProvider cannot set connection to calendar and instrument providers 2021-07-13 10:49:21 +08:00
chaosyu
79026e5390 fix bug that duplicate rows will cause reindex failed when dumping with csv files 2021-07-13 10:49:21 +08:00
Gaurav Chauhan
4610e16ac2 updated readme of yahoo collector where region parameter was incorrect (#504)
* updated readme of yahoo collector where region parameter was incorrect

* changes

update readme of yahoo collector where region parameter was incorrect

* update readme of yahoo collector

update readme of yahoo collector where region parameter was incorrect

* updated changes

* updated readme of cn1d data

Co-authored-by: Gaurav Chauhan01/HO/Analytics/General <Gaurav.Chauhan01@bajajallianz.in>
2021-07-13 09:46:13 +08:00
wangwenxi.handsome
b504cc6ac8 update readme and rst 2021-07-12 21:51:08 +08:00
Young
d5059e609f change to dev version 2021-07-12 02:49:25 +00:00
85 changed files with 4867 additions and 623 deletions

View File

@@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
os: [windows-latest, macos-latest]
python-version: [3.6, 3.7, 3.8]
python-version: [3.6, 3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2

View File

@@ -1,4 +1,4 @@
name: Test
name: Test
on:
push:
@@ -12,8 +12,8 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04, macos-latest]
python-version: [3.6, 3.7, 3.8, 3.9]
os: [windows-latest, ubuntu-18.04, ubuntu-20.04]
python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
@@ -25,96 +25,41 @@ jobs:
- name: Lint with Black
run: |
cd ..
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install black
$CONDA\\python.exe -m black qlib -l 120 --check --diff
else
sudo $CONDA/bin/python -m pip install black
$CONDA/bin/python -m black qlib -l 120 --check --diff
fi
shell: bash
pip install --upgrade pip
pip install black wheel
black qlib -l 120 --check --diff
# Test Qlib installed with pip
- name: Install Qlib with pip
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install numpy==1.19.5
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml numpy --user
else
sudo $CONDA/bin/python -m pip install numpy==1.19.5
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
fi
shell: bash
- name: Install Lightgbm for MacOS
if: runner.os == 'macOS'
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
pip install numpy==1.19.5 ruamel.yaml
pip install pyqlib --ignore-installed
- name: Test data downloads
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
else
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
fi
shell: bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- name: Test workflow by config (install from pip)
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
$CONDA\\python.exe -m pip uninstall -y pyqlib
else
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
sudo $CONDA/bin/python -m pip uninstall -y pyqlib
fi
shell: bash
# Test Qlib installed from source
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
python -m pip uninstall -y pyqlib
# Test Qlib installed from source
- name: Install Qlib from source
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install --upgrade cython
$CONDA\\python.exe -m pip install numpy jupyter jupyter_contrib_nbextensions
$CONDA\\python.exe -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
$CONDA\\python.exe setup.py install
else
sudo $CONDA/bin/python -m pip install --upgrade cython
sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
sudo $CONDA/bin/python setup.py install
fi
shell: bash
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
pip install -e .
- name: Install test dependencies
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pip install --upgrade pip
$CONDA\\python.exe -m pip install black pytest
else
sudo $CONDA/bin/python -m pip install --upgrade pip
sudo $CONDA/bin/python -m pip install black pytest
fi
shell: bash
pip install --upgrade pip
pip install black pytest
- name: Unit tests with Pytest
run: |
cd tests
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe -m pytest . --durations=0
else
$CONDA/bin/python -m pytest . --durations=0
fi
shell: bash
python -m pytest . --durations=10
- name: Test workflow by config (install from source)
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
else
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
fi
shell: bash
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

72
.github/workflows/test_macos.yml vendored Normal file
View File

@@ -0,0 +1,72 @@
# There are some issues (in the downloading data phase) on MacOS when running with other tests. So we split it into an individual config.
name: Test MacOS
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: macos-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Lint with Black
run: |
cd ..
python -m pip install pip --upgrade
python -m pip install wheel --upgrade
python -m pip install black
python -m black qlib -l 120 --check --diff
# Test Qlib installed with pip
- name: Install Qlib with pip
run: |
python -m pip install numpy==1.19.5
python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
- name: Install Lightgbm for MacOS
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
# FIX MacOS error: Segmentation fault
# reference: https://github.com/microsoft/LightGBM/issues/4229
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
brew unlink libomp
brew install libomp.rb
- name: Test data downloads
run: |
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- name: Test workflow by config (install from pip)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
python -m pip uninstall -y pyqlib
# Test Qlib installed from source
- name: Install Qlib from source
run: |
python -m pip install --upgrade cython
python -m pip install numpy jupyter jupyter_contrib_nbextensions
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
python setup.py install
- name: Install test dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -U pyopenssl idna
python -m pip install black pytest
- name: Unit tests with Pytest
run: |
cd tests
python -m pytest . --durations=0
- name: Test workflow by config (install from source)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

1
MANIFEST.in Normal file
View File

@@ -0,0 +1 @@
include qlib/VERSION.txt

View File

@@ -11,6 +11,9 @@
Recent released features
| Feature | Status |
| -- | ------ |
|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
| TCTS Model | [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
@@ -21,8 +24,6 @@ Recent released features
Features released before 2021 are not listed here.
<p align="center">
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
</p>
@@ -43,7 +44,7 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
- [Data Preparation](#data-preparation)
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
- [**Quant Model Zoo**](#quant-model-zoo)
- [**Quant Model(Paper) Zoo**](#quant-model-paper-zoo)
- [Run a single model](#run-a-single-model)
- [Run multiple models](#run-multiple-models)
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
@@ -105,8 +106,9 @@ This table demonstrates the supported Python version of `Qlib`:
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
**Note**:
1. **Conda** is suggested for managing your Python environment.
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
### Install with pip
Users can easily install ``Qlib`` by pip according to the following command.
@@ -160,7 +162,7 @@ Users could create the same dataset with it.
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
### Automatic update of daily frequency data(from yahoo finance)
### Automatic update of daily frequency data (from yahoo finance)
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
@@ -274,7 +276,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
# [Quant Model Zoo](examples/benchmarks)
# [Quant Model (Paper) Zoo](examples/benchmarks)
Here is a list of models built on `Qlib`.
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](qlib/contrib/model/xgboost.py)
@@ -290,6 +292,9 @@ Here is a list of models built on `Qlib`.
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](qlib/contrib/model/pytorch_tabnet.py)
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](qlib/contrib/model/double_ensemble.py)
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](qlib/contrib/model/pytorch_tcts.py)
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py)
- [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py)
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](qlib/contrib/model/pytorch_tra.py)
Your PR of new Quant models is highly welcomed.
@@ -303,9 +308,10 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)
## Run multiple models
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.)
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.
@@ -370,9 +376,7 @@ Such overheads greatly slow down the data loading process.
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
# Related Reports
- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA)
- [Guide To Qlib: Microsofts AI Investment Platform](https://analyticsindiamag.com/qlib/)
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
- [微软也搞AI量化平台还是开源的](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
- [微矿Qlib业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
@@ -389,7 +393,17 @@ Join IM discussion groups:
# Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
This project welcomes contributions and suggestions.
**Here are some
[code standards](docs/developer/code_standard.rst) when you submit a pull request.**
If you want to contribute to Qlib's document, you can follow the steps in the figure below.
<p align="center">
<img src="https://github.com/demon143/qlib/blob/main/docs/_static/img/change%20doc.gif" />
</p>
Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the right to use your contribution. For details, visit https://cla.opensource.microsoft.com.

1
VERSION.txt Normal file
View File

@@ -0,0 +1 @@
0.7.2

View File

@@ -97,4 +97,57 @@ Also, feel free to post a new issue in our GitHub repository. We always check ea
python setup.py build_ext --inplace
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
4. BadNamespaceError: / is not a connected namespace
------------------------------------------------------------------------------------------------------------------------------------
.. code-block:: python
File "qlib_online.py", line 35, in <module>
cal = D.calendar()
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
return Cal.calendar(start_time, end_time, freq, future=future)
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
self.conn.send_request(
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
self.sio.emit(request_type + "_request", request_content)
File "G:\apps\miniconda\envs\qlib\lib\site-packages\python_socketio-5.3.0-py3.8.egg\socketio\client.py", line 369, in emit
raise exceptions.BadNamespaceError(
BadNamespaceError: / is not a connected namespace.
- The version of ``python-socketio`` in qlib needs to be the same as the version of ``python-socketio`` in qlib-server:
.. code-block:: bash
pip install -U python-socketio==<qlib-server python-socketio version>
5. TypeError: send() got an unexpected keyword argument 'binary'
------------------------------------------------------------------------------------------------------------------------------------
.. code-block:: python
File "qlib_online.py", line 35, in <module>
cal = D.calendar()
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
return Cal.calendar(start_time, end_time, freq, future=future)
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
self.conn.send_request(
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
self.sio.emit(request_type + "_request", request_content)
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 263, in emit
self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 339, in _send_packet
self.eio.send(ep, binary=binary)
TypeError: send() got an unexpected keyword argument 'binary'
- The ``python-engineio`` version needs to be compatible with the ``python-socketio`` version, reference: https://github.com/miguelgrinberg/python-socketio#version-compatibility
.. code-block:: bash
pip install -U python-engineio==<compatible python-socketio version>
# or
pip install -U python-socketio==3.1.2 python-engineio==3.13.2

BIN
docs/_static/img/change doc.gif vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

View File

@@ -179,6 +179,7 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
The Restoration factor. Normally, ``factor = adjusted_price / original_price``, `adjusted price` reference: `split adjusted <https://www.investopedia.com/terms/s/splitadjusted.asp>`_
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
If you want to use your own alpha-factor which can't be calculate by OCHLV, like PE, EPS and so on, you could add it to the CSV files with OHCLV together and then dump it to the Qlib format data.
Stock Pool (Market)
--------------------------------

View File

@@ -21,6 +21,8 @@ which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online S
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
**NOTE**: User should keep his data source updated to support online serving. For example, Qlib provides `a batch of scripts <https://github.com/microsoft/qlib/blob/main/scripts/data_collector/yahoo/README.md#automatic-update-of-daily-frequency-datafrom-yahoo-finance>`_ to help users update Yahoo daily data.
Online Manager
=============
@@ -43,4 +45,4 @@ Updater
=============
.. automodule:: qlib.workflow.online.update
:members:
:members:

View File

@@ -0,0 +1,22 @@
.. _code_standard:
=================================
Code Standard
=================================
Docstring
=================================
Please use the `Numpydoc Style <https://stackoverflow.com/a/24385103>`_.
Continuous Integration
=================================
Continuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
.. code-block:: python
pip install black
python -m black . -l 120

View File

@@ -241,6 +241,7 @@ Online Tool
.. automodule:: qlib.workflow.online.utils
:members:
RecordUpdater
--------------------
.. automodule:: qlib.workflow.online.update
@@ -257,4 +258,4 @@ Serializable
:members:

View File

@@ -77,7 +77,8 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
})
- `mongo`
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
Users need to follow the steps in `installation <https://www.mongodb.com/try/download/community>`_ to install MongoDB firstly and then access it via a URI.
Users can access mongodb with credential by setting "task_url" to a string like `"mongodb://%s:%s@%s" % (user, pwd, host + ":" + port)`.
.. code-block:: Python

View File

@@ -0,0 +1,16 @@
import datetime
import pandas as pd
from qlib.data.inst_processor import InstProcessor
class Resample1minProcessor(InstProcessor):
def __init__(self, hour: int, minute: int, **kwargs):
self.hour = hour
self.minute = minute
def __call__(self, df: pd.DataFrame, *args, **kwargs):
df.index = pd.to_datetime(df.index)
df = df.loc[df.index.time == datetime.time(self.hour, self.minute)]
df.index = df.index.normalize()
return df

View File

@@ -63,4 +63,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -0,0 +1,83 @@
qlib_init:
provider_uri:
day: "~/.qlib/qlib_data/cn_data"
1min: "~/.qlib/qlib_data/cn_data_1min"
region: cn
dataset_cache: null
maxtasksperchild: 1
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
# 1min closing time is 15:00:00
end_time: "2020-08-01 15:00:00"
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
freq:
label: day
feature: 1min
# with label as reference
inst_processor:
feature:
- class: Resample1minProcessor
module_path: features_sample.py
kwargs:
hour: 14
minute: 56
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -78,4 +78,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -0,0 +1,3 @@
numpy==1.17.4
pandas==1.1.2
torch==1.2.0

View File

@@ -0,0 +1,82 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LocalformerModel
module_path: qlib.contrib.model.pytorch_localformer_ts
kwargs:
seed: 0
n_jobs: 20
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,73 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LocalformerModel
module_path: qlib.contrib.model.pytorch_localformer
kwargs:
d_feat: 6
seed: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -1,6 +1,6 @@
# Benchmarks Performance
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs.
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
@@ -23,6 +23,9 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
| Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 |
| Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 |
| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0491±0.01 | 0.3868±0.06 | 0.0589±0.00 | 0.4802±0.04 | 0.0898±0.02 | 1.2490±0.32 | -0.0778±0.02 |
## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
@@ -39,6 +42,10 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
| Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 |
| Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 |
| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features)| 0.0409±0.00 | 0.3253±0.04 | 0.0488±0.00 | 0.4045±0.02 | 0.0673±0.02 | 1.0389±0.39 | -0.0830±0.02 |
| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0442±0.00 | 0.3426±0.03 | 0.0555±0.00 | 0.4395±0.03 | 0.0833±0.03 | 1.2064±0.36 | -0.0849±0.02 |
- The selected 20 features are based on the feature importance of a lightgbm-based model.
- The base model of DoubleEnsemble is LGBM.

View File

@@ -8,7 +8,7 @@
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
### Notes
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
3. The model must run in GPU, or an error will be raised.
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.

View File

@@ -1,3 +1,2 @@
tensorflow-gpu==1.15.0
numpy == 1.19.4
pandas==1.1.0
pandas==1.1.0

View File

@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
from typing import Union
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
@@ -243,7 +245,7 @@ class TFTModel(ModelFT):
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
# 0.9)
tf.keras.backend.set_session(default_keras_session)
print("Training completed.".format(dte.datetime.now()))
print("Training completed at {}.".format(dte.datetime.now()))
# ===========================Training Process===========================
def predict(self, dataset):
@@ -289,3 +291,24 @@ class TFTModel(ModelFT):
dataset for finetuning
"""
pass
def to_pickle(self, path: Union[Path, str]):
"""
Tensorflow model can't be dumped directly.
So the data should be save seperatedly
**TODO**: Please implement the function to load the files
Parameters
----------
path : Union[Path, str]
the target path to be dumped
"""
# save tensorflow model
# path = Path(path)
# path.mkdir(parents=True)
# self.model.save(path)
# save qlib model wrapper
self.model = None
super(TFTModel, self).to_pickle(path / "qlib_model")

View File

@@ -1,53 +1,77 @@
# Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport
This code provides a PyTorch implementation for TRA (Temporal Routing Adaptor), as described in the paper [Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport](http://arxiv.org/abs/2106.12950).
Temporal Routing Adaptor (TRA) is designed to capture multiple trading patterns in the stock market data. Please refer to [our paper](http://arxiv.org/abs/2106.12950) for more details.
* TRA (Temporal Routing Adaptor) is a lightweight module that consists of a set of independent predictors for learning multiple patterns as well as a router to dispatch samples to different predictors.
* We also design a learning algorithm based on Optimal Transport (OT) to obtain the optimal sample to predictor assignment and effectively optimize the router with such assignment through an auxiliary loss term.
If you find our work useful in your research, please cite:
```
@inproceedings{HengxuKDD2021,
author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
series = {KDD '21},
year = {2021},
publisher = {ACM},
}
@article{yang2020qlib,
title={Qlib: An AI-oriented Quantitative Investment Platform},
author={Yang, Xiao and Liu, Weiqing and Zhou, Dong and Bian, Jiang and Liu, Tie-Yan},
journal={arXiv preprint arXiv:2009.11189},
year={2020}
}
```
# Running TRA
## Usage (Recommended)
## Requirements
- Install `Qlib` main branch
**Update**: `TRA` has been moved to `qlib.contrib.model.pytorch_tra` to support other `Qlib` components like `qlib.workflow` and `Alpha158/Alpha360` dataset.
## Running
Please follow the official [doc](https://qlib.readthedocs.io/en/latest/component/workflow.html) to use `TRA` with `workflow`. Here we also provide several example config files:
- `workflow_config_tra_Alpha360.yaml`: running `TRA` with `Alpha360` dataset
- `workflow_config_tra_Alpha158.yaml`: running `TRA` with `Alpha158` dataset (with feature subsampling)
- `workflow_config_tra_Alpha158_full.yaml`: running `TRA` with `Alpha158` dataset (without feature subsampling)
The performances of `TRA` are reported in [Benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).
## Usage (Not Maintained)
This section is used to reproduce the results in the paper.
### Running
We attach our running scripts for the paper in `run.sh`.
And here are two ways to run the model:
* Running from scripts with default parameters
You can directly run from Qlib command `qrun`:
```
qrun configs/config_alstm.yaml
```
You can directly run from Qlib command `qrun`:
```
qrun configs/config_alstm.yaml
```
* Running from code with self-defined parameters
Setting different parameters is also allowed. See codes in `example.py`:
```
python example.py --config_file configs/config_alstm.yaml
```
Setting different parameters is also allowed. See codes in `example.py`:
```
python example.py --config_file configs/config_alstm.yaml
```
Here we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scipts.
# Results
## Outputs
### Results
After running the scripts, you can find result files in path `./output`:
`info.json` - config settings and result metrics.
* `info.json` - config settings and result metrics.
* `log.csv` - running logs.
* `model.bin` - the model parameter dictionary.
* `pred.pkl` - the prediction scores and output for inference.
`log.csv` - running logs.
Evaluation metrics reported in the paper:
`model.bin` - the model parameter dictionary.
`pred.pkl` - the prediction scores and output for inference.
## Our Results
| Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD |
|-------------------|-------------------|---------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|-------|-------|------|-----|-----|-----|-----|-----|-----|
|Linear|0.163|0.327|0.020|0.132|-3.2%|16.8%|-0.191|32.1%|
|LightGBM|0.160(0.000)|0.323(0.000)|0.041|0.292|7.8%|15.5%|0.503|25.7%|
|MLP|0.160(0.002)|0.323(0.003)|0.037|0.273|3.7%|15.3%|0.264|26.2%|
@@ -61,21 +85,8 @@ After running the scripts, you can find result files in path `./output`:
A more detailed demo for our experiment results in the paper can be found in `Report.ipynb`.
# Common Issues
## Common Issues
For help or issues using TRA, please submit a GitHub issue.
Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.
# Citation
If you find this repository useful in your research, please cite:
```
@inproceedings{HengxuKDD2021,
author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
series = {KDD '21},
year = {2021},
publisher = {ACM},
}
```
Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.

View File

@@ -0,0 +1,129 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
num_states: &num_states 3
memory_mode: &memory_mode sample
tra_config: &tra_config
num_states: *num_states
rnn_arch: LSTM
hidden_size: 32
num_layers: 1
dropout: 0.0
tau: 1.0
src_info: LR_TPE
model_config: &model_config
input_size: 20
hidden_size: 64
num_layers: 2
rnn_arch: LSTM
use_attn: True
dropout: 0.0
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TRAModel
module_path: qlib.contrib.model.pytorch_tra
kwargs:
tra_config: *tra_config
model_config: *model_config
model_type: RNN
lr: 1e-3
n_epochs: 100
max_steps_per_epoch:
early_stop: 20
logdir: output/Alpha158
seed: 0
lamb: 1.0
rho: 0.99
alpha: 0.5
transport_method: router
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: True
init_state:
freeze_model: False
freeze_predictors: False
dataset:
class: MTSDatasetH
module_path: qlib.contrib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size:
num_states: *num_states
batch_size: 1024
n_samples:
memory_mode: *memory_mode
drop_last: True
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,123 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
num_states: &num_states 3
memory_mode: &memory_mode sample
tra_config: &tra_config
num_states: *num_states
rnn_arch: LSTM
hidden_size: 32
num_layers: 1
dropout: 0.0
tau: 1.0
src_info: LR_TPE
model_config: &model_config
input_size: 158
hidden_size: 256
num_layers: 2
rnn_arch: LSTM
use_attn: True
dropout: 0.2
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TRAModel
module_path: qlib.contrib.model.pytorch_tra
kwargs:
tra_config: *tra_config
model_config: *model_config
model_type: RNN
lr: 1e-3
n_epochs: 100
max_steps_per_epoch:
early_stop: 20
logdir: output/Alpha158_full
seed: 0
lamb: 1.0
rho: 0.99
alpha: 0.5
transport_method: router
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: True
init_state:
freeze_model: False
freeze_predictors: False
dataset:
class: MTSDatasetH
module_path: qlib.contrib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size:
num_states: *num_states
batch_size: 1024
n_samples:
memory_mode: *memory_mode
drop_last: True
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,123 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
num_states: &num_states 3
memory_mode: &memory_mode sample
tra_config: &tra_config
num_states: *num_states
rnn_arch: LSTM
hidden_size: 32
num_layers: 1
dropout: 0.0
tau: 1.0
src_info: LR_TPE
model_config: &model_config
input_size: 6
hidden_size: 64
num_layers: 2
rnn_arch: LSTM
use_attn: True
dropout: 0.0
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TRAModel
module_path: qlib.contrib.model.pytorch_tra
kwargs:
tra_config: *tra_config
model_config: *model_config
model_type: RNN
lr: 1e-3
n_epochs: 100
max_steps_per_epoch:
early_stop: 20
logdir: output/Alpha360
seed: 0
lamb: 1.0
rho: 0.99
alpha: 0.5
transport_method: router
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: True
init_state:
freeze_model: False
freeze_predictors: False
dataset:
class: MTSDatasetH
module_path: qlib.contrib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size: 6
num_states: *num_states
batch_size: 1024
n_samples:
memory_mode: *memory_mode
drop_last: True
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,3 @@
numpy==1.17.4
pandas==1.1.2
torch==1.2.0

View File

@@ -0,0 +1,82 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TransformerModel
module_path: qlib.contrib.model.pytorch_transformer_ts
kwargs:
seed: 0
n_jobs: 20
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,73 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TransformerModel
module_path: qlib.contrib.model.pytorch_transformer
kwargs:
d_feat: 6
seed: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -99,8 +99,6 @@ class HighFreqHandler(DataHandlerLP):
]
names += ["$volume_1"]
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
names += ["date"]
return fields, names

View File

@@ -33,6 +33,9 @@ class HighFreqNorm(Processor):
self.feature_vmin[name] = np.nanmin(part_values)
def __call__(self, df_features):
df_features["date"] = pd.to_datetime(
df_features.index.get_level_values(level="datetime").to_series().dt.date.values
)
df_features.set_index("date", append=True, drop=True, inplace=True)
df_values = df_features.values
names = {

View File

@@ -0,0 +1 @@
xgboost

View File

@@ -17,7 +17,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
from qlib.model.trainer import TrainerRM, task_train
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG

View File

@@ -23,7 +23,6 @@ from qlib.config import REG_CN
from qlib.workflow import R
from qlib.tests.data import GetData
# init qlib
provider_uri = "~/.qlib/qlib_data/cn_data"
exp_folder_name = "run_all_model_records"
@@ -40,6 +39,7 @@ exp_manager = {
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
# decorator to check the arguments
def only_allow_defined_args(function_to_decorate):
@functools.wraps(function_to_decorate)
@@ -92,7 +92,8 @@ def create_env():
# function to execute the cmd
def execute(cmd):
def execute(cmd, wait_when_err=False):
print("Running CMD:", cmd)
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
for line in p.stdout:
sys.stdout.write(line.split("\b")[0])
@@ -102,6 +103,8 @@ def execute(cmd):
sys.stdout.write("\b" * 10 + "\b".join(line.split("\b")[1:-1]))
if p.returncode != 0:
if wait_when_err:
input("Press Enter to Continue")
return p.stderr
else:
return None
@@ -184,7 +187,15 @@ def gen_and_save_md_table(metrics, dataset):
# function to run the all the models
@only_allow_defined_args
def run(times=1, models=None, dataset="Alpha360", exclude=False):
def run(
times=1,
models=None,
dataset="Alpha360",
exclude=False,
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
wait_before_rm_env: bool = False,
wait_when_err: bool = False,
):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parrallel running the same model
@@ -200,6 +211,13 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
determines whether the model being used is excluded or included.
dataset : str
determines the dataset to be used for each model.
qlib_uri : str
the uri to install qlib with pip
it could be url on the we or local path
wait_before_rm_env : bool
wait before remove environment.
wait_when_err : bool
wait when errors raised when executing commands
Usage:
-------
@@ -240,32 +258,36 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
sys.stderr.write("\n")
# install requirements.txt
sys.stderr.write("Installing requirements.txt...\n")
execute(f"{python_path} -m pip install -r {req_path}")
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
sys.stderr.write("\n")
# setup gpu for tft
if fn == "TFT":
execute(
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn"
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
wait_when_err=wait_when_err,
)
sys.stderr.write("\n")
# install qlib
sys.stderr.write("Installing qlib...\n")
execute(f"{python_path} -m pip install --upgrade pip") # TODO: FIX ME!
execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
if fn == "TFT":
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e git+https://github.com/microsoft/qlib#egg=pyqlib"
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
else:
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e git+https://github.com/microsoft/qlib#egg=pyqlib"
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
wait_when_err=wait_when_err,
) # TODO: FIX ME!
sys.stderr.write("\n")
# run workflow_by_config for multiple times
for i in range(times):
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
errs = execute(
f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
wait_when_err=wait_when_err,
)
if errs is not None:
_errs = errors.get(fn, {})
@@ -274,6 +296,8 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
sys.stderr.write("\n")
# remove env
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
if wait_before_rm_env:
input("Press Enter to Continue")
shutil.rmtree(env_path)
# getting all results
sys.stderr.write(f"Retrieving results...\n")

View File

@@ -1,17 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
__version__ = "0.7.0"
_version_path = Path(__file__).absolute().parent / "VERSION.txt" # This file is copyed from setup.py
__version__ = _version_path.read_text(encoding="utf-8").strip()
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
import yaml
import logging
import platform
import subprocess
from pathlib import Path
from .log import get_module_logger
@@ -33,69 +31,70 @@ def init(default_conf="client", **kwargs):
H.clear()
C.set(default_conf, **kwargs)
# check path if server/local
if C.get_uri_type() == C.LOCAL_URI:
if not os.path.exists(C["provider_uri"]):
if C["auto_mount"]:
logger.error(
f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
)
else:
logger.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
elif C.get_uri_type() == C.NFS_URI:
_mount_nfs_uri(C)
else:
raise NotImplementedError(f"This type of URI is not supported")
# mount nfs
for _freq, provider_uri in C.provider_uri.items():
mount_path = C["mount_path"][_freq]
# check path if server/local
uri_type = C.dpm.get_uri_type(provider_uri)
if uri_type == C.LOCAL_URI:
if not Path(provider_uri).exists():
if C["auto_mount"]:
logger.error(
f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist."
)
else:
logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
elif uri_type == C.NFS_URI:
_mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
else:
raise NotImplementedError(f"This type of URI is not supported")
C.register()
if "flask_server" in C:
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
logger.info("qlib successfully initialized based on %s settings." % default_conf)
logger.info(f"data_path={C.get_data_path()}")
data_path = {_freq: C.dpm.get_data_path(_freq) for _freq in C.dpm.provider_uri.keys()}
logger.info(f"data_path={data_path}")
def _mount_nfs_uri(C):
def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
LOG = get_module_logger("mount nfs", level=logging.INFO)
# FIXME: the C["provider_uri"] is modified in this function
# If it is not modified, we can pass only provider_uri or mount_path instead of C
mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"])
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
# If the provider uri looks like this 172.23.233.89//data/csdesign'
# It will be a nfs path. The client provider will be used
if not C["auto_mount"]:
if not os.path.exists(C["mount_path"]):
if not auto_mount:
if not Path(mount_path).exists():
raise FileNotFoundError(
f"Invalid mount path: {C['mount_path']}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
)
else:
# Judging system type
sys_type = platform.system()
if "win" in sys_type.lower():
# system: window
exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":"))
exec_result = os.popen("mount -o anon %s %s" % (provider_uri, mount_path + ":"))
result = exec_result.read()
if "85" in result:
LOG.warning("already mounted or window mount path already exists")
LOG.warning(f"{provider_uri} on Windows:{mount_path} is already mounted")
elif "53" in result:
raise OSError("not find network path")
elif "error" in result or "错误" in result:
raise OSError("Invalid mount path")
elif C["provider_uri"] in result:
elif provider_uri in result:
LOG.info("window success mount..")
else:
raise OSError(f"unknown error: {result}")
# config mount path
C["mount_path"] = C["mount_path"] + ":\\"
else:
# system: linux/Unix/Mac
# check mount
_remote_uri = C["provider_uri"]
_remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri
_mount_path = C["mount_path"]
_mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path
_remote_uri = provider_uri[:-1] if provider_uri.endswith("/") else provider_uri
_mount_path = mount_path[:-1] if mount_path.endswith("/") else mount_path
_check_level_num = 2
_is_mount = False
while _check_level_num:
@@ -121,11 +120,9 @@ def _mount_nfs_uri(C):
if not _is_mount:
try:
os.makedirs(C["mount_path"], exist_ok=True)
Path(mount_path).mkdir(parents=True, exist_ok=True)
except Exception:
raise OSError(
f"Failed to create directory {C['mount_path']}, please create {C['mount_path']} manually!"
)
raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
# check nfs-common
command_res = os.popen("dpkg -l | grep nfs-common")
@@ -136,11 +133,11 @@ def _mount_nfs_uri(C):
command_status = os.system(mount_command)
if command_status == 256:
raise OSError(
f"mount {C['provider_uri']} on {C['mount_path']} error! Needs SUDO! Please mount manually: {mount_command}"
f"mount {provider_uri} on {mount_path} error! Needs SUDO! Please mount manually: {mount_command}"
)
elif command_status == 32512:
# LOG.error("Command error")
raise OSError(f"mount {C['provider_uri']} on {C['mount_path']} error! Command error")
raise OSError(f"mount {provider_uri} on {mount_path} error! Command error")
elif command_status == 0:
LOG.info("Mount finished")
else:

View File

@@ -15,8 +15,10 @@ import os
import re
import copy
import logging
import platform
import multiprocessing
from pathlib import Path
from typing import Union
class Config:
@@ -73,6 +75,12 @@ REG_US = "us"
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
DISK_DATASET_CACHE = "DiskDatasetCache"
SIMPLE_DATASET_CACHE = "SimpleDatasetCache"
DISK_EXPRESSION_CACHE = "DiskExpressionCache"
DEPENDENCY_REDIS_CACHE = (DISK_DATASET_CACHE, DISK_EXPRESSION_CACHE)
_default_config = {
# data provider config
"calendar_provider": "LocalCalendarProvider",
@@ -82,6 +90,15 @@ _default_config = {
"dataset_provider": "LocalDatasetProvider",
"provider": "LocalProvider",
# config it in qlib.init()
# "provider_uri" str or dict:
# # str
# "~/.qlib/stock_data/cn_data"
# # dict
# {"day": "~/.qlib/stock_data/cn_data", "1min": "~/.qlib/stock_data/cn_data_1min"}
# NOTE: provider_uri priority
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
# 3. qlib.init: provider_uri
"provider_uri": "",
# cache
"expression_cache": None,
@@ -167,8 +184,9 @@ MODE_CONF = {
"redis_task_db": 1,
"kernels": NUM_USABLE_CPU,
# cache
"expression_cache": "DiskExpressionCache",
"dataset_cache": "DiskDatasetCache",
"expression_cache": DISK_EXPRESSION_CACHE,
"dataset_cache": DISK_DATASET_CACHE,
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
"mount_path": None,
},
"client": {
@@ -183,8 +201,10 @@ MODE_CONF = {
"provider_uri": "~/.qlib/qlib_data/cn_data",
# cache
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
"expression_cache": "DiskExpressionCache",
"dataset_cache": "DiskDatasetCache",
"expression_cache": DISK_EXPRESSION_CACHE,
"dataset_cache": DISK_DATASET_CACHE,
# SimpleDatasetCache directory
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
"calendar_cache": None,
# client config
"kernels": NUM_USABLE_CPU,
@@ -228,11 +248,43 @@ class QlibConfig(Config):
# URI_TYPE
LOCAL_URI = "local"
NFS_URI = "nfs"
DEFAULT_FREQ = "__DEFAULT_FREQ"
def __init__(self, default_conf):
super().__init__(default_conf)
self._registered = False
class DataPathManager:
def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]):
self.provider_uri = provider_uri
self.mount_path = mount_path
@staticmethod
def get_uri_type(uri: Union[str, Path]):
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:'
# such as 'host:/data/' (User may define short hostname by themselves or use localhost)
is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None
if is_nfs_or_win and not is_win:
return QlibConfig.NFS_URI
else:
return QlibConfig.LOCAL_URI
def get_data_path(self, freq: str = None) -> Path:
if freq is None or freq not in self.provider_uri:
freq = QlibConfig.DEFAULT_FREQ
_provider_uri = self.provider_uri[freq]
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
return Path(_provider_uri)
elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
if "win" in platform.system().lower():
# windows, mount_path is the drive
return Path(f"{self.mount_path[freq]}:\\")
return Path(self.mount_path[freq])
else:
raise NotImplementedError(f"This type of uri is not supported")
def set_mode(self, mode):
# raise KeyError
self.update(MODE_CONF[mode])
@@ -242,32 +294,43 @@ class QlibConfig(Config):
# raise KeyError
self.update(_default_region_config[region])
@staticmethod
def is_depend_redis(cache_name: str):
return cache_name in DEPENDENCY_REDIS_CACHE
@property
def dpm(self):
return self.DataPathManager(self["provider_uri"], self["mount_path"])
def resolve_path(self):
# resolve path
if self["mount_path"] is not None:
self["mount_path"] = str(Path(self["mount_path"]).expanduser().resolve())
_mount_path = self["mount_path"]
_provider_uri = self["provider_uri"]
if _provider_uri is None:
raise ValueError("provider_uri cannot be None")
if not isinstance(_provider_uri, dict):
_provider_uri = {self.DEFAULT_FREQ: _provider_uri}
if not isinstance(_mount_path, dict):
_mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}
if self.get_uri_type() == QlibConfig.LOCAL_URI:
self["provider_uri"] = str(Path(self["provider_uri"]).expanduser().resolve())
# check provider_uri and mount_path
_miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
def get_uri_type(self):
is_win = re.match("^[a-zA-Z]:.*", self["provider_uri"]) is not None # such as 'C:\\data', 'D:'
is_nfs_or_win = (
re.match("^[^/]+:.+", self["provider_uri"]) is not None
) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
# resolve
for _freq, _uri in _provider_uri.items():
# provider_uri
if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
_provider_uri[_freq] = str(Path(_uri).expanduser().resolve())
# mount_path
_mount_path[_freq] = (
_mount_path[_freq]
if _mount_path[_freq] is None
else str(Path(_mount_path[_freq]).expanduser().resolve())
)
if is_nfs_or_win and not is_win:
return QlibConfig.NFS_URI
else:
return QlibConfig.LOCAL_URI
def get_data_path(self):
if self.get_uri_type() == QlibConfig.LOCAL_URI:
return self["provider_uri"]
elif self.get_uri_type() == QlibConfig.NFS_URI:
return self["mount_path"]
else:
raise NotImplementedError(f"This type of uri is not supported")
self["provider_uri"] = _provider_uri
self["mount_path"] = _mount_path
def set(self, default_conf="client", **kwargs):
from .utils import set_log_with_config, get_module_logger, can_use_cache
@@ -299,11 +362,20 @@ class QlibConfig(Config):
if not (self["expression_cache"] is None and self["dataset_cache"] is None):
# check redis
if not can_use_cache():
logger.warning(
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), cache will not be used!"
)
self["expression_cache"] = None
self["dataset_cache"] = None
log_str = ""
# check expression cache
if self.is_depend_redis(self["expression_cache"]):
log_str += self["expression_cache"]
self["expression_cache"] = None
# check dataset cache
if self.is_depend_redis(self["dataset_cache"]):
log_str += f" and {self['dataset_cache']}" if log_str else self["dataset_cache"]
self["dataset_cache"] = None
if log_str:
logger.warning(
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), "
f"{log_str} will not be used!"
)
def register(self):
from .utils import init_instance_by_config

View File

@@ -16,6 +16,7 @@ def get_benchmark_weight(
start_date=None,
end_date=None,
path=None,
freq="day",
):
"""get_benchmark_weight
@@ -25,6 +26,7 @@ def get_benchmark_weight(
:param start_date:
:param end_date:
:param path:
:param freq:
:return: The weight distribution of the the benchmark described by a pandas dataframe
Every row corresponds to a trading day.
@@ -33,7 +35,7 @@ def get_benchmark_weight(
"""
if not path:
path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
path = Path(C.dpm.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
# TODO: the storage of weights should be implemented in a more elegent way
# TODO: The benchmark is not consistant with the filename in instruments.
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
@@ -222,6 +224,7 @@ def brinson_pa(
group_method="category",
group_n=None,
deal_price="vwap",
freq="day",
):
"""brinson profit attribution
@@ -243,7 +246,7 @@ def brinson_pa(
start_date, end_date = min(dates), max(dates)
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date)
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date, freq)
# The attributes for allocation will not
if not group_field.startswith("$"):
@@ -259,13 +262,14 @@ def brinson_pa(
start_time=shift_start_date,
end_time=end_date,
as_list=True,
freq=freq,
)
stock_df = D.features(
instruments,
[group_field, deal_price],
start_time=shift_start_date,
end_time=end_date,
freq="day",
freq=freq,
)
stock_df.columns = [group_field, "deal_price"]

View File

@@ -0,0 +1,346 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import torch
import warnings
import numpy as np
import pandas as pd
from qlib.utils import init_instance_by_config
from qlib.data.dataset import DatasetH, DataHandler
device = "cuda" if torch.cuda.is_available() else "cpu"
def _to_tensor(x):
if not isinstance(x, torch.Tensor):
return torch.tensor(x, dtype=torch.float, device=device)
return x
def _create_ts_slices(index, seq_len):
"""
create time series slices from pandas index
Args:
index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
seq_len (int): sequence length
"""
assert isinstance(index, pd.MultiIndex), "unsupported index type"
assert seq_len > 0, "sequence length should be larger than 0"
assert index.is_monotonic_increasing, "index should be sorted"
# number of dates for each instrument
sample_count_by_insts = index.to_series().groupby(level=0).size().values
# start index for each instrument
start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1)
start_index_of_insts[0] = 0
# all the [start, stop) indices of features
# features between [start, stop) will be used to predict label at `stop - 1`
slices = []
for cur_loc, cur_cnt in zip(start_index_of_insts, sample_count_by_insts):
for stop in range(1, cur_cnt + 1):
end = cur_loc + stop
start = max(end - seq_len, 0)
slices.append(slice(start, end))
slices = np.array(slices, dtype="object")
assert len(slices) == len(index) # the i-th slice = index[i]
return slices
def _get_date_parse_fn(target):
"""get date parse function
This method is used to parse date arguments as target type.
Example:
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
get_date_parse_fn(20120101)('2017-01-01') => 20170101
"""
if isinstance(target, pd.Timestamp):
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
elif isinstance(target, int):
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
elif isinstance(target, str) and len(target) == 8:
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
else:
_fn = lambda x: x # '2021-01-01'
return _fn
def _maybe_padding(x, seq_len, zeros=None):
"""padding 2d <time * feature> data with zeros
Args:
x (np.ndarray): 2d data with shape <time * feature>
seq_len (int): target sequence length
zeros (np.ndarray): zeros with shape <seq_len * feature>
"""
assert seq_len > 0, "sequence length should be larger than 0"
if zeros is None:
zeros = np.zeros((seq_len, x.shape[1]), dtype=np.float32)
else:
assert len(zeros) >= seq_len, "zeros matrix is not large enough for padding"
if len(x) != seq_len: # padding zeros
x = np.concatenate([zeros[: seq_len - len(x), : x.shape[1]], x], axis=0)
return x
class MTSDatasetH(DatasetH):
"""Memory Augmented Time Series Dataset
Args:
handler (DataHandler): data handler
segments (dict): data split segments
seq_len (int): time series sequence length
horizon (int): label horizon
num_states (int): how many memory states to be added
memory_mode (str): memory mode (daily or sample)
batch_size (int): batch size (<0 will use daily sampling)
n_samples (int): number of samples in the same day
shuffle (bool): whether shuffle data
drop_last (bool): whether drop last batch < batch_size
input_size (int): reshape flatten rows as this input_size (backward compatibility)
"""
def __init__(
self,
handler,
segments,
seq_len=60,
horizon=0,
num_states=0,
memory_mode="sample",
batch_size=-1,
n_samples=None,
shuffle=True,
drop_last=False,
input_size=None,
**kwargs,
):
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"
assert memory_mode in ["sample", "daily"], "unsupported memory mode"
assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)"
assert batch_size != 0, "invalid batch size"
if batch_size > 0 and n_samples is not None:
warnings.warn("`n_samples` can only be used for daily sampling (`batch_size < 0`)")
self.seq_len = seq_len
self.horizon = horizon
self.num_states = num_states
self.memory_mode = memory_mode
self.batch_size = batch_size
self.n_samples = n_samples
self.shuffle = shuffle
self.drop_last = drop_last
self.input_size = input_size
self.params = (batch_size, n_samples, drop_last, shuffle) # for train/eval switch
super().__init__(handler, segments, **kwargs)
def setup_data(self, handler_kwargs: dict = None, **kwargs):
super().setup_data(**kwargs)
if handler_kwargs is not None:
self.handler.setup_data(**handler_kwargs)
# pre-fetch data and change index to <code, date>
# NOTE: we will use inplace sort to reduce memory use
try:
df = self.handler._learn.copy() # use copy otherwise recorder will fail
# FIXME: currently we cannot support switching from `_learn` to `_infer` for inference
except:
warnings.warn("cannot access `_learn`, will load raw data")
df = self.handler._data.copy()
df.index = df.index.swaplevel()
df.sort_index(inplace=True)
# convert to numpy
self._data = df["feature"].values.astype("float32")
np.nan_to_num(self._data, copy=False) # NOTE: fillna in case users forget using the fillna processor
self._label = df["label"].squeeze().values.astype("float32")
self._index = df.index
if self.input_size is not None and self.input_size != self._data.shape[1]:
warnings.warn("the data has different shape from input_size and the data will be reshaped")
assert self._data.shape[1] % self.input_size == 0, "data mismatch, please check `input_size`"
# create batch slices
self._batch_slices = _create_ts_slices(self._index, self.seq_len)
# create daily slices
daily_slices = {date: [] for date in sorted(self._index.unique(level=1))} # sorted by date
for i, (code, date) in enumerate(self._index):
daily_slices[date].append(self._batch_slices[i])
self._daily_slices = np.array(list(daily_slices.values()), dtype="object")
self._daily_index = pd.Series(list(daily_slices.keys())) # index is the original date index
# add memory (sample wise and daily)
if self.memory_mode == "sample":
self._memory = np.zeros((len(self._data), self.num_states), dtype=np.float32)
elif self.memory_mode == "daily":
self._memory = np.zeros((len(self._daily_index), self.num_states), dtype=np.float32)
else:
raise ValueError(f"invalid memory_mode `{self.memory_mode}`")
# padding tensor
self._zeros = np.zeros((self.seq_len, max(self.num_states, self._data.shape[1])), dtype=np.float32)
def _prepare_seg(self, slc, **kwargs):
fn = _get_date_parse_fn(self._index[0][1])
start_date = fn(slc.start)
end_date = fn(slc.stop)
obj = copy.copy(self) # shallow copy
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
obj._data = self._data # reference (no copy)
obj._label = self._label
obj._index = self._index
obj._memory = self._memory
obj._zeros = self._zeros
# update index for this batch
date_index = self._index.get_level_values(1)
obj._batch_slices = self._batch_slices[(date_index >= start_date) & (date_index <= end_date)]
mask = (self._daily_index.values >= start_date) & (self._daily_index.values <= end_date)
obj._daily_slices = self._daily_slices[mask]
obj._daily_index = self._daily_index[mask]
return obj
def restore_index(self, index):
return self._index[index]
def restore_daily_index(self, daily_index):
return pd.Index(self._daily_index.loc[daily_index])
def assign_data(self, index, vals):
if self.num_states == 0:
raise ValueError("cannot assign data as `num_states==0`")
if isinstance(vals, torch.Tensor):
vals = vals.detach().cpu().numpy()
self._memory[index] = vals
def clear_memory(self):
if self.num_states == 0:
raise ValueError("cannot clear memory as `num_states==0`")
self._memory[:] = 0
def train(self):
"""enable traning mode"""
self.batch_size, self.n_samples, self.drop_last, self.shuffle = self.params
def eval(self):
"""enable evaluation mode"""
self.batch_size = -1
self.n_samples = None
self.drop_last = False
self.shuffle = False
def _get_slices(self):
if self.batch_size < 0: # daily sampling
slices = self._daily_slices.copy()
batch_size = -1 * self.batch_size
else: # normal sampling
slices = self._batch_slices.copy()
batch_size = self.batch_size
return slices, batch_size
def __len__(self):
slices, batch_size = self._get_slices()
if self.drop_last:
return len(slices) // batch_size
return (len(slices) + batch_size - 1) // batch_size
def __iter__(self):
slices, batch_size = self._get_slices()
indices = np.arange(len(slices))
if self.shuffle:
np.random.shuffle(indices)
for i in range(len(indices))[::batch_size]:
if self.drop_last and i + batch_size > len(indices):
break
data = [] # store features
label = [] # store labels
index = [] # store index
state = [] # store memory states
daily_index = [] # store daily index
daily_count = [] # store number of samples for each day
for j in indices[i : i + batch_size]:
# normal sampling: self.batch_size > 0 => slices is a list => slices_subset is a slice
# daily sampling: self.batch_size < 0 => slices is a nested list => slices_subset is a list
slices_subset = slices[j]
# daily sampling
# each slices_subset contains a list of slices for multiple stocks
# NOTE: daily sampling is used in 1) eval mode, 2) train mode with self.batch_size < 0
if self.batch_size < 0:
# store daily index
idx = self._daily_index.index[j] # daily_index.index is the index of the original data
daily_index.append(idx)
# store daily memory if specified
# NOTE: daily memory always requires daily sampling (self.batch_size < 0)
if self.memory_mode == "daily":
slc = slice(max(idx - self.seq_len - self.horizon, 0), max(idx - self.horizon, 0))
state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros))
# down-sample stocks and store count
if self.n_samples and 0 < self.n_samples < len(slices_subset): # intraday subsample
slices_subset = np.random.choice(slices_subset, self.n_samples, replace=False)
daily_count.append(len(slices_subset))
# normal sampling
# each slices_subset is a single slice
# NOTE: normal sampling is used in train mode with self.batch_size > 0
else:
slices_subset = [slices_subset]
for slc in slices_subset:
# legacy support for Alpha360 data by `input_size`
if self.input_size:
data.append(self._data[slc.stop - 1].reshape(self.input_size, -1).T)
else:
data.append(_maybe_padding(self._data[slc], self.seq_len, self._zeros))
if self.memory_mode == "sample":
state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros)[: -self.horizon])
label.append(self._label[slc.stop - 1])
index.append(slc.stop - 1)
# end slices loop
# end indices batch loop
# concate
data = _to_tensor(np.stack(data))
state = _to_tensor(np.stack(state))
label = _to_tensor(np.stack(label))
index = np.array(index)
daily_index = np.array(daily_index)
daily_count = np.array(daily_count)
# yield -> generator
yield {
"data": data,
"label": label,
"state": state,
"index": index,
"daily_index": daily_index,
"daily_count": daily_count,
}
# end indice loop

View File

@@ -3,7 +3,7 @@
from ...data.dataset.handler import DataHandlerLP
from ...data.dataset.processor import Processor
from ...utils import get_cls_kwargs
from ...utils import get_callable_kwargs
from ...data.dataset import processor as processor_module
from ...log import TimeInspector
from inspect import getfullargspec
@@ -14,7 +14,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
new_l = []
for p in proc_l:
if not isinstance(p, Processor):
klass, pkwargs = get_cls_kwargs(p, processor_module)
klass, pkwargs = get_callable_kwargs(p, processor_module)
args = getfullargspec(klass).args
if "fit_start_time" in args and "fit_end_time" in args:
assert (
@@ -58,6 +58,7 @@ class Alpha360(DataHandlerLP):
fit_start_time=None,
fit_end_time=None,
filter_pipe=None,
inst_processor=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -72,6 +73,7 @@ class Alpha360(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processor": inst_processor,
},
}
@@ -144,6 +146,7 @@ class Alpha158(DataHandlerLP):
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processor=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
@@ -158,6 +161,7 @@ class Alpha158(DataHandlerLP):
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processor": inst_processor,
},
}
super().__init__(

View File

@@ -27,7 +27,6 @@ from ...contrib.model.pytorch_gru import GRUModel
class DailyBatchSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
# calculate number of samples in each batch
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values

View File

@@ -0,0 +1,331 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
import math
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
from torch.nn.modules.container import ModuleList
# qrun examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml ”
class LocalformerModel(Model):
def __init__(
self,
d_feat: int = 20,
d_model: int = 64,
batch_size: int = 2048,
nhead: int = 2,
num_layers: int = 2,
dropout: float = 0,
n_epochs=100,
lr=0.0001,
metric="",
early_stop=5,
loss="mse",
optimizer="adam",
reg=1e-3,
n_jobs=10,
GPU=0,
seed=None,
**kwargs
):
# set hyper-parameters.
self.d_model = d_model
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
self.reg = reg
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.n_jobs = n_jobs
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger = get_module_logger("TransformerModel")
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred.float() - label.float()) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def train_epoch(self, x_train, y_train):
x_train_values = x_train.values
y_train_values = np.squeeze(y_train.values)
self.model.train()
indices = np.arange(len(x_train_values))
np.random.shuffle(indices)
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.model(feature)
loss = self.loss_fn(pred, label)
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_x, data_y):
# prepare training data
x_values = data_x.values
y_values = np.squeeze(data_y.values)
self.model.eval()
scores = []
losses = []
indices = np.arange(len(x_values))
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
with torch.no_grad():
pred = self.model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(x_train, y_train)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(x_train, y_train)
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.model.eval()
x_values = x_test.values
sample_num = x_values.shape[0]
preds = []
for begin in range(sample_num)[:: self.batch_size]:
if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
pred = self.model(x_batch).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
# [T, N, F]
return x + self.pe[: x.size(0), :]
def _get_clones(module, N):
return ModuleList([copy.deepcopy(module) for i in range(N)])
class LocalformerEncoder(nn.Module):
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, d_model):
super(LocalformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.conv = _get_clones(nn.Conv1d(d_model, d_model, 3, 1, 1), num_layers)
self.num_layers = num_layers
def forward(self, src, mask):
output = src
out = src
for i, mod in enumerate(self.layers):
# [T, N, F] --> [N, T, F] --> [N, F, T]
out = output.transpose(1, 0).transpose(2, 1)
out = self.conv[i](out).transpose(2, 1).transpose(1, 0)
output = mod(output + out, src_mask=mask)
return output + out
class Transformer(nn.Module):
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
super(Transformer, self).__init__()
self.rnn = nn.GRU(
input_size=d_model,
hidden_size=d_model,
num_layers=num_layers,
batch_first=False,
dropout=dropout,
)
self.feature_layer = nn.Linear(d_feat, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model)
self.decoder_layer = nn.Linear(d_model, 1)
self.device = device
self.d_feat = d_feat
def forward(self, src):
# src [N, F*T] --> [N, T, F]
src = src.reshape(len(src), self.d_feat, -1).permute(0, 2, 1)
src = self.feature_layer(src)
# src [N, T, F] --> [T, N, F], [60, 512, 8]
src = src.transpose(1, 0) # not batch first
mask = None
src = self.pos_encoder(src)
output = self.transformer_encoder(src, mask) # [60, 512, 8]
output, _ = self.rnn(output)
# [T, N, F] --> [N, T*F]
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
return output.squeeze()

View File

@@ -0,0 +1,308 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
import math
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
from torch.nn.modules.container import ModuleList
class LocalformerModel(Model):
def __init__(
self,
d_feat: int = 20,
d_model: int = 64,
batch_size: int = 8192,
nhead: int = 2,
num_layers: int = 2,
dropout: float = 0,
n_epochs=100,
lr=0.0001,
metric="",
early_stop=5,
loss="mse",
optimizer="adam",
reg=1e-3,
n_jobs=10,
GPU=0,
seed=None,
**kwargs
):
# set hyper-parameters.
self.d_model = d_model
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
self.reg = reg
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.n_jobs = n_jobs
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger = get_module_logger("TransformerModel")
self.logger.info(
"Improved Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device)
)
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred.float() - label.float()) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def train_epoch(self, data_loader):
self.model.train()
for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
pred = self.model(feature.float()) # .float()
loss = self.loss_fn(pred, label)
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_loader):
self.model.eval()
scores = []
losses = []
for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
with torch.no_grad():
pred = self.model(feature.float()) # .float()
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(train_loader)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(train_loader)
val_loss, val_score = self.test_epoch(valid_loader)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.model.eval()
preds = []
for data in test_loader:
feature = data[:, :, 0:-1].to(self.device)
with torch.no_grad():
pred = self.model(feature.float()).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=dl_test.get_index())
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
# [T, N, F]
return x + self.pe[: x.size(0), :]
def _get_clones(module, N):
return ModuleList([copy.deepcopy(module) for i in range(N)])
class LocalformerEncoder(nn.Module):
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, d_model):
super(LocalformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.conv = _get_clones(nn.Conv1d(d_model, d_model, 3, 1, 1), num_layers)
self.num_layers = num_layers
def forward(self, src, mask):
output = src
out = src
for i, mod in enumerate(self.layers):
# [T, N, F] --> [N, T, F] --> [N, F, T]
out = output.transpose(1, 0).transpose(2, 1)
out = self.conv[i](out).transpose(2, 1).transpose(1, 0)
output = mod(output + out, src_mask=mask)
return output + out
class Transformer(nn.Module):
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
super(Transformer, self).__init__()
self.rnn = nn.GRU(
input_size=d_model,
hidden_size=d_model,
num_layers=num_layers,
batch_first=False,
dropout=dropout,
)
self.feature_layer = nn.Linear(d_feat, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model)
self.decoder_layer = nn.Linear(d_model, 1)
self.device = device
self.d_feat = d_feat
def forward(self, src):
# src [N, T, F], [512, 60, 6]
src = self.feature_layer(src) # [512, 60, 8]
# src [N, T, F] --> [T, N, F], [60, 512, 8]
src = src.transpose(1, 0) # not batch first
mask = None
src = self.pos_encoder(src)
output = self.transformer_encoder(src, mask) # [60, 512, 8]
output, _ = self.rnn(output)
# [T, N, F] --> [N, T*F]
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
return output.squeeze()

View File

@@ -564,7 +564,7 @@ class FeatureTransformer(nn.Module):
self.shared = None
self.independ = nn.ModuleList()
if first:
self.independ.append(GLU(inp, out_dim, vbs=vbs))
self.independ.append(GLU(inp_dim, out_dim, vbs=vbs))
for x in range(first, n_ind):
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
self.scale = float(np.sqrt(0.5))

View File

@@ -0,0 +1,944 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import io
import os
import copy
import math
import json
import collections
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
try:
from torch.utils.tensorboard import SummaryWriter
except:
SummaryWriter = None
from tqdm import tqdm
from qlib.utils import get_or_create_path
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.contrib.data.dataset import MTSDatasetH
device = "cuda" if torch.cuda.is_available() else "cpu"
class TRAModel(Model):
"""
TRA Model
Args:
model_config (dict): model config (will be used by RNN or Transformer)
tra_config (dict): TRA config (will be used by TRA)
model_type (str): which backbone model to use (RNN/Transformer)
lr (float): learning rate
n_epochs (int): number of total epochs
early_stop (int): early stop when performance not improved at this step
update_freq (int): gradient update frequency
max_steps_per_epoch (int): maximum number of steps in one epoch
lamb (float): regularization parameter
rho (float): exponential decay rate for `lamb`
alpha (float): fusion parameter for calculating transport loss matrix
seed (int): random seed
logdir (str): local log directory
eval_train (bool): whether evaluate train set between epochs
eval_test (bool): whether evaluate test set between epochs
pretrain (bool): whether pretrain the backbone model before training TRA.
Note that only TRA will be optimized after pretraining
init_state (str): model init state path
freeze_model (bool): whether freeze backbone model parameters
freeze_predictors (bool): whether freeze predictors parameters
transport_method (str): transport method, can be none/router/oracle
memory_mode (str): memory mode, the same argument for MTSDatasetH
"""
def __init__(
self,
model_config,
tra_config,
model_type="RNN",
lr=1e-3,
n_epochs=500,
early_stop=50,
update_freq=1,
max_steps_per_epoch=None,
lamb=0.0,
rho=0.99,
alpha=1.0,
seed=0,
logdir=None,
eval_train=False,
eval_test=False,
pretrain=False,
init_state=None,
reset_router=False,
freeze_model=False,
freeze_predictors=False,
transport_method="none",
memory_mode="sample",
):
self.logger = get_module_logger("TRA")
assert memory_mode in ["sample", "daily"], "invalid memory mode"
assert transport_method in ["none", "router", "oracle"], f"invalid transport method {transport_method}"
assert transport_method == "none" or tra_config["num_states"] > 1, "optimal transport requires `num_states` > 1"
assert (
memory_mode != "daily" or tra_config["src_info"] == "TPE"
), "daily transport can only support TPE as `src_info`"
if transport_method == "router" and not eval_train:
self.logger.warning("`eval_train` will be ignored when using TRA.router")
np.random.seed(seed)
torch.manual_seed(seed)
self.model_config = model_config
self.tra_config = tra_config
self.model_type = model_type
self.lr = lr
self.n_epochs = n_epochs
self.early_stop = early_stop
self.update_freq = update_freq
self.max_steps_per_epoch = max_steps_per_epoch
self.lamb = lamb
self.rho = rho
self.alpha = alpha
self.seed = seed
self.logdir = logdir
self.eval_train = eval_train
self.eval_test = eval_test
self.pretrain = pretrain
self.init_state = init_state
self.reset_router = reset_router
self.freeze_model = freeze_model
self.freeze_predictors = freeze_predictors
self.transport_method = transport_method
self.use_daily_transport = memory_mode == "daily"
self.transport_fn = transport_daily if self.use_daily_transport else transport_sample
self._writer = None
if self.logdir is not None:
if os.path.exists(self.logdir):
self.logger.warning(f"logdir {self.logdir} is not empty")
os.makedirs(self.logdir, exist_ok=True)
if SummaryWriter is not None:
self._writer = SummaryWriter(log_dir=self.logdir)
self._init_model()
def _init_model(self):
self.logger.info("init TRAModel...")
self.model = eval(self.model_type)(**self.model_config).to(device)
print(self.model)
self.tra = TRA(self.model.output_size, **self.tra_config).to(device)
print(self.tra)
if self.init_state:
self.logger.warning(f"load state dict from `init_state`")
state_dict = torch.load(self.init_state, map_location="cpu")
self.model.load_state_dict(state_dict["model"])
res = load_state_dict_unsafe(self.tra, state_dict["tra"])
self.logger.warning(str(res))
if self.reset_router:
self.logger.warning(f"reset TRA.router parameters")
self.tra.fc.reset_parameters()
self.tra.router.reset_parameters()
if self.freeze_model:
self.logger.warning(f"freeze model parameters")
for param in self.model.parameters():
param.requires_grad_(False)
if self.freeze_predictors:
self.logger.warning(f"freeze TRA.predictors parameters")
for param in self.tra.predictors.parameters():
param.requires_grad_(False)
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters() if p.requires_grad]))
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters() if p.requires_grad]))
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
self.fitted = False
self.global_step = -1
def train_epoch(self, epoch, data_set, is_pretrain=False):
self.model.train()
self.tra.train()
data_set.train()
self.optimizer.zero_grad()
P_all = []
prob_all = []
choice_all = []
max_steps = len(data_set)
if self.max_steps_per_epoch is not None:
if epoch == 0 and self.max_steps_per_epoch < max_steps:
self.logger.info(f"max steps updated from {max_steps} to {self.max_steps_per_epoch}")
max_steps = min(self.max_steps_per_epoch, max_steps)
cur_step = 0
total_loss = 0
total_count = 0
for batch in tqdm(data_set, total=max_steps):
cur_step += 1
if cur_step > max_steps:
break
if not is_pretrain:
self.global_step += 1
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
with torch.set_grad_enabled(not self.freeze_model):
hidden = self.model(data)
all_preds, choice, prob = self.tra(hidden, state)
if is_pretrain or self.transport_method != "none":
# NOTE: use oracle transport for pre-training
loss, pred, L, P = self.transport_fn(
all_preds,
label,
choice,
prob,
state.mean(dim=1),
count,
self.transport_method if not is_pretrain else "oracle",
self.alpha,
training=True,
)
data_set.assign_data(index, L) # save loss to memory
if self.use_daily_transport: # only save for daily transport
P_all.append(pd.DataFrame(P.detach().cpu().numpy(), index=index))
prob_all.append(pd.DataFrame(prob.detach().cpu().numpy(), index=index))
choice_all.append(pd.DataFrame(choice.detach().cpu().numpy(), index=index))
decay = self.rho ** (self.global_step // 100) # decay every 100 steps
lamb = 0 if is_pretrain else self.lamb * decay
reg = prob.log().mul(P).sum(dim=1).mean() # train router to predict OT assignment
if self._writer is not None and not is_pretrain:
self._writer.add_scalar("training/router_loss", -reg.item(), self.global_step)
self._writer.add_scalar("training/reg_loss", loss.item(), self.global_step)
self._writer.add_scalar("training/lamb", lamb, self.global_step)
if not self.use_daily_transport:
P_mean = P.mean(axis=0).detach()
self._writer.add_scalar("training/P", P_mean.max() / P_mean.min(), self.global_step)
loss = loss - lamb * reg
else:
pred = all_preds.mean(dim=1)
loss = loss_fn(pred, label)
(loss / self.update_freq).backward()
if cur_step % self.update_freq == 0:
self.optimizer.step()
self.optimizer.zero_grad()
if self._writer is not None and not is_pretrain:
self._writer.add_scalar("training/total_loss", loss.item(), self.global_step)
total_loss += loss.item()
total_count += 1
if self.use_daily_transport and len(P_all):
P_all = pd.concat(P_all, axis=0)
prob_all = pd.concat(prob_all, axis=0)
choice_all = pd.concat(choice_all, axis=0)
P_all.index = data_set.restore_daily_index(P_all.index)
prob_all.index = P_all.index
choice_all.index = P_all.index
if not is_pretrain:
self._writer.add_image("P", plot(P_all), epoch, dataformats="HWC")
self._writer.add_image("prob", plot(prob_all), epoch, dataformats="HWC")
self._writer.add_image("choice", plot(choice_all), epoch, dataformats="HWC")
total_loss /= total_count
if self._writer is not None and not is_pretrain:
self._writer.add_scalar("training/loss", total_loss, epoch)
return total_loss
def test_epoch(self, epoch, data_set, return_pred=False, prefix="test", is_pretrain=False):
self.model.eval()
self.tra.eval()
data_set.eval()
preds = []
probs = []
P_all = []
metrics = []
for batch in tqdm(data_set):
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
with torch.no_grad():
hidden = self.model(data)
all_preds, choice, prob = self.tra(hidden, state)
if is_pretrain or self.transport_method != "none":
loss, pred, L, P = self.transport_fn(
all_preds,
label,
choice,
prob,
state.mean(dim=1),
count,
self.transport_method if not is_pretrain else "oracle",
self.alpha,
training=False,
)
data_set.assign_data(index, L) # save loss to memory
if P is not None and return_pred:
P_all.append(pd.DataFrame(P.cpu().numpy(), index=index))
else:
pred = all_preds.mean(dim=1)
X = np.c_[pred.cpu().numpy(), label.cpu().numpy(), all_preds.cpu().numpy()]
columns = ["score", "label"] + ["score_%d" % d for d in range(all_preds.shape[1])]
pred = pd.DataFrame(X, index=batch["index"], columns=columns)
metrics.append(evaluate(pred))
if return_pred:
preds.append(pred)
if prob is not None:
columns = ["prob_%d" % d for d in range(all_preds.shape[1])]
probs.append(pd.DataFrame(prob.cpu().numpy(), index=index, columns=columns))
metrics = pd.DataFrame(metrics)
metrics = {
"MSE": metrics.MSE.mean(),
"MAE": metrics.MAE.mean(),
"IC": metrics.IC.mean(),
"ICIR": metrics.IC.mean() / metrics.IC.std(),
}
if self._writer is not None and epoch >= 0 and not is_pretrain:
for key, value in metrics.items():
self._writer.add_scalar(prefix + "/" + key, value, epoch)
if return_pred:
preds = pd.concat(preds, axis=0)
preds.index = data_set.restore_index(preds.index)
preds.index = preds.index.swaplevel()
preds.sort_index(inplace=True)
if probs:
probs = pd.concat(probs, axis=0)
if self.use_daily_transport:
probs.index = data_set.restore_daily_index(probs.index)
else:
probs.index = data_set.restore_index(probs.index)
probs.index = probs.index.swaplevel()
probs.sort_index(inplace=True)
if len(P_all):
P_all = pd.concat(P_all, axis=0)
if self.use_daily_transport:
P_all.index = data_set.restore_daily_index(P_all.index)
else:
P_all.index = data_set.restore_index(P_all.index)
P_all.index = P_all.index.swaplevel()
P_all.sort_index(inplace=True)
return metrics, preds, probs, P_all
def _fit(self, train_set, valid_set, test_set, evals_result, is_pretrain=True):
best_score = -1
best_epoch = 0
stop_rounds = 0
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
# train
if not is_pretrain and self.transport_method != "none":
self.logger.info("init memory...")
self.test_epoch(-1, train_set)
for epoch in range(self.n_epochs):
self.logger.info("Epoch %d:", epoch)
self.logger.info("training...")
self.train_epoch(epoch, train_set, is_pretrain=is_pretrain)
self.logger.info("evaluating...")
# NOTE: during evaluating, the whole memory will be refreshed
if not is_pretrain and (self.transport_method == "router" or self.eval_train):
train_set.clear_memory() # NOTE: clear the shared memory
train_metrics = self.test_epoch(epoch, train_set, is_pretrain=is_pretrain, prefix="train")[0]
evals_result["train"].append(train_metrics)
self.logger.info("train metrics: %s" % train_metrics)
valid_metrics = self.test_epoch(epoch, valid_set, is_pretrain=is_pretrain, prefix="valid")[0]
evals_result["valid"].append(valid_metrics)
self.logger.info("valid metrics: %s" % valid_metrics)
if self.eval_test:
test_metrics = self.test_epoch(epoch, test_set, is_pretrain=is_pretrain, prefix="test")[0]
evals_result["test"].append(test_metrics)
self.logger.info("test metrics: %s" % test_metrics)
if valid_metrics["IC"] > best_score:
best_score = valid_metrics["IC"]
stop_rounds = 0
best_epoch = epoch
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
if self.logdir is not None:
torch.save(best_params, self.logdir + "/model.bin")
else:
stop_rounds += 1
if stop_rounds >= self.early_stop:
self.logger.info("early stop @ %s" % epoch)
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_params["model"])
self.tra.load_state_dict(best_params["tra"])
return best_score
def fit(self, dataset, evals_result=dict()):
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
self.fitted = True
self.global_step = -1
evals_result["train"] = []
evals_result["valid"] = []
evals_result["test"] = []
if self.pretrain:
self.logger.info("pretraining...")
self.optimizer = optim.Adam(
list(self.model.parameters()) + list(self.tra.predictors.parameters()), lr=self.lr
)
self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True)
# reset optimizer
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
self.logger.info("training...")
best_score = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=False)
self.logger.info("inference")
train_metrics, train_preds, train_probs, train_P = self.test_epoch(-1, train_set, return_pred=True)
self.logger.info("train metrics: %s" % train_metrics)
valid_metrics, valid_preds, valid_probs, valid_P = self.test_epoch(-1, valid_set, return_pred=True)
self.logger.info("valid metrics: %s" % valid_metrics)
test_metrics, test_preds, test_probs, test_P = self.test_epoch(-1, test_set, return_pred=True)
self.logger.info("test metrics: %s" % test_metrics)
if self.logdir:
self.logger.info("save model & pred to local directory")
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
self.logdir + "/logs.csv", index=False
)
torch.save({"model": self.model.state_dict(), "tra": self.tra.state_dict()}, self.logdir + "/model.bin")
train_preds.to_pickle(self.logdir + "/train_pred.pkl")
valid_preds.to_pickle(self.logdir + "/valid_pred.pkl")
test_preds.to_pickle(self.logdir + "/test_pred.pkl")
if len(train_probs):
train_probs.to_pickle(self.logdir + "/train_prob.pkl")
valid_probs.to_pickle(self.logdir + "/valid_prob.pkl")
test_probs.to_pickle(self.logdir + "/test_prob.pkl")
if len(train_P):
train_P.to_pickle(self.logdir + "/train_P.pkl")
valid_P.to_pickle(self.logdir + "/valid_P.pkl")
test_P.to_pickle(self.logdir + "/test_P.pkl")
info = {
"config": {
"model_config": self.model_config,
"tra_config": self.tra_config,
"model_type": self.model_type,
"lr": self.lr,
"n_epochs": self.n_epochs,
"early_stop": self.early_stop,
"max_steps_per_epoch": self.max_steps_per_epoch,
"lamb": self.lamb,
"rho": self.rho,
"alpha": self.alpha,
"seed": self.seed,
"logdir": self.logdir,
"pretrain": self.pretrain,
"init_state": self.init_state,
"transport_method": self.transport_method,
"use_daily_transport": self.use_daily_transport,
},
"best_eval_metric": -best_score, # NOTE: -1 for minimize
"metrics": {"train": train_metrics, "valid": valid_metrics, "test": test_metrics},
}
with open(self.logdir + "/info.json", "w") as f:
json.dump(info, f)
def predict(self, dataset, segment="test"):
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
if not self.fitted:
raise ValueError("model is not fitted yet!")
test_set = dataset.prepare(segment)
metrics, preds, _, _ = self.test_epoch(-1, test_set, return_pred=True)
self.logger.info("test metrics: %s" % metrics)
return preds
class RNN(nn.Module):
"""RNN Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of hidden layers
rnn_arch (str): rnn architecture
use_attn (bool): whether use attention layer.
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
dropout (float): dropout rate
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
rnn_arch="GRU",
use_attn=True,
dropout=0.0,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn_arch = rnn_arch
self.use_attn = use_attn
if hidden_size < input_size:
# compression
self.input_proj = nn.Linear(input_size, hidden_size)
else:
self.input_proj = None
self.rnn = getattr(nn, rnn_arch)(
input_size=min(input_size, hidden_size),
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
if self.use_attn:
self.W = nn.Linear(hidden_size, hidden_size)
self.u = nn.Linear(hidden_size, 1, bias=False)
self.output_size = hidden_size * 2
else:
self.output_size = hidden_size
def forward(self, x):
if self.input_proj is not None:
x = self.input_proj(x)
rnn_out, last_out = self.rnn(x)
if self.rnn_arch == "LSTM":
last_out = last_out[0]
last_out = last_out.mean(dim=0)
if self.use_attn:
laten = self.W(rnn_out).tanh()
scores = self.u(laten).softmax(dim=1)
att_out = (rnn_out * scores).sum(dim=1)
last_out = torch.cat([last_out, att_out], dim=1)
return last_out
class PositionalEncoding(nn.Module):
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class Transformer(nn.Module):
"""Transformer Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of transformer layers
num_heads (int): number of heads in transformer
dropout (float): dropout rate
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
num_heads=2,
dropout=0.0,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.input_proj = nn.Linear(input_size, hidden_size)
self.pe = PositionalEncoding(input_size, dropout)
layer = nn.TransformerEncoderLayer(
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.output_size = hidden_size
def forward(self, x):
x = x.permute(1, 0, 2).contiguous() # the first dim need to be time
x = self.pe(x)
x = self.input_proj(x)
out = self.encoder(x)
return out[-1]
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction erros & latent representation as inputs,
then routes the input sample to a specific predictor for training & inference.
Args:
input_size (int): input size (RNN/Transformer's hidden size)
num_states (int): number of latent states (i.e., trading patterns)
If `num_states=1`, then TRA falls back to traditional methods
hidden_size (int): hidden size of the router
tau (float): gumbel softmax temperature
src_info (str): information for the router
"""
def __init__(
self,
input_size,
num_states=1,
hidden_size=8,
rnn_arch="GRU",
num_layers=1,
dropout=0.0,
tau=1.0,
src_info="LR_TPE",
):
super().__init__()
assert src_info in ["LR", "TPE", "LR_TPE"], "invalid `src_info`"
self.num_states = num_states
self.tau = tau
self.rnn_arch = rnn_arch
self.src_info = src_info
self.predictors = nn.Linear(input_size, num_states)
if self.num_states > 1:
if "TPE" in src_info:
self.router = getattr(nn, rnn_arch)(
input_size=num_states,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
self.fc = nn.Linear(hidden_size + input_size if "LR" in src_info else hidden_size, num_states)
else:
self.fc = nn.Linear(input_size, num_states)
def reset_parameters(self):
for child in self.children():
child.reset_parameters()
def forward(self, hidden, hist_loss):
preds = self.predictors(hidden)
if self.num_states == 1: # no need for router when having only one prediction
return preds, None, None
if "TPE" in self.src_info:
out = self.router(hist_loss)[1] # TPE
if self.rnn_arch == "LSTM":
out = out[0]
out = out.mean(dim=0)
if "LR" in self.src_info:
out = torch.cat([hidden, out], dim=-1) # LR_TPE
else:
out = hidden # LR
out = self.fc(out)
choice = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=True)
prob = torch.softmax(out / self.tau, dim=-1)
return preds, choice, prob
def evaluate(pred):
pred = pred.rank(pct=True) # transform into percentiles
score = pred.score
label = pred.label
diff = score - label
MSE = (diff ** 2).mean()
MAE = (diff.abs()).mean()
IC = score.corr(label, method="spearman")
return {"MSE": MSE, "MAE": MAE, "IC": IC}
def shoot_infs(inp_tensor):
"""Replaces inf by maximum of tensor"""
mask_inf = torch.isinf(inp_tensor)
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
if len(ind_inf) > 0:
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = 0
elif len(ind) == 1:
inp_tensor[ind[0]] = 0
m = torch.max(inp_tensor)
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = m
elif len(ind) == 1:
inp_tensor[ind[0]] = m
return inp_tensor
def sinkhorn(Q, n_iters=3, epsilon=0.1):
# epsilon should be adjusted according to logits value's scale
with torch.no_grad():
Q = torch.exp(Q / epsilon)
Q = shoot_infs(Q)
for i in range(n_iters):
Q /= Q.sum(dim=0, keepdim=True)
Q /= Q.sum(dim=1, keepdim=True)
return Q
def loss_fn(pred, label):
mask = ~torch.isnan(label)
if len(pred.shape) == 2:
label = label[:, None]
return (pred[mask] - label[mask]).pow(2).mean(dim=0)
def minmax_norm(x):
xmin = x.min(dim=-1, keepdim=True).values
xmax = x.max(dim=-1, keepdim=True).values
mask = (xmin == xmax).squeeze()
x = (x - xmin) / (xmax - xmin + 1e-12)
x[mask] = 1
return x
def transport_sample(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
"""
sample-wise transport
Args:
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
label (torch.Tensor): label, [sample]
choice (torch.Tensor): gumbel softmax choice, [sample x states]
prob (torch.Tensor): router predicted probility, [sample x states]
hist_loss (torch.Tensor): history loss matrix, [sample x states]
count (list): sample counts for each day, empty list for sample-wise transport
transport_method (str): transportation method
alpha (float): fusion parameter for calculating transport loss matrix
training (bool): indicate training or inference
"""
assert all_preds.shape == choice.shape
assert len(all_preds) == len(label)
assert transport_method in ["oracle", "router"]
all_loss = torch.zeros_like(all_preds)
mask = ~torch.isnan(label)
all_loss[mask] = (all_preds[mask] - label[mask, None]).pow(2) # [sample x states]
L = minmax_norm(all_loss.detach())
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
Lh = minmax_norm(Lh)
P = sinkhorn(-Lh)
del Lh
if transport_method == "router":
if training:
pred = (all_preds * choice).sum(dim=1) # gumbel softmax
else:
pred = all_preds[range(len(all_preds)), prob.argmax(dim=-1)] # argmax
else:
pred = (all_preds * P).sum(dim=1)
if transport_method == "router":
loss = loss_fn(pred, label)
else:
loss = (all_loss * P).sum(dim=1).mean()
return loss, pred, L, P
def transport_daily(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
"""
daily transport
Args:
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
label (torch.Tensor): label, [sample]
choice (torch.Tensor): gumbel softmax choice, [days x states]
prob (torch.Tensor): router predicted probility, [days x states]
hist_loss (torch.Tensor): history loss matrix, [days x states]
count (list): sample counts for each day, [days]
transport_method (str): transportation method
alpha (float): fusion parameter for calculating transport loss matrix
training (bool): indicate training or inference
"""
assert len(prob) == len(count)
assert len(all_preds) == sum(count)
assert transport_method in ["oracle", "router"]
all_loss = [] # loss of all predictions
start = 0
for i, cnt in enumerate(count):
slc = slice(start, start + cnt) # samples from the i-th day
start += cnt
tloss = loss_fn(all_preds[slc], label[slc]) # loss of the i-th day
all_loss.append(tloss)
all_loss = torch.stack(all_loss, dim=0) # [days x states]
L = minmax_norm(all_loss.detach())
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
Lh = minmax_norm(Lh)
P = sinkhorn(-Lh)
del Lh
pred = []
start = 0
for i, cnt in enumerate(count):
slc = slice(start, start + cnt) # samples from the i-th day
start += cnt
if transport_method == "router":
if training:
tpred = all_preds[slc] @ choice[i] # gumbel softmax
else:
tpred = all_preds[slc][:, prob[i].argmax(dim=-1)] # argmax
else:
tpred = all_preds[slc] @ P[i]
pred.append(tpred)
pred = torch.cat(pred, dim=0) # [samples]
if transport_method == "router":
loss = loss_fn(pred, label)
else:
loss = (all_loss * P).sum(dim=1).mean()
return loss, pred, L, P
def load_state_dict_unsafe(model, state_dict):
"""
Load state dict to provided model while ignore exceptions.
"""
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(model)
load = None # break load->load reference cycle
return {"unexpected_keys": unexpected_keys, "missing_keys": missing_keys, "error_msgs": error_msgs}
def plot(P):
assert isinstance(P, pd.DataFrame)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
P.plot.area(ax=axes[0], xlabel="")
P.idxmax(axis=1).value_counts().sort_index().plot.bar(ax=axes[1], xlabel="")
plt.tight_layout()
with io.BytesIO() as buf:
plt.savefig(buf, format="png")
buf.seek(0)
img = plt.imread(buf)
plt.close()
return np.uint8(img * 255)

View File

@@ -0,0 +1,294 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
import math
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
# qrun examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml ”
class TransformerModel(Model):
def __init__(
self,
d_feat: int = 20,
d_model: int = 64,
batch_size: int = 2048,
nhead: int = 2,
num_layers: int = 2,
dropout: float = 0,
n_epochs=100,
lr=0.0001,
metric="",
early_stop=5,
loss="mse",
optimizer="adam",
reg=1e-3,
n_jobs=10,
GPU=0,
seed=None,
**kwargs
):
# set hyper-parameters.
self.d_model = d_model
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
self.reg = reg
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.n_jobs = n_jobs
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger = get_module_logger("TransformerModel")
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred.float() - label.float()) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def train_epoch(self, x_train, y_train):
x_train_values = x_train.values
y_train_values = np.squeeze(y_train.values)
self.model.train()
indices = np.arange(len(x_train_values))
np.random.shuffle(indices)
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.model(feature)
loss = self.loss_fn(pred, label)
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_x, data_y):
# prepare training data
x_values = data_x.values
y_values = np.squeeze(data_y.values)
self.model.eval()
scores = []
losses = []
indices = np.arange(len(x_values))
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
with torch.no_grad():
pred = self.model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(x_train, y_train)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(x_train, y_train)
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.model.eval()
x_values = x_test.values
sample_num = x_values.shape[0]
preds = []
for begin in range(sample_num)[:: self.batch_size]:
if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
pred = self.model(x_batch).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
# [T, N, F]
return x + self.pe[: x.size(0), :]
class Transformer(nn.Module):
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
super(Transformer, self).__init__()
self.feature_layer = nn.Linear(d_feat, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.decoder_layer = nn.Linear(d_model, 1)
self.device = device
self.d_feat = d_feat
def forward(self, src):
# src [N, F*T] --> [N, T, F]
src = src.reshape(len(src), self.d_feat, -1).permute(0, 2, 1)
src = self.feature_layer(src)
# src [N, T, F] --> [T, N, F], [60, 512, 8]
src = src.transpose(1, 0) # not batch first
mask = None
src = self.pos_encoder(src)
output = self.transformer_encoder(src, mask) # [60, 512, 8]
# [T, N, F] --> [N, T*F]
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
return output.squeeze()

View File

@@ -0,0 +1,269 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
import math
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
class TransformerModel(Model):
def __init__(
self,
d_feat: int = 20,
d_model: int = 64,
batch_size: int = 8192,
nhead: int = 2,
num_layers: int = 2,
dropout: float = 0,
n_epochs=100,
lr=0.0001,
metric="",
early_stop=5,
loss="mse",
optimizer="adam",
reg=1e-3,
n_jobs=10,
GPU=0,
seed=None,
**kwargs
):
# set hyper-parameters.
self.d_model = d_model
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
self.reg = reg
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.n_jobs = n_jobs
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger = get_module_logger("TransformerModel")
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred.float() - label.float()) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def train_epoch(self, data_loader):
self.model.train()
for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
pred = self.model(feature.float()) # .float()
loss = self.loss_fn(pred, label)
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_loader):
self.model.eval()
scores = []
losses = []
for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
with torch.no_grad():
pred = self.model(feature.float()) # .float()
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(train_loader)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(train_loader)
val_loss, val_score = self.test_epoch(valid_loader)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.model.eval()
preds = []
for data in test_loader:
feature = data[:, :, 0:-1].to(self.device)
with torch.no_grad():
pred = self.model(feature.float()).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=dl_test.get_index())
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
# [T, N, F]
return x + self.pe[: x.size(0), :]
class Transformer(nn.Module):
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
super(Transformer, self).__init__()
self.feature_layer = nn.Linear(d_feat, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.decoder_layer = nn.Linear(d_model, 1)
self.device = device
self.d_feat = d_feat
def forward(self, src):
# src [N, T, F], [512, 60, 6]
src = self.feature_layer(src) # [512, 60, 8]
# src [N, T, F] --> [T, N, F], [60, 512, 8]
src = src.transpose(1, 0) # not batch first
mask = None
src = self.pos_encoder(src)
output = self.transformer_encoder(src, mask) # [60, 512, 8]
# [T, N, F] --> [N, T*F]
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
return output.squeeze()

View File

@@ -221,9 +221,9 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
only_tradable : bool
will the strategy only consider the tradable stock when buying and selling.
if only_tradable:
strategy will make buy sell decision without checking the tradable state of the stock.
the strategy will peek at the information in the short future to avoid untradable stocks (untradable stocks include stocks that meet suspension, or hit limit up or limit down).
else:
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
the strategy will generate orders without peeking any information in the future, so the order generated by the strategies may fail.
"""
super(TopkDropoutStrategy, self).__init__()
ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None))

View File

@@ -196,9 +196,9 @@ class Feature(Expression):
def __init__(self, name=None):
if name:
self._name = name.lower()
self._name = name
else:
self._name = type(self).__name__.lower()
self._name = type(self).__name__
def __str__(self):
return "$" + self._name

View File

@@ -17,6 +17,7 @@ import abc
from pathlib import Path
import numpy as np
import pandas as pd
from typing import Union, Iterable
from collections import OrderedDict
from ..config import C
@@ -216,12 +217,14 @@ class CacheUtils:
redis_lock.reset_all(r)
@staticmethod
def visit(cache_path):
def visit(cache_path: Union[str, Path]):
# FIXME: Because read_lock was canceled when reading the cache, multiple processes may have read and write exceptions here
try:
with open(cache_path + ".meta", "rb") as f:
cache_path = Path(cache_path)
meta_path = cache_path.with_suffix(".meta")
with meta_path.open("rb") as f:
d = pickle.load(f)
with open(cache_path + ".meta", "wb") as f:
with meta_path.open("wb") as f:
try:
d["meta"]["last_visit"] = str(time.time())
d["meta"]["visits"] = d["meta"]["visits"] + 1
@@ -249,17 +252,17 @@ class CacheUtils:
@staticmethod
@contextlib.contextmanager
def reader_lock(redis_t, lock_name):
lock_name = f"{C.provider_uri}:{lock_name}"
current_cache_rlock = redis_lock.Lock(redis_t, "%s-rlock" % lock_name)
current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name)
def reader_lock(redis_t, lock_name: str):
current_cache_rlock = redis_lock.Lock(redis_t, f"{lock_name}-rlock")
current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock")
lock_reader = f"{lock_name}-reader"
# make sure only one reader is entering
current_cache_rlock.acquire(timeout=60)
try:
current_cache_readers = redis_t.get("%s-reader" % lock_name)
current_cache_readers = redis_t.get(lock_reader)
if current_cache_readers is None or int(current_cache_readers) == 0:
CacheUtils.acquire(current_cache_wlock, lock_name)
redis_t.incr("%s-reader" % lock_name)
redis_t.incr(lock_reader)
finally:
current_cache_rlock.release()
try:
@@ -268,9 +271,9 @@ class CacheUtils:
# make sure only one reader is leaving
current_cache_rlock.acquire(timeout=60)
try:
redis_t.decr("%s-reader" % lock_name)
if int(redis_t.get("%s-reader" % lock_name)) == 0:
redis_t.delete("%s-reader" % lock_name)
redis_t.decr(lock_reader)
if int(redis_t.get(lock_reader)) == 0:
redis_t.delete(lock_reader)
current_cache_wlock.reset()
finally:
current_cache_rlock.release()
@@ -278,8 +281,7 @@ class CacheUtils:
@staticmethod
@contextlib.contextmanager
def writer_lock(redis_t, lock_name):
lock_name = f"{C.provider_uri}:{lock_name}"
current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name, id=CacheUtils.LOCK_ID)
current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock", id=CacheUtils.LOCK_ID)
CacheUtils.acquire(current_cache_wlock, lock_name)
try:
yield
@@ -297,6 +299,30 @@ class BaseProviderCache:
def __getattr__(self, attr):
return getattr(self.provider, attr)
@staticmethod
def check_cache_exists(cache_path: Union[str, Path], suffix_list: Iterable = (".index", ".meta")) -> bool:
cache_path = Path(cache_path)
for p in [cache_path] + [cache_path.with_suffix(_s) for _s in suffix_list]:
if not p.exists():
return False
return True
@staticmethod
def clear_cache(cache_path: Union[str, Path]):
for p in [
cache_path,
cache_path.with_suffix(".meta"),
cache_path.with_suffix(".index"),
]:
if p.exists():
p.unlink()
@staticmethod
def get_cache_dir(dir_name: str, freq: str = None) -> Path:
cache_dir = Path(C.dpm.get_data_path(freq)).joinpath(dir_name)
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
class ExpressionCache(BaseProviderCache):
"""Expression cache mechanism base class.
@@ -330,15 +356,16 @@ class ExpressionCache(BaseProviderCache):
"""
raise NotImplementedError("Implement this method if you want to use expression cache")
def update(self, cache_uri):
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
"""Update expression cache to latest calendar.
Overide this method to define how to update expression cache corresponding to users' own cache mechanism.
Parameters
----------
cache_uri : str
cache_uri : str or Path
the complete uri of expression cache file (include dir path).
freq : str
Returns
-------
@@ -358,7 +385,9 @@ class DatasetCache(BaseProviderCache):
HDF_KEY = "df"
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
def dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
"""Get feature dataset.
.. note:: Same interface as `dataset` method in dataset provider
@@ -369,13 +398,19 @@ class DatasetCache(BaseProviderCache):
"""
if disk_cache == 0:
# skip cache
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
return self.provider.dataset(
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
)
else:
# use and replace cache
try:
return self._dataset(instruments, fields, start_time, end_time, freq, disk_cache)
return self._dataset(
instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors
)
except NotImplementedError:
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
return self.provider.dataset(
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
)
def _uri(self, instruments, fields, start_time, end_time, freq, **kwargs):
"""Get dataset cache file uri.
@@ -384,14 +419,18 @@ class DatasetCache(BaseProviderCache):
"""
raise NotImplementedError("Implement this function to match your own cache mechanism")
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
"""Get feature dataset using cache.
Override this method to define how to get feature dataset corresponding to users' own cache mechanism.
"""
raise NotImplementedError("Implement this method if you want to use dataset feature cache")
def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
def _dataset_uri(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
"""Get a uri of feature dataset using cache.
specially:
disk_cache=1 means using data set cache and return the uri of cache file.
@@ -403,15 +442,16 @@ class DatasetCache(BaseProviderCache):
"Implement this method if you want to use dataset feature cache as a cache file for client"
)
def update(self, cache_uri):
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
"""Update dataset cache to latest calendar.
Overide this method to define how to update dataset cache corresponding to users' own cache mechanism.
Parameters
----------
cache_uri : str
cache_uri : str or Path
the complete uri of dataset cache file (include dir path).
freq : str
Returns
-------
@@ -452,25 +492,19 @@ class DiskExpressionCache(ExpressionCache):
self.r = get_redis_connection()
# remote==True means client is using this module, writing behaviour will not be allowed.
self.remote = kwargs.get("remote", False)
self.expr_cache_path = os.path.join(C.get_data_path(), C.features_cache_dir_name)
os.makedirs(self.expr_cache_path, exist_ok=True)
def get_cache_dir(self, freq: str = None) -> Path:
return super(DiskExpressionCache, self).get_cache_dir(C.features_cache_dir_name, freq)
def _uri(self, instrument, field, start_time, end_time, freq):
field = remove_fields_space(field)
instrument = str(instrument).lower()
return hash_args(instrument, field, freq)
@staticmethod
def check_cache_exists(cache_path):
for p in [cache_path, cache_path + ".meta"]:
if not Path(p).exists():
return False
return True
def _expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
_cache_uri = self._uri(instrument=instrument, field=field, start_time=None, end_time=None, freq=freq)
_instrument_dir = os.path.join(self.expr_cache_path, instrument.lower())
cache_path = os.path.join(_instrument_dir, _cache_uri)
_instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower())
cache_path = _instrument_dir.joinpath(_cache_uri)
# get calendar
from .data import Cal
@@ -478,7 +512,7 @@ class DiskExpressionCache(ExpressionCache):
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False)
if self.check_cache_exists(cache_path):
if self.check_cache_exists(cache_path, suffix_list=[".meta"]):
"""
In most cases, we do not need reader_lock.
Because updating data is a small probability event compare to reading data.
@@ -502,8 +536,7 @@ class DiskExpressionCache(ExpressionCache):
# normalize field
field = remove_fields_space(field)
# cache unavailable, generate the cache
if not os.path.exists(_instrument_dir):
os.makedirs(_instrument_dir, exist_ok=True)
_instrument_dir.mkdir(parents=True, exist_ok=True)
if not isinstance(eval(parse_field(field)), Feature):
# When the expression is not a raw feature
# generate expression cache if the feature is not a Feature
@@ -511,7 +544,7 @@ class DiskExpressionCache(ExpressionCache):
series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq)
if not series.empty:
# This expresion is empty, we don't generate any cache for it.
with CacheUtils.writer_lock(self.r, "expression-%s" % _cache_uri):
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:expression-{_cache_uri}"):
self.gen_expression_cache(
expression_data=series,
cache_path=cache_path,
@@ -527,14 +560,6 @@ class DiskExpressionCache(ExpressionCache):
# If the expression is a raw feature(such as $close, $open)
return self.provider.expression(instrument, field, start_time, end_time, freq)
@staticmethod
def clear_cache(cache_path):
meta_path = cache_path + ".meta"
for p in [cache_path, meta_path]:
p = Path(p)
if p.exists():
p.unlink()
def gen_expression_cache(self, expression_data, cache_path, instrument, field, freq, last_update):
"""use bin file to save like feature-data."""
# Make sure the cache runs right when the directory is deleted
@@ -544,27 +569,28 @@ class DiskExpressionCache(ExpressionCache):
"meta": {"last_visit": time.time(), "visits": 1},
}
self.logger.debug(f"generating expression cache: {meta}")
os.makedirs(self.expr_cache_path, exist_ok=True)
self.clear_cache(cache_path)
meta_path = cache_path + ".meta"
meta_path = cache_path.with_suffix(".meta")
with open(meta_path, "wb") as f:
with meta_path.open("wb") as f:
pickle.dump(meta, f)
os.chmod(meta_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
meta_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
df = expression_data.to_frame()
r = np.hstack([df.index[0], expression_data]).astype("<f")
r.tofile(str(cache_path))
def update(self, sid, cache_uri):
cp_cache_uri = os.path.join(self.expr_cache_path, sid, cache_uri)
if not self.check_cache_exists(cp_cache_uri):
def update(self, sid, cache_uri, freq: str = "day"):
cp_cache_uri = self.get_cache_dir(freq).joinpath(sid).joinpath(cache_uri)
meta_path = cp_cache_uri.with_suffix(".meta")
if not self.check_cache_exists(cp_cache_uri, suffix_list=[".meta"]):
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
self.clear_cache(cp_cache_uri)
return 2
with CacheUtils.writer_lock(self.r, "expression-%s" % cache_uri):
with open(cp_cache_uri + ".meta", "rb") as f:
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:expression-{cache_uri}"):
with meta_path.open("rb") as f:
d = pickle.load(f)
instrument = d["info"]["instrument"]
field = d["info"]["field"]
@@ -611,7 +637,7 @@ class DiskExpressionCache(ExpressionCache):
f.write(data)
# update meta file
d["info"]["last_update"] = str(new_calendar[-1])
with open(cp_cache_uri + ".meta", "wb") as f:
with meta_path.open("wb") as f:
pickle.dump(d, f)
return 0
@@ -623,22 +649,16 @@ class DiskDatasetCache(DatasetCache):
super(DiskDatasetCache, self).__init__(provider)
self.r = get_redis_connection()
self.remote = kwargs.get("remote", False)
self.dtst_cache_path = os.path.join(C.get_data_path(), C.dataset_cache_dir_name)
os.makedirs(self.dtst_cache_path, exist_ok=True)
@staticmethod
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache)
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
@staticmethod
def check_cache_exists(cache_path):
for p in [cache_path, cache_path + ".index", cache_path + ".meta"]:
if not Path(p).exists():
return False
return True
def get_cache_dir(self, freq: str = None) -> Path:
return super(DiskDatasetCache, self).get_cache_dir(C.dataset_cache_dir_name, freq)
@classmethod
def read_data_from_cache(cls, cache_path, start_time, end_time, fields):
def read_data_from_cache(cls, cache_path: Union[str, Path], start_time, end_time, fields):
"""read_cache_from
This function can read data from the disk cache dataset
@@ -671,17 +691,32 @@ class DiskDatasetCache(DatasetCache):
df = pd.DataFrame(columns=fields)
return df
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
):
if disk_cache == 0:
# In this case, data_set cache is configured but will not be used.
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
return self.provider.dataset(
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
)
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
if inst_processors:
raise ValueError(
f"{self.__class__.__name__} does not support inst_processor. "
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
)
_cache_uri = self._uri(
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
instruments=instruments,
fields=fields,
start_time=None,
end_time=None,
freq=freq,
disk_cache=disk_cache,
inst_processors=inst_processors,
)
cache_path = os.path.join(self.dtst_cache_path, _cache_uri)
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
features = pd.DataFrame()
gen_flag = False
@@ -689,7 +724,7 @@ class DiskDatasetCache(DatasetCache):
if self.check_cache_exists(cache_path):
if disk_cache == 1:
# use cache
with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri):
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
CacheUtils.visit(cache_path)
features = self.read_data_from_cache(cache_path, start_time, end_time, fields)
elif disk_cache == 2:
@@ -699,15 +734,21 @@ class DiskDatasetCache(DatasetCache):
if gen_flag:
# cache unavailable, generate the cache
with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri):
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
features = self.gen_dataset_cache(
cache_path=cache_path, instruments=instruments, fields=fields, freq=freq
cache_path=cache_path,
instruments=instruments,
fields=fields,
freq=freq,
inst_processors=inst_processors,
)
if not features.empty:
features = features.sort_index().loc(axis=0)[:, start_time:end_time]
return features
def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
def _dataset_uri(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
):
if disk_cache == 0:
# In this case, server only checks the expression cache.
# The client will load the cache data by itself.
@@ -715,21 +756,38 @@ class DiskDatasetCache(DatasetCache):
LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq)
return ""
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
if inst_processors:
raise ValueError(
f"{self.__class__.__name__} does not support inst_processor. "
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
)
_cache_uri = self._uri(
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
instruments=instruments,
fields=fields,
start_time=None,
end_time=None,
freq=freq,
disk_cache=disk_cache,
inst_processors=inst_processors,
)
cache_path = os.path.join(self.dtst_cache_path, _cache_uri)
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
if self.check_cache_exists(cache_path):
self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly")
with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri):
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
CacheUtils.visit(cache_path)
return _cache_uri
else:
# cache unavailable, generate the cache
with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri):
self.gen_dataset_cache(cache_path=cache_path, instruments=instruments, fields=fields, freq=freq)
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
self.gen_dataset_cache(
cache_path=cache_path,
instruments=instruments,
fields=fields,
freq=freq,
inst_processors=inst_processors,
)
return _cache_uri
class IndexManager:
@@ -740,8 +798,9 @@ class DiskDatasetCache(DatasetCache):
KEY = "df"
def __init__(self, cache_path):
self.index_path = cache_path + ".index"
def __init__(self, cache_path: Union[str, Path]):
self.index_path = cache_path.with_suffix(".index")
self._data = None
self.logger = get_module_logger(self.__class__.__name__)
@@ -757,7 +816,7 @@ class DiskDatasetCache(DatasetCache):
self._data.sort_index(inplace=True)
self._data.to_hdf(self.index_path, key=self.KEY, mode="w", format="table")
# The index should be readable for all users
os.chmod(self.index_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
self.index_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
def sync_from_disk(self):
# The file will not be closed directly if we read_hdf from the disk directly
@@ -795,15 +854,7 @@ class DiskDatasetCache(DatasetCache):
index_data += start_index
return index_data
@staticmethod
def clear_cache(cache_path):
meta_path = cache_path + ".meta"
for p in [cache_path, meta_path, cache_path + ".index", cache_path + ".data"]:
p = Path(p)
if p.exists():
p.unlink()
def gen_dataset_cache(self, cache_path, instruments, fields, freq):
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=[]):
"""gen_dataset_cache
.. note:: This function does not consider the cache read write lock. Please
@@ -838,20 +889,23 @@ class DiskDatasetCache(DatasetCache):
:param instruments: The instruments to store the cache.
:param fields: The fields to store the cache.
:param freq: The freq to store the cache.
:param inst_processors: Instrument processors.
:return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
"""
# get calendar
from .data import Cal
cache_path = Path(cache_path)
_calendar = Cal.calendar(freq=freq)
self.logger.debug(f"Generating dataset cache {cache_path}")
# Make sure the cache runs right when the directory is deleted
# while running
os.makedirs(self.dtst_cache_path, exist_ok=True)
self.clear_cache(cache_path)
features = self.provider.dataset(instruments, fields, _calendar[0], _calendar[-1], freq)
features = self.provider.dataset(
instruments, fields, _calendar[0], _calendar[-1], freq, inst_processors=inst_processors
)
if features.empty:
return features
@@ -860,7 +914,7 @@ class DiskDatasetCache(DatasetCache):
features = features.swaplevel("instrument", "datetime").sort_index()
# write cache data
with pd.HDFStore(cache_path + ".data") as store:
with pd.HDFStore(str(cache_path.with_suffix(".data"))) as store:
cache_to_orig_map = dict(zip(remove_fields_space(features.columns), features.columns))
orig_to_cache_map = dict(zip(features.columns, remove_fields_space(features.columns)))
cache_features = features[list(cache_to_orig_map.values())].rename(columns=orig_to_cache_map)
@@ -876,12 +930,13 @@ class DiskDatasetCache(DatasetCache):
"fields": cache_columns,
"freq": freq,
"last_update": str(_calendar[-1]), # The last_update to store the cache
"inst_processors": inst_processors, # The last_update to store the cache
},
"meta": {"last_visit": time.time(), "visits": 1},
}
with open(cache_path + ".meta", "wb") as f:
with cache_path.with_suffix(".meta").open("wb") as f:
pickle.dump(meta, f)
os.chmod(cache_path + ".meta", stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
cache_path.with_suffix(".meta").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
# write index file
im = DiskDatasetCache.IndexManager(cache_path)
index_data = im.build_index_from_data(features)
@@ -890,26 +945,27 @@ class DiskDatasetCache(DatasetCache):
# rename the file after the cache has been generated
# this doesn't work well on windows, but our server won't use windows
# temporarily
os.replace(cache_path + ".data", cache_path)
cache_path.with_suffix(".data").rename(cache_path)
# the fields of the cached features are converted to the original fields
return features.swaplevel("datetime", "instrument")
def update(self, cache_uri):
cp_cache_uri = os.path.join(self.dtst_cache_path, cache_uri)
def update(self, cache_uri, freq: str = "day"):
cp_cache_uri = self.get_cache_dir(freq).joinpath(cache_uri)
meta_path = cp_cache_uri.with_suffix(".meta")
if not self.check_cache_exists(cp_cache_uri):
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
self.clear_cache(cp_cache_uri)
return 2
im = DiskDatasetCache.IndexManager(cp_cache_uri)
with CacheUtils.writer_lock(self.r, "dataset-%s" % cache_uri):
with open(cp_cache_uri + ".meta", "rb") as f:
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:dataset-{cache_uri}"):
with meta_path.open("rb") as f:
d = pickle.load(f)
instruments = d["info"]["instruments"]
fields = d["info"]["fields"]
freq = d["info"]["freq"]
last_update_time = d["info"]["last_update"]
inst_processors = d["info"]["inst_processors"]
index_data = im.get_index()
self.logger.debug("Updating dataset: {}".format(d))
@@ -960,7 +1016,12 @@ class DiskDatasetCache(DatasetCache):
)
data = self.provider.dataset(
instruments, fields, whole_calendar[current_index - rm_n_period], new_calendar[-1], freq
instruments,
fields,
whole_calendar[current_index - rm_n_period],
new_calendar[-1],
freq,
inst_processors=inst_processors,
)
if not data.empty:
@@ -995,7 +1056,7 @@ class DiskDatasetCache(DatasetCache):
# update meta file
d["info"]["last_update"] = str(new_calendar[-1])
with open(cp_cache_uri + ".meta", "wb") as f:
with meta_path.open("wb") as f:
pickle.dump(d, f)
return 0
@@ -1006,26 +1067,36 @@ class SimpleDatasetCache(DatasetCache):
def __init__(self, provider):
super(SimpleDatasetCache, self).__init__(provider)
try:
self.local_cache_path = C["local_cache_path"]
except KeyError as e:
self.local_cache_path: Path = Path(C["local_cache_path"]).expanduser().resolve()
except (KeyError, TypeError) as e:
self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism")
raise
self.logger.info(
f"DatasetCache directory: {self.local_cache_path}, "
f"modify the cache directory via the local_cache_path in the config"
)
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq)
local_cache_path = str(Path(self.local_cache_path).expanduser().resolve())
return hash_args(instruments, fields, start_time, end_time, freq, disk_cache, local_cache_path)
return hash_args(
instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors
)
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
def _dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
):
if disk_cache == 0:
# In this case, data_set cache is configured but will not be used.
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
os.makedirs(os.path.expanduser(self.local_cache_path), exist_ok=True)
cache_file = os.path.join(
self.local_cache_path, self._uri(instruments, fields, start_time, end_time, freq, disk_cache=disk_cache)
self.local_cache_path.mkdir(exist_ok=True, parents=True)
cache_file = self.local_cache_path.joinpath(
self._uri(
instruments, fields, start_time, end_time, freq, disk_cache=disk_cache, inst_processors=inst_processors
)
)
gen_flag = False
if os.path.exists(cache_file):
if cache_file.exists():
if disk_cache == 1:
# use cache
df = pd.read_pickle(cache_file)
@@ -1037,7 +1108,9 @@ class SimpleDatasetCache(DatasetCache):
gen_flag = True
if gen_flag:
data = self.provider.dataset(instruments, normalize_cache_fields(fields), start_time, end_time, freq)
data = self.provider.dataset(
instruments, normalize_cache_fields(fields), start_time, end_time, freq, inst_processors=inst_processors
)
data.to_pickle(cache_file)
return self.cache_to_origin_data(data, fields)
@@ -1045,26 +1118,53 @@ class SimpleDatasetCache(DatasetCache):
class DatasetURICache(DatasetCache):
"""Prepared cache mechanism for server."""
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache)
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
def dataset(
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
):
if "local" in C.dataset_provider.lower():
# use LocalDatasetProvider
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
return self.provider.dataset(
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
)
if disk_cache == 0:
# do not use data_set cache, load data from remote expression cache directly
return self.provider.dataset(instruments, fields, start_time, end_time, freq, disk_cache, return_uri=False)
return self.provider.dataset(
instruments,
fields,
start_time,
end_time,
freq,
disk_cache,
return_uri=False,
inst_processors=inst_processors,
)
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
if inst_processors:
raise ValueError(
f"{self.__class__.__name__} does not support inst_processor. "
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
)
# use ClientDatasetProvider
feature_uri = self._uri(instruments, fields, None, None, freq, disk_cache=disk_cache)
feature_uri = self._uri(
instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors
)
value, expire = MemCacheExpire.get_cache(H["f"], feature_uri)
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
if value is None or expire or not os.path.exists(mnt_feature_uri):
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
if value is None or expire or not mnt_feature_uri.exists():
df, uri = self.provider.dataset(
instruments, fields, start_time, end_time, freq, disk_cache, return_uri=True
instruments,
fields,
start_time,
end_time,
freq,
disk_cache,
return_uri=True,
inst_processors=inst_processors,
)
# cache uri
MemCacheExpire.set_cache(H["f"], uri, uri)
@@ -1072,7 +1172,6 @@ class DatasetURICache(DatasetCache):
# HZ['f'][uri] = df.copy()
get_module_logger("cache").debug(f"get feature from {C.dataset_provider}")
else:
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
get_module_logger("cache").debug("get feature from uri cache")

View File

@@ -5,28 +5,34 @@
from __future__ import division
from __future__ import print_function
import os
import re
import abc
import copy
import time
import queue
import bisect
import logging
import importlib
import traceback
import numpy as np
import pandas as pd
from multiprocessing import Pool
from typing import Iterable, Union
from .cache import H
from ..config import C
from .ops import Operators
from ..log import get_module_logger
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
from .base import Feature
from .ops import Operators
from .inst_processor import InstProcessor
from ..log import get_module_logger
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
from ..utils import (
Wrapper,
init_instance_by_config,
register_wrapper,
get_module_by_module_path,
parse_field,
hash_args,
normalize_cache_fields,
code_to_fname,
)
class ProviderBackendMixin:
@@ -48,8 +54,14 @@ class ProviderBackendMixin:
# default provider_uri map
if "provider_uri" not in backend_kwargs:
# if the user has no uri configured, use: uri = uri_map[freq]
# NOTE: provider_uri priority
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
# 3. qlib.init: provider_uri
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {})
freq = kwargs.get("freq", "day")
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
if freq not in provider_uri_map:
provider_uri_map[freq] = C.dpm.get_data_path(freq)
backend_kwargs["provider_uri"] = provider_uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)
@@ -199,13 +211,23 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin):
'filter_start_time': None,
'filter_end_time': None}]}
"""
from .filter import SeriesDFilter
if filter_pipe is None:
filter_pipe = []
config = {"market": market, "filter_pipe": []}
# the order of the filters will affect the result, so we need to keep
# the order
for filter_t in filter_pipe:
config["filter_pipe"].append(filter_t.to_config())
if isinstance(filter_t, dict):
_config = filter_t
elif isinstance(filter_t, SeriesDFilter):
_config = filter_t.to_config()
else:
raise TypeError(
f"Unsupported filter types: {type(filter_t)}! Filter only supports dict or isinstance(filter, SeriesDFilter)"
)
config["filter_pipe"].append(_config)
return config
@abc.abstractmethod
@@ -341,7 +363,7 @@ class DatasetProvider(abc.ABC):
"""
@abc.abstractmethod
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"):
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=[]):
"""Get dataset data.
Parameters
@@ -356,6 +378,8 @@ class DatasetProvider(abc.ABC):
end of the time range.
freq : str
time frequency.
inst_processors: Iterable[Union[dict, InstProcessor]]
the operations performed on each instrument
Returns
----------
@@ -372,6 +396,7 @@ class DatasetProvider(abc.ABC):
end_time=None,
freq="day",
disk_cache=1,
inst_processors=[],
**kwargs,
):
"""Get task uri, used when generating rabbitmq task in qlib_server
@@ -392,7 +417,8 @@ class DatasetProvider(abc.ABC):
whether to skip(0)/use(1)/replace(2) disk_cache.
"""
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache)
# TODO: qlib-server support inst_processors
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache, inst_processors)
@staticmethod
def get_instruments_d(instruments, freq):
@@ -433,7 +459,7 @@ class DatasetProvider(abc.ABC):
return [ExpressionD.get_expression_instance(f) for f in fields]
@staticmethod
def dataset_processor(instruments_d, column_names, start_time, end_time, freq):
def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=[]):
"""
Load and process the data, return the data set.
- default using multi-kernel method.
@@ -459,6 +485,7 @@ class DatasetProvider(abc.ABC):
normalize_column_names,
spans,
C,
inst_processors,
),
)
else:
@@ -473,6 +500,7 @@ class DatasetProvider(abc.ABC):
normalize_column_names,
None,
C,
inst_processors,
),
)
@@ -494,7 +522,9 @@ class DatasetProvider(abc.ABC):
return data
@staticmethod
def expression_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None):
def expression_calculator(
inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]
):
"""
Calculate the expressions for one instrument, return a df result.
If the expression has been calculated before, load from cache.
@@ -518,13 +548,17 @@ class DatasetProvider(abc.ABC):
data.index = _calendar[data.index.values.astype(int)]
data.index.names = ["datetime"]
if spans is None:
return data
else:
if spans is not None:
mask = np.zeros(len(data), dtype=bool)
for begin, end in spans:
mask |= (data.index >= begin) & (data.index <= end)
return data[mask]
data = data[mask]
for _processor in inst_processors:
if _processor:
_processor_obj = init_instance_by_config(_processor, accept_types=InstProcessor)
data = _processor_obj(data)
return data
class LocalCalendarProvider(CalendarProvider):
@@ -537,11 +571,6 @@ class LocalCalendarProvider(CalendarProvider):
super(LocalCalendarProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
def _uri_cal(self):
"""Calendar file uri."""
return os.path.join(C.get_data_path(), "calendars", "{}.txt")
def load_calendar(self, freq, future):
"""Load original calendar timestamp from file.
@@ -601,11 +630,6 @@ class LocalInstrumentProvider(InstrumentProvider):
Provide instrument data from local data source.
"""
@property
def _uri_inst(self):
"""Instrument file uri."""
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
def _load_instruments(self, market, freq):
return self.backend_obj(market=market, freq=freq).data
@@ -654,14 +678,9 @@ class LocalFeatureProvider(FeatureProvider):
super(LocalFeatureProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
def _uri_data(self):
"""Static feature file uri."""
return os.path.join(C.get_data_path(), "features", "{}", "{}.{}.bin")
def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
field = str(field)[1:]
instrument = code_to_fname(instrument)
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
@@ -703,7 +722,15 @@ class LocalDatasetProvider(DatasetProvider):
def __init__(self):
pass
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"):
def dataset(
self,
instruments,
fields,
start_time=None,
end_time=None,
freq="day",
inst_processors=[],
):
instruments_d = self.get_instruments_d(instruments, freq)
column_names = self.get_column_names(fields)
cal = Cal.calendar(start_time, end_time, freq)
@@ -712,7 +739,9 @@ class LocalDatasetProvider(DatasetProvider):
start_time = cal[0]
end_time = cal[-1]
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
data = self.dataset_processor(
instruments_d, column_names, start_time, end_time, freq, inst_processors=inst_processors
)
return data
@@ -855,6 +884,7 @@ class ClientDatasetProvider(DatasetProvider):
freq="day",
disk_cache=0,
return_uri=False,
inst_processors=[],
):
if Inst.get_inst_type(instruments) == Inst.DICT:
get_module_logger("data").warning(
@@ -894,7 +924,7 @@ class ClientDatasetProvider(DatasetProvider):
start_time = cal[0]
end_time = cal[-1]
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors)
if return_uri:
return data, feature_uri
else:
@@ -907,6 +937,13 @@ class ClientDatasetProvider(DatasetProvider):
- using single-process implementation.
"""
# TODO: support inst_processors, need to change the code of qlib-server at the same time
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
if inst_processors:
raise ValueError(
f"{self.__class__.__name__} does not support inst_processor. "
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
)
self.conn.send_request(
request_type="feature",
request_content={
@@ -926,7 +963,7 @@ class ClientDatasetProvider(DatasetProvider):
get_module_logger("data").debug("get result")
try:
# pre-mound nfs, used for demo
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri)
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
get_module_logger("data").debug("finish slicing data")
if return_uri:
@@ -964,6 +1001,7 @@ class BaseProvider:
end_time=None,
freq="day",
disk_cache=None,
inst_processors=[],
):
"""
Parameters:
@@ -978,9 +1016,11 @@ class BaseProvider:
disk_cache = C.default_disk_cache if disk_cache is None else disk_cache
fields = list(fields) # In case of tuple.
try:
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache)
return DatasetD.dataset(
instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors
)
except TypeError:
return DatasetD.dataset(instruments, fields, start_time, end_time, freq)
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, inst_processors=inst_processors)
class LocalProvider(BaseProvider):
@@ -1028,13 +1068,21 @@ class ClientProvider(BaseProvider):
"""
def __init__(self):
def is_instance_of_provider(instance: object, cls: type):
if isinstance(instance, Wrapper):
p = getattr(instance, "_provider", None)
return False if p is None else isinstance(p, cls)
return isinstance(instance, cls)
from .client import Client
self.client = Client(C.flask_server, C.flask_port)
self.logger = get_module_logger(self.__class__.__name__)
if isinstance(Cal, ClientCalendarProvider):
if is_instance_of_provider(Cal, ClientCalendarProvider):
Cal.set_conn(self.client)
if isinstance(Inst, ClientInstrumentProvider):
if is_instance_of_provider(Inst, ClientInstrumentProvider):
Inst.set_conn(self.client)
if hasattr(DatasetD, "provider"):
DatasetD.provider.set_conn(self.client)

View File

@@ -18,6 +18,7 @@ from ...config import C
from ...utils import parse_config, transform_end_date, init_instance_by_config
from ...utils.serial import Serializable
from .utils import fetch_df_by_index
from ...utils import lazy_sort_index
from pathlib import Path
from .loader import DataLoader
@@ -146,7 +147,8 @@ class DataHandler(Serializable):
# Setup data.
# _data may be with multiple column index level. The outer level indicates the feature set name
with TimeInspector.logt("Loading data"):
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
# make sure the fetch method is based on a index-sorted pd.DataFrame
self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time))
# TODO: cache
CS_ALL = "__all" # return all columns with single-level index column
@@ -293,11 +295,14 @@ class DataHandlerLP(DataHandler):
# process type
PTYPE_I = "independent"
# - self._infer will be processed by infer_processors
# - self._learn will be processed by learn_processors
# - self._infer will be processed by shared_processors + infer_processors
# - self._learn will be processed by shared_processors + learn_processors
# NOTE:
PTYPE_A = "append"
# - self._infer will be processed by infer_processors
# - self._learn will be processed by infer_processors + learn_processors
# - self._infer will be processed by shared_processors + infer_processors
# - self._learn will be processed by shared_processors + infer_processors + learn_processors
# - (e.g. self._infer processed by learn_processors )
def __init__(
@@ -306,8 +311,9 @@ class DataHandlerLP(DataHandler):
start_time=None,
end_time=None,
data_loader: Union[dict, str, DataLoader] = None,
infer_processors=[],
learn_processors=[],
infer_processors: List = [],
learn_processors: List = [],
shared_processors: List = [],
process_type=PTYPE_A,
drop_raw=False,
**kwargs,
@@ -358,7 +364,8 @@ class DataHandlerLP(DataHandler):
# Setup preprocessor
self.infer_processors = [] # for lint
self.learn_processors = [] # for lint
for pname in "infer_processors", "learn_processors":
self.shared_processors = [] # for lint
for pname in "infer_processors", "learn_processors", "shared_processors":
for proc in locals()[pname]:
getattr(self, pname).append(
init_instance_by_config(
@@ -373,9 +380,12 @@ class DataHandlerLP(DataHandler):
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
def get_all_processors(self):
return self.infer_processors + self.learn_processors
return self.shared_processors + self.infer_processors + self.learn_processors
def fit(self):
"""
fit data without processing the data
"""
for proc in self.get_all_processors():
with TimeInspector.logt(f"{proc.__class__.__name__}"):
proc.fit(self._data)
@@ -388,30 +398,68 @@ class DataHandlerLP(DataHandler):
"""
self.process_data(with_fit=True)
@staticmethod
def _run_proc_l(
df: pd.DataFrame, proc_l: List[processor_module.Processor], with_fit: bool, check_for_infer: bool
) -> pd.DataFrame:
for proc in proc_l:
if check_for_infer and not proc.is_for_infer():
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(df)
df = proc(df)
return df
@staticmethod
def _is_proc_readonly(proc_l: List[processor_module.Processor]):
"""
NOTE: it will return True if `len(proc_l) == 0`
"""
for p in proc_l:
if not p.readonly():
return False
return True
def process_data(self, with_fit: bool = False):
"""
process_data data. Fun `processor.fit` if necessary
Notation: (data) [processor]
# data processing flow of self.process_type == DataHandlerLP.PTYPE_I
(self._data)-[shared_processors]-(_shared_df)-[learn_processors]-(_learn_df)
\
-[infer_processors]-(_infer_df)
# data processing flow of self.process_type == DataHandlerLP.PTYPE_A
(self._data)-[shared_processors]-(_shared_df)-[infer_processors]-(_infer_df)-[learn_processors]-(_learn_df)
Parameters
----------
with_fit : bool
The input of the `fit` will be the output of the previous processor
"""
# data for inference
_infer_df = self._data
if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data
_infer_df = _infer_df.copy()
# shared data processors
# 1) assign
_shared_df = self._data
if not self._is_proc_readonly(self.shared_processors): # avoid modifying the original data
_shared_df = _shared_df.copy()
# 2) process
_shared_df = self._run_proc_l(_shared_df, self.shared_processors, with_fit=with_fit, check_for_infer=True)
# data for inference
# 1) assign
_infer_df = _shared_df
if not self._is_proc_readonly(self.infer_processors): # avoid modifying the original data
_infer_df = _infer_df.copy()
# 2) process
_infer_df = self._run_proc_l(_infer_df, self.infer_processors, with_fit=with_fit, check_for_infer=True)
for proc in self.infer_processors:
if not proc.is_for_infer():
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(_infer_df)
_infer_df = proc(_infer_df)
self._infer = _infer_df
# data for learning
# 1) assign
if self.process_type == DataHandlerLP.PTYPE_I:
_learn_df = self._data
elif self.process_type == DataHandlerLP.PTYPE_A:
@@ -419,14 +467,11 @@ class DataHandlerLP(DataHandler):
_learn_df = _infer_df
else:
raise NotImplementedError(f"This type of input is not supported")
if len(self.learn_processors) > 0: # avoid modifying the original data
if not self._is_proc_readonly(self.learn_processors): # avoid modifying the original data
_learn_df = _learn_df.copy()
for proc in self.learn_processors:
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(_learn_df)
_learn_df = proc(_learn_df)
# 2) process
_learn_df = self._run_proc_l(_learn_df, self.learn_processors, with_fit=with_fit, check_for_infer=False)
self._learn = _learn_df
if self.drop_raw:

View File

@@ -1,17 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import abc
import warnings
import numpy as np
import pandas as pd
from typing import Tuple, Union
from typing import Tuple, Union, List
from qlib.data import D
from qlib.data import filter as filter_module
from qlib.data.filter import BaseDFilter
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
from qlib.log import get_module_logger
@@ -62,11 +58,11 @@ class DLWParser(DataLoader):
Extracting this class so that QlibDataLoader and other dataloaders(such as QdbDataLoader) can share the fields.
"""
def __init__(self, config: Tuple[list, tuple, dict]):
def __init__(self, config: Union[list, tuple, dict]):
"""
Parameters
----------
config : Tuple[list, tuple, dict]
config : Union[list, tuple, dict]
Config will be used to describe the fields and column names
.. code-block::
@@ -88,7 +84,7 @@ class DLWParser(DataLoader):
else:
self.fields = self._parse_fields_info(config)
def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]:
def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, list]:
if len(fields_info) == 0:
raise ValueError("The size of fields must be greater than 0")
@@ -104,7 +100,15 @@ class DLWParser(DataLoader):
return exprs, names
@abc.abstractmethod
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
def load_group_df(
self,
instruments,
exprs: list,
names: list,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
gp_name: str = None,
) -> pd.DataFrame:
"""
load the dataframe for specific group
@@ -128,7 +132,7 @@ class DLWParser(DataLoader):
if self.is_group:
df = pd.concat(
{
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
for grp, (exprs, names) in self.fields.items()
},
axis=1,
@@ -142,7 +146,14 @@ class DLWParser(DataLoader):
class QlibDataLoader(DLWParser):
"""Same as QlibDataLoader. The fields can be define by config"""
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True, freq="day"):
def __init__(
self,
config: Tuple[list, tuple, dict],
filter_pipe: List = None,
swap_level: bool = True,
freq: Union[str, dict] = "day",
inst_processor: dict = None,
):
"""
Parameters
----------
@@ -152,20 +163,41 @@ class QlibDataLoader(DLWParser):
Filter pipe for the instruments
swap_level :
Whether to swap level of MultiIndex
freq: dict or str
If type(config) == dict and type(freq) == str, load config data using freq.
If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
inst_processor: dict
If inst_processor is not None and type(config) == dict; load config[<group_name>] data using inst_processor[<group_name>]
"""
if filter_pipe is not None:
assert isinstance(filter_pipe, list), "The type of `filter_pipe` must be list."
filter_pipe = [
init_instance_by_config(fp, None if "module_path" in fp else filter_module, accept_types=BaseDFilter)
for fp in filter_pipe
]
self.filter_pipe = filter_pipe
self.swap_level = swap_level
self.freq = freq
# sample
self.inst_processor = inst_processor if inst_processor is not None else {}
assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict"
super().__init__(config)
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
if self.is_group:
# check sample config
if isinstance(freq, dict):
for _gp in config.keys():
if _gp not in freq:
raise ValueError(f"freq(={freq}) missing group(={_gp})")
assert (
self.inst_processor
), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty"
def load_group_df(
self,
instruments,
exprs: list,
names: list,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
gp_name: str = None,
) -> pd.DataFrame:
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
instruments = "all"
@@ -174,7 +206,10 @@ class QlibDataLoader(DLWParser):
elif self.filter_pipe is not None:
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
df = D.features(instruments, exprs, start_time, end_time, self.freq)
freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
df = D.features(
instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, [])
)
df.columns = names
if self.swap_level:
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
@@ -199,6 +234,10 @@ class StaticDataLoader(DataLoader):
self.join = join
self._data = None
def __getstate__(self) -> dict:
# avoid pickling `self._data`
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
self._maybe_load_raw_data()
if instruments is None:

View File

@@ -73,6 +73,14 @@ class Processor(Serializable):
"""
return True
def readonly(self) -> bool:
"""
Does the processor treat the input data readonly (i.e. does not write the input data) when processsing
Knowning the readonly information is helpful to the Handler to avoid uncessary copy
"""
return False
def config(self, **kwargs):
attr_list = {"fit_start_time", "fit_end_time"}
for k, v in kwargs.items():
@@ -92,6 +100,9 @@ class DropnaProcessor(Processor):
def __call__(self, df):
return df.dropna(subset=get_group_columns(df, self.fields_group))
def readonly(self):
return True
class DropnaLabel(DropnaProcessor):
def __init__(self, fields_group="label"):
@@ -113,6 +124,9 @@ class DropCol(Processor):
mask = df.columns.isin(self.col_list)
return df.loc[:, ~mask]
def readonly(self):
return True
class FilterCol(Processor):
def __init__(self, fields_group="feature", col_list=[]):
@@ -128,6 +142,9 @@ class FilterCol(Processor):
mask = df.columns.get_level_values(-1).isin(self.col_list)
return df.loc[:, mask]
def readonly(self):
return True
class TanhProcess(Processor):
"""Use tanh to process noise data"""

View File

@@ -0,0 +1,23 @@
import abc
import json
import pandas as pd
class InstProcessor:
@abc.abstractmethod
def __call__(self, df: pd.DataFrame, *args, **kwargs):
"""
process the data
NOTE: **The processor could change the content of `df` inplace !!!!! **
User should keep a copy of data outside
Parameters
----------
df : pd.DataFrame
The raw_df of handler or result from previous processor.
"""
pass
def __str__(self):
return f"{self.__class__.__name__}:{json.dumps(self.__dict__, sort_keys=True, default=str)}"

View File

@@ -15,7 +15,7 @@ from scipy.stats import percentileofscore
from .base import Expression, ExpressionOps
from ..log import get_module_logger
from ..utils import get_cls_kwargs
from ..utils import get_callable_kwargs
try:
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
@@ -1513,7 +1513,7 @@ class OpsWrapper:
"""
for _operator in ops_list:
if isinstance(_operator, dict):
_ops_class, _ = get_cls_kwargs(_operator)
_ops_class, _ = get_callable_kwargs(_operator)
else:
_ops_class = _operator

View File

@@ -105,6 +105,20 @@ class AverageEnsemble(Ensemble):
"""
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
"""using sample:
from qlib.model.ens.ensemble import AverageEnsemble
pred_res['new_key_name'] = AverageEnsemble()(predict_dict)
Parameters
----------
ensemble_dict : dict
Dictionary you want to ensemble
Returns
-------
pd.DataFrame
The dictionary including ensenbling result
"""
# need to flatten the nested dict
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
values = list(ensemble_dict.values())

View File

@@ -12,13 +12,12 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
"""
import socket
import time
from typing import Callable, List
from qlib.data.dataset import Dataset
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
from qlib.utils import flatten_dict, get_callable_kwargs, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.recorder import Recorder
@@ -72,7 +71,7 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
cls, kwargs = get_callable_kwargs(record, default_module="qlib.workflow.record_temp")
if cls is SignalRecord:
rconf = {"model": model, "dataset": dataset, "recorder": rec}
else:

View File

@@ -43,8 +43,9 @@ def get_redis_connection():
#################### Data ####################
def read_bin(file_path, start_index, end_index):
with open(file_path, "rb") as f:
def read_bin(file_path: Union[str, Path], start_index, end_index):
file_path = Path(file_path.expanduser().resolve())
with file_path.open("rb") as f:
# read start_index
ref_start_index = int(np.frombuffer(f.read(4), dtype="<f")[0])
si = max(ref_start_index, start_index)
@@ -189,9 +190,9 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]):
return module
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
"""
extract class and kwargs from config info
extract class/func and kwargs from config info
Parameters
----------
@@ -206,22 +207,22 @@ def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleTy
Returns
-------
(type, dict):
the class object and it's arguments.
the class/func object and it's arguments.
"""
if isinstance(config, dict):
module = get_module_by_module_path(config.get("module_path", default_module))
# raise AttributeError
klass = getattr(module, config["class"])
_callable = getattr(module, config["class" if "class" in config else "func"])
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
klass = getattr(module, config)
_callable = getattr(module, config)
kwargs = {}
else:
raise NotImplementedError(f"This type of input is not supported")
return klass, kwargs
return _callable, kwargs
def init_instance_by_config(
@@ -272,7 +273,7 @@ def init_instance_by_config(
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
return pickle.load(f)
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)
@@ -570,9 +571,11 @@ def get_pre_trading_date(trading_date, future=False):
def transform_end_date(end_date=None, freq="day"):
"""get previous trading date
"""handle the end date with various format
If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned.
Otherwise, returns the end_date
----------
end_date: str
end trading date
@@ -738,7 +741,8 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
sorted dataframe
"""
idx = df.index if axis == 0 else df.columns
if idx.is_monotonic_increasing:
# NOTE: MultiIndex.is_lexsorted() is a deprecated method in Pandas 1.3.0 and is suggested to be replaced by MultiIndex.is_monotonic_increasing (see discussion here: https://github.com/pandas-dev/pandas/issues/32259). However, in case older versions of Pandas is implemented, MultiIndex.is_lexsorted() is necessary to prevent certain fatal errors.
if idx.is_monotonic_increasing and not (isinstance(idx, pd.MultiIndex) and not idx.is_lexsorted()):
return df
else:
return df.sort_index(axis=axis)
@@ -792,7 +796,7 @@ class Wrapper:
return "{name}(provider={provider})".format(name=self.__class__.__name__, provider=self._provider)
def __getattr__(self, key):
if self._provider is None:
if self.__dict__.get("_provider", None) is None:
raise AttributeError("Please run qlib.init() first using qlib")
return getattr(self._provider, key)

View File

@@ -1,10 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
import pickle
import typing
import dill
from pathlib import Path
from typing import Union
@@ -18,6 +17,7 @@ class Serializable:
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
default_dump_all = False # if dump all things
FLAG_KEY = "_qlib_serial_flag"
def __init__(self):
self._dump_all = self.default_dump_all
@@ -45,8 +45,6 @@ class Serializable:
"""
return getattr(self, "_exclude", [])
FLAG_KEY = "_qlib_serial_flag"
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
"""
configure the serializable object
@@ -124,3 +122,22 @@ class Serializable:
return dill
else:
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
@staticmethod
def general_dump(obj, path: Union[Path, str]):
"""
A general dumping method for object
Parameters
----------
obj : object
the object to be dumped
path : Union[Path, str]
the target path the data will be dumped
"""
path = Path(path)
if isinstance(obj, Serializable):
obj.to_pickle(path)
else:
with path.open("wb") as f:
pickle.dump(obj, f)

View File

@@ -38,13 +38,13 @@ class QlibRecorder:
.. code-block:: Python
# start new experiment and recorder
with R.start('test', 'recorder_1'):
with R.start(experiment_name='test', recorder_name='recorder_1'):
model.fit(dataset)
R.log...
... # further operations
# resume previous experiment and recorder
with R.start('test', 'recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder.
with R.start(experiment_name='test', recorder_name='recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder.
... # further operations
Parameters

View File

@@ -53,7 +53,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder)
qlib.init(**config.get("qlib_init"), exp_manager=exp_manager)
task_train(config.get("task"), experiment_name=experiment_name)
recorder = task_train(config.get("task"), experiment_name=experiment_name)
recorder.save_objects(config=config)
# function to run worklflow by config

View File

@@ -325,7 +325,7 @@ class MLflowExperiment(Experiment):
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None):
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None, filter_string: str = ""):
"""
Parameters
----------
@@ -334,8 +334,12 @@ class MLflowExperiment(Experiment):
status : str
the criteria based on status to filter results.
`None` indicates no filtering.
filter_string : str
mlflow supported filter string like 'params."my_param"="a" and tags."my_tag"="b"', use this will help to reduce too much run number.
"""
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
runs = self._client.search_runs(
self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results, filter_string=filter_string
)
recorders = dict()
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])

View File

@@ -1,7 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from urllib.parse import urlparse
import mlflow
from filelock import FileLock
from mlflow.exceptions import MlflowException
from mlflow.entities import ViewType
import os, logging
@@ -191,6 +193,13 @@ class ExpManager:
if experiment_name is None:
experiment_name = self._default_exp_name
logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
# NOTE: mlflow doesn't consider the lock for recording multiple runs
# So we supported it in the interface wrapper
pr = urlparse(self.uri)
if pr.scheme == "file":
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")) as f:
return self.create_exp(experiment_name), True
return self.create_exp(experiment_name), True
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:

View File

@@ -10,6 +10,7 @@ from typing import List, Tuple, Union
from qlib.data.data import D
from qlib.log import get_module_logger
from qlib.model.ens.group import RollingGroup
from qlib.utils import transform_end_date
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import Collector, RecorderCollector
@@ -118,6 +119,7 @@ class RollingStrategy(OnlineStrategy):
task_template = [task_template]
self.task_template = task_template
self.rg = rolling_gen
assert issubclass(self.rg.__class__, RollingGen), "The rolling strategy relies on the feature if RollingGen"
self.tool = OnlineToolR(self.exp_name)
self.ta = TimeAdjuster()
@@ -174,28 +176,20 @@ class RollingStrategy(OnlineStrategy):
Returns:
List[dict]: a list of new tasks.
"""
# TODO: filter recorders by latest test segments is not a necessary
latest_records, max_test = self._list_latest(self.tool.online_models())
if max_test is None:
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
calendar_latest = transform_end_date(cur_time)
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []
tasks_tmp = []
for rec in latest_records:
task = rec.load_object("task")
old_tasks.append(deepcopy(task))
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
tasks_tmp.append(task)
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
return new_tasks
return []
res = []
for rec in latest_records:
task = rec.load_object("task")
res.extend(self.rg.gen_following_tasks(task, calendar_latest))
return res
def _list_latest(self, rec_list: List[Recorder]):
"""

View File

@@ -1,6 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Updater is a module to update artifacts such as predictions when the stock data is updating.
"""
@@ -10,11 +9,12 @@ from abc import ABCMeta, abstractmethod
import pandas as pd
from qlib import get_module_logger
from qlib.data import D
from qlib.data.dataset import DatasetH
from qlib.data.dataset import Dataset, DatasetH
from qlib.data.dataset.handler import DataHandlerLP
from qlib.model import Model
from qlib.utils import get_date_by_shift
from qlib.workflow.recorder import Recorder
from qlib.workflow.record_temp import SignalRecord
class RMDLoader:
@@ -72,12 +72,25 @@ class RecordUpdater(metaclass=ABCMeta):
...
class PredUpdater(RecordUpdater):
class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
"""
Update the prediction in the Recorder
Dataset-Based Updater
- Provding updating feature for Updating data based on Qlib Dataset
Assumption
- Based on Qlib dataset
- The data to be updated is a multi-level index pd.DataFrame. For example label , prediction.
LABEL0
datetime instrument
2021-05-10 SH600000 0.006965
SH600004 0.003407
... ...
2021-05-28 SZ300498 0.015748
SZ300676 -0.001321
"""
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"):
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"):
"""
Init PredUpdater.
@@ -100,13 +113,27 @@ class PredUpdater(RecordUpdater):
self.to_date = to_date
self.hist_ref = hist_ref
self.freq = freq
self.fname = fname
self.rmdl = RMDLoader(rec=record)
latest_date = D.calendar(freq=freq)[-1]
if to_date == None:
to_date = D.calendar(freq=freq)[-1]
self.to_date = pd.Timestamp(to_date)
self.old_pred = record.load_object("pred.pkl")
self.last_end = self.old_pred.index.get_level_values("datetime").max()
to_date = latest_date
to_date = pd.Timestamp(to_date)
if to_date >= latest_date:
self.logger.warning(
f"The given `to_date`({to_date}) is later than `latest_date`({latest_date}). So `to_date` is clipped to `latest_date`."
)
to_date = latest_date
self.to_date = to_date
# FIXME: it will raise error when running routine with delay trainer
# should we use another prediction updater for delay trainer?
self.old_data: pd.DataFrame = record.load_object(fname)
# dropna is for being compatible to some data with future information(e.g. label)
# The recent label data should be updated together
self.last_end = self.old_data.dropna().index.get_level_values("datetime").max()
def prepare_data(self) -> DatasetH:
"""
@@ -125,7 +152,7 @@ class PredUpdater(RecordUpdater):
def update(self, dataset: DatasetH = None):
"""
Update the prediction in a recorder.
Update the data in a recorder.
Args:
DatasetH: the instance of DatasetH. None for reprepare.
@@ -137,7 +164,7 @@ class PredUpdater(RecordUpdater):
if self.last_end >= self.to_date:
self.logger.info(
f"The prediction in {self.record.info['id']} are latest ({self.last_end}). No need to update to {self.to_date}."
f"The data in {self.record.info['id']} are latest ({self.last_end}). No need to update to {self.to_date}."
)
return
@@ -146,14 +173,49 @@ class PredUpdater(RecordUpdater):
# For reusing the dataset
dataset = self.prepare_data()
self.record.save_objects(**{self.fname: self.get_update_data(dataset)})
@abstractmethod
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
"""
return the updated data based on the given dataset
The difference between `get_update_data` and `update`
- `update_date` only include some data specific feature
- `update` include some general routine steps(e.g. prepare dataset, checking)
"""
...
class PredUpdater(DSBasedUpdater):
"""
Update the prediction in the Recorder
"""
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
# Load model
model = self.rmdl.get_model()
new_pred: pd.Series = model.predict(dataset)
cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0)
cb_pred = pd.concat([self.old_data, new_pred.to_frame("score")], axis=0)
cb_pred = cb_pred.sort_index()
self.record.save_objects(**{"pred.pkl": cb_pred})
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
return cb_pred
class LabelUpdater(DSBasedUpdater):
"""
Update the label in the recorder
Assumption
- The label is generated from record_temp.SignalRecord.
"""
def __init__(self, record: Recorder, to_date=None, **kwargs):
super().__init__(record, to_date=to_date, fname="label.pkl", **kwargs)
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
new_label = SignalRecord.generate_label(dataset)
cb_data = pd.concat([self.old_data, new_label], axis=0)
cb_data = cb_data[~cb_data.index.duplicated(keep="last")].sort_index()
return cb_data

View File

@@ -11,7 +11,7 @@ from typing import List, Union
from qlib.data.dataset import TSDatasetH
from qlib.log import get_module_logger
from qlib.utils import get_cls_kwargs
from qlib.utils import get_callable_kwargs
from qlib.utils.exceptions import LoadObjectError
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
@@ -172,7 +172,7 @@ class OnlineToolR(OnlineTool):
hist_ref = 0
task = rec.load_object("task")
# Special treatment of historical dependencies
cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset")
cls, kwargs = get_callable_kwargs(task["dataset"], default_module="qlib.data.dataset")
if issubclass(cls, TSDatasetH):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
try:

View File

@@ -121,6 +121,30 @@ class SignalRecord(RecordTemp):
self.model = model
self.dataset = dataset
@staticmethod
def generate_label(dataset):
# NOTE:
# Python doesn't provide the downcasting mechanism.
# We use the trick here to downcast the class
orig_cls = dataset.__class__
dataset.__class__ = DatasetH
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
try:
# Assume the backend handler is DataHandlerLP
raw_label = dataset.prepare(**params)
except TypeError:
# The argument number is not right
del params["data_key"]
# The backend handler should be DataHandler
raw_label = dataset.prepare(**params)
except AttributeError:
# The data handler is initialize with `drop_raw=True`...
# So raw_label is not available
raw_label = None
dataset.__class__ = orig_cls
return raw_label
def generate(self, **kwargs):
# generate prediciton
pred = self.model.predict(self.dataset)
@@ -136,28 +160,8 @@ class SignalRecord(RecordTemp):
pprint(pred.head(5))
if isinstance(self.dataset, DatasetH):
# NOTE:
# Python doesn't provide the downcasting mechanism.
# We use the trick here to downcast the class
orig_cls = self.dataset.__class__
self.dataset.__class__ = DatasetH
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
try:
# Assume the backend handler is DataHandlerLP
raw_label = self.dataset.prepare(**params)
except TypeError:
# The argument number is not right
del params["data_key"]
# The backend handler should be DataHandler
raw_label = self.dataset.prepare(**params)
except AttributeError:
# The data handler is initialize with `drop_raw=True`...
# So raw_label is not available
raw_label = None
raw_label = self.generate_label(self.dataset)
self.recorder.save_objects(**{"label.pkl": raw_label})
self.dataset.__class__ = orig_cls
def list(self):
return ["pred.pkl", "label.pkl"]

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.utils.serial import Serializable
import mlflow, logging
import shutil, os, pickle, tempfile, codecs, pickle
from pathlib import Path
@@ -299,12 +300,16 @@ class MLflowRecorder(Recorder):
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
if local_path is not None:
self.client.log_artifacts(self.id, local_path, artifact_path)
path = Path(local_path)
if path.is_dir():
self.client.log_artifacts(self.id, local_path, artifact_path)
else:
self.client.log_artifact(self.id, local_path, artifact_path)
else:
temp_dir = Path(tempfile.mkdtemp()).resolve()
for name, data in kwargs.items():
with (temp_dir / name).open("wb") as f:
pickle.dump(data, f)
path = temp_dir / name
Serializable.general_dump(data, path)
self.client.log_artifact(self.id, temp_dir / name, artifact_path)
shutil.rmtree(temp_dir)

View File

@@ -139,6 +139,7 @@ class RecorderCollector(Collector):
rec_filter_func=None,
artifacts_path={"pred": "pred.pkl"},
artifacts_key=None,
list_kwargs={},
):
"""
Init RecorderCollector.
@@ -150,6 +151,7 @@ class RecorderCollector(Collector):
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
list_kwargs (str): arguments for list_recorders function.
"""
super().__init__(process_list=process_list)
if isinstance(experiment, str):
@@ -163,6 +165,7 @@ class RecorderCollector(Collector):
self.rec_key_func = rec_key_func
self.artifacts_key = artifacts_key
self.rec_filter_func = rec_filter_func
self.list_kwargs = list_kwargs
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
"""
@@ -187,7 +190,7 @@ class RecorderCollector(Collector):
collect_dict = {}
# filter records
recs = self.experiment.list_recorders()
recs = self.experiment.list_recorders(**self.list_kwargs)
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):

View File

@@ -5,6 +5,7 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp
"""
import abc
import copy
import pandas as pd
from typing import List, Union, Callable
from qlib.utils import transform_end_date
@@ -139,6 +140,53 @@ class RollingGen(TaskGen):
self.test_key = "test"
self.train_key = "train"
def _update_task_segs(self, task, segs):
# update segments of this task
task["dataset"]["kwargs"]["segments"] = copy.deepcopy(segs)
if self.ds_extra_mod_func is not None:
self.ds_extra_mod_func(task, self)
def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]:
"""
generating following rolling tasks for `task` until test_end
Parameters
----------
task : dict
Qlib task format
test_end : pd.Timestamp
the latest rolling task includes `test_end`
Returns
-------
List[dict]:
the following tasks of `task`(`task` itself is excluded)
"""
prev_seg = task["dataset"]["kwargs"]["segments"]
while True:
segments = {}
try:
for k, seg in prev_seg.items():
# decide how to shift
# expanding only for train data, the segments size of test data and valid data won't change
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# shift the segments data
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break
prev_seg = segments
t = copy.deepcopy(task) # deepcopy is necessary to avoid modify task inplace
self._update_task_segs(t, segments)
yield t
def generate(self, task: dict) -> List[dict]:
"""
Converting the task into a rolling task.
@@ -191,43 +239,23 @@ class RollingGen(TaskGen):
"""
res = []
prev_seg = None
test_end = None
while True:
t = copy.deepcopy(task)
t = copy.deepcopy(task)
# calculate segments
if prev_seg is None:
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = transform_end_date(segments[self.test_key][1])
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
else:
segments = {}
try:
for k, seg in prev_seg.items():
# decide how to shift
# expanding only for train data, the segments size of test data and valid data won't change
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# shift the segments data
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break
# calculate segments
# update segments of this task
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
prev_seg = segments
if self.ds_extra_mod_func is not None:
self.ds_extra_mod_func(t, self)
res.append(t)
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = transform_end_date(segments[self.test_key][1])
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
# update segments of this task
self._update_task_segs(t, segments)
res.append(t)
# Update the following rolling
res.extend(self.gen_following_tasks(t, test_end))
return res

View File

@@ -47,6 +47,14 @@ class TaskManager:
The tasks manager assumes that you will only update the tasks you fetched.
The mongo fetch one and update will make it date updating secure.
This class can be used as a tool from commandline. Here are serveral examples
.. code-block:: shell
python -m qlib.workflow.task.manage -t <pool_name> wait
python -m qlib.workflow.task.manage -t <pool_name> task_stat
.. note::
Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
@@ -80,7 +88,7 @@ class TaskManager:
task_pool: str
the name of Collection in MongoDB
"""
self.task_pool = getattr(get_mongodb(), task_pool)
self.task_pool: pymongo.collection.Collection = getattr(get_mongodb(), task_pool)
self.logger = get_module_logger(self.__class__.__name__)
@staticmethod
@@ -101,6 +109,20 @@ class TaskManager:
return task
def _decode_task(self, task):
"""
_decode_task is Serialization tool.
Mongodb needs JSON, so it needs to convert Python objects into JSON objects through pickle
Parameters
----------
task : dict
task information
Returns
-------
dict
JSON required by mongodb
"""
for prefix in self.ENCODE_FIELDS_PREFIX:
for k in list(task.keys()):
if k.startswith(prefix):
@@ -211,6 +233,7 @@ class TaskManager:
r = self.task_pool.find_one({"filter": t})
except InvalidDocument:
r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
# When r is none, it indicates that r s a new task
if r is None:
new_tasks.append(t)
if not dry_run:
@@ -461,11 +484,11 @@ def run_task(
After running this method, here are 4 situations (before_status -> after_status):
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param, it means that the task has not been started
STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param, it means that the task has been started but not completed
STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param

View File

@@ -1,10 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys, traceback, signal, atexit, logging
import atexit
import logging
import sys
import traceback
from ..log import get_module_logger
from . import R
from .recorder import Recorder
from ..log import get_module_logger
logger = get_module_logger("workflow", logging.INFO)

View File

@@ -1,7 +1,7 @@
- [Download Qlib Data](#Download-Qlib-Data)
- [Download CN Data](#Download-CN-Data)
- [Downlaod US Data](#Downlaod-US-Data)
- [Download US Data](#Download-US-Data)
- [Download CN Simple Data](#Download-CN-Simple-Data)
- [Help](#Help)
- [Using in Qlib](#Using-in-Qlib)

View File

@@ -78,6 +78,7 @@ def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"):
data_list.append(_row_data[0])
data_list = sorted(data_list)
date_list = generate_qlib_calendar(data_list, freq=freq)
date_list = sorted(set(daily_calendar.loc[:, 0].values.tolist() + date_list))
write_calendar_to_qlib(qlib_dir, date_list, freq=freq)
bs.logout()
logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31")

View File

@@ -32,6 +32,7 @@ CALENDAR_BENCH_URL_MAP = {
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
# NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks
"US_ALL": "^GSPC",
"IN_ALL": "^NSEI",
}
@@ -39,6 +40,7 @@ _BENCH_CALENDAR_LIST = None
_ALL_CALENDAR_LIST = None
_HS_SYMBOLS = None
_US_SYMBOLS = None
_IN_SYMBOLS = None
_EN_FUND_SYMBOLS = None
_CALENDAR_MAP = {}
@@ -67,7 +69,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
calendar = _CALENDAR_MAP.get(bench_code, None)
if calendar is None:
if bench_code.startswith("US_"):
if bench_code.startswith("US_") or bench_code.startswith("IN_"):
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
else:
@@ -298,6 +300,47 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
return _US_SYMBOLS
def get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
"""get IN stock symbols
Returns
-------
stock symbols
"""
global _IN_SYMBOLS
@deco_retry
def _get_nifty():
url = f"https://www1.nseindia.com/content/equities/EQUITY_L.csv"
df = pd.read_csv(url)
df = df.rename(columns={"SYMBOL": "Symbol"})
df["Symbol"] = df["Symbol"] + ".NS"
_symbols = df["Symbol"].dropna()
_symbols = _symbols.unique().tolist()
return _symbols
if _IN_SYMBOLS is None:
_all_symbols = _get_nifty()
if qlib_data_path is not None:
for _index in ["nifty"]:
ins_df = pd.read_csv(
Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"),
sep="\t",
names=["symbol", "start_date", "end_date"],
)
_all_symbols += ins_df["symbol"].unique().tolist()
def _format(s_):
s_ = s_.replace(".", "-")
s_ = s_.strip("$")
s_ = s_.strip("*")
return s_
_IN_SYMBOLS = sorted(set(_all_symbols))
return _IN_SYMBOLS
def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
"""get en fund symbols

View File

@@ -37,7 +37,7 @@ pip install -r requirements.txt
- user can append data to `v2`: [automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
- **the [benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks) for qlib use `v1`**, *due to the unstable access to historical data by YahooFinance, there are some differences between `v2` and `v1`*
- `interval`: `1d` or `1min`, by default `1d`
- `region`: `cn` or `us`, by default `cn`
- `region`: `cn` or `us` or `in`, by default `cn`
- `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`
- `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`
- examples:
@@ -50,6 +50,10 @@ pip install -r requirements.txt
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d --region us --interval 1d
# us 1min
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1min --region us --interval 1min
# in 1d
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_in_1d --region in --interval 1d
# in 1min
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_in_1min --region in --interval 1min
```
### Collector *YahooFinance* data to qlib
@@ -60,7 +64,7 @@ pip install -r requirements.txt
- `source_dir`: save the directory
- `interval`: `1d` or `1min`, by default `1d`
> **due to the limitation of the *YahooFinance API*, only the last month's data is available in `1min`**
- `region`: `CN` or `US`, by default `CN`
- `region`: `CN` or `US` or `IN`, by default `CN`
- `delay`: `time.sleep(delay)`, by default *0.5*
- `start`: start datetime, by default *"2000-01-01"*; *closed interval(including start)*
- `end`: end datetime, by default `pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`; *open interval(excluding end)*
@@ -71,13 +75,17 @@ pip install -r requirements.txt
- examples:
```bash
# cn 1d data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region CN
# cn 1min data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --delay 1 --interval 1min --region CN
# us 1d data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US
# us 1min data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1min --delay 1 --interval 1min --region US
# in 1d data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/in_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region IN
# in 1min data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/in_1min --delay 1 --interval 1min --region IN
```
2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data`
@@ -87,7 +95,7 @@ pip install -r requirements.txt
- `max_workers`: number of concurrent, by default *1*
- `interval`: `1d` or `1min`, by default `1d`
> if **`interval == 1min`**, `qlib_data_1d_dir` cannot be `None`
- `region`: `CN` or `US`, by default `CN`
- `region`: `CN` or `US` or `IN`, by default `CN`
- `date_field_name`: column *name* identifying time in csv files, by default `date`
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
- `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`

View File

@@ -34,6 +34,7 @@ from data_collector.utils import (
get_calendar_list,
get_hs_stock_symbols,
get_us_stock_symbols,
get_in_stock_symbols,
generate_minutes_calendar_from_daily,
)
@@ -279,10 +280,46 @@ class YahooCollectorUS1min(YahooCollectorUS):
pass
class YahooCollectorIN(YahooCollector, ABC):
def get_instrument_list(self):
logger.info("get INDIA stock symbols......")
symbols = get_in_stock_symbols()
logger.info(f"get {len(symbols)} symbols.")
return symbols
def download_index_data(self):
pass
def normalize_symbol(self, symbol):
return code_to_fname(symbol).upper()
@property
def _timezone(self):
return "Asia/Kolkata"
class YahooCollectorIN1d(YahooCollectorIN):
pass
class YahooCollectorIN1min(YahooCollectorIN):
pass
class YahooNormalize(BaseNormalize):
COLUMNS = ["open", "close", "high", "low", "volume"]
DAILY_FORMAT = "%Y-%m-%d"
@staticmethod
def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series:
df = df.copy()
_tmp_series = df["close"].fillna(method="ffill")
_tmp_shift_series = _tmp_series.shift(1)
if last_close is not None:
_tmp_shift_series.iloc[0] = float(last_close)
change_series = _tmp_series / _tmp_shift_series - 1
return change_series
@staticmethod
def normalize_yahoo(
df: pd.DataFrame,
@@ -310,11 +347,29 @@ class YahooNormalize(BaseNormalize):
)
df.sort_index(inplace=True)
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan
_tmp_series = df["close"].fillna(method="ffill")
_tmp_shift_series = _tmp_series.shift(1)
if last_close is not None:
_tmp_shift_series.iloc[0] = float(last_close)
df["change"] = _tmp_series / _tmp_shift_series - 1
change_series = YahooNormalize.calc_change(df, last_close)
# NOTE: The data obtained by Yahoo finance sometimes has exceptions
# WARNING: If it is normal for a `symbol(exchange)` to differ by a factor of *89* to *111* for consecutive trading days,
# WARNING: the logic in the following line needs to be modified
_count = 0
while True:
# NOTE: may appear unusual for many days in a row
change_series = YahooNormalize.calc_change(df, last_close)
_mask = (change_series >= 89) & (change_series <= 111)
if not _mask.any():
break
_tmp_cols = ["high", "close", "low", "open", "adjclose"]
df.loc[_mask, _tmp_cols] = df.loc[_mask, _tmp_cols] / 100
_count += 1
if _count >= 10:
_symbol = df.loc[df[symbol_field_name].first_valid_index()]["symbol"]
logger.warning(
f"{_symbol} `change` is abnormal for {_count} consecutive days, please check the specific data file carefully"
)
df["change"] = YahooNormalize.calc_change(df, last_close)
columns += ["change"]
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
@@ -710,6 +765,29 @@ class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline):
return fname_to_code(symbol)
class YahooNormalizeIN:
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
return get_calendar_list("IN_ALL")
class YahooNormalizeIN1d(YahooNormalizeIN, YahooNormalize1d):
pass
class YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1minOffline):
CALC_PAUSED_NUM = False
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
# TODO: support 1min
raise ValueError("Does not support 1min")
def _get_1d_calendar_list(self):
return get_calendar_list("IN_ALL")
def symbol_to_yahoo(self, symbol):
return fname_to_code(symbol)
class YahooNormalizeCN:
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
# TODO: from MSN
@@ -852,7 +930,7 @@ class Run(BaseRun):
if self.interval.lower() == "1min":
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
raise ValueError(
"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/zhupr/qlib/tree/support_extend_data/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
)
super(Run, self).normalize_data(
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir

View File

@@ -244,6 +244,10 @@ class DumpDataBase:
if df is None or df.empty:
logger.warning(f"{code} data is None or empty")
return
# try to remove dup rows or it will cause exception when reindex.
df = df.drop_duplicates(self.date_field_name)
# features save dir
features_dir = self._features_dir.joinpath(code_to_fname(code).lower())
features_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -11,7 +11,14 @@ NAME = "pyqlib"
DESCRIPTION = "A Quantitative-research Platform"
REQUIRES_PYTHON = ">=3.5.0"
VERSION = "0.7.0"
from pathlib import Path
from shutil import copyfile
CURRENT_DIR = Path(__file__).absolute().parent
_version_src = CURRENT_DIR / "VERSION.txt"
_version_dst = CURRENT_DIR / "qlib" / "VERSION.txt"
copyfile(_version_src, _version_dst)
VERSION = _version_dst.read_text(encoding="utf-8").strip()
# Detect Cython
try:
@@ -39,13 +46,13 @@ REQUIRED = [
"redis>=3.0.1",
"python-redis-lock>=3.3.1",
"schedule>=0.6.0",
"cvxpy==1.0.21",
"cvxpy>=1.0.21",
"hyperopt==0.1.1",
"fire>=0.3.1",
"statsmodels",
"xlrd>=1.0.0",
"plotly==4.12.0",
"matplotlib==3.1.3",
"matplotlib>=3.3",
"tables>=3.6.1",
"pyyaml>=5.3.1",
"mlflow>=1.12.1",
@@ -58,6 +65,7 @@ REQUIRED = [
"pymongo==3.7.2", # For task management
"scikit-learn>=0.22",
"dill",
"filelock",
]
# Numpy include
@@ -121,5 +129,6 @@ setup(
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
],
)

View File

@@ -0,0 +1,117 @@
import copy
import unittest
import fire
import pandas as pd
import qlib
from qlib.config import REG_CN
from qlib.data import D
from qlib.model.trainer import task_train
from qlib.tests import TestAutoData
from qlib.tests.config import CSI300_GBDT_TASK
from qlib.workflow.online.utils import OnlineToolR
from qlib.workflow.online.update import LabelUpdater
class TestRolling(TestAutoData):
_setup_kwargs = dict(expression_cache=None, dataset_cache=None)
def test_update_pred(self):
"""
This test is for testing if it will raise error if the `to_date` is out of the boundary.
"""
task = copy.deepcopy(CSI300_GBDT_TASK)
task["record"] = {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
}
exp_name = "online_srv_test"
cal = D.calendar()
latest_date = cal[-1]
train_start = latest_date - pd.Timedelta(days=61)
train_end = latest_date - pd.Timedelta(days=41)
task["dataset"]["kwargs"]["segments"] = {
"train": (train_start, train_end),
"valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)),
"test": (latest_date - pd.Timedelta(days=20), latest_date),
}
task["dataset"]["kwargs"]["handler"]["kwargs"] = {
"start_time": train_start,
"end_time": latest_date,
"fit_start_time": train_start,
"fit_end_time": train_end,
"instruments": "csi300",
}
rec = task_train(task, exp_name)
pred = rec.load_object("pred.pkl")
online_tool = OnlineToolR(exp_name)
online_tool.reset_online_tag(rec) # set to online model
online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))
def test_update_label(self):
task = copy.deepcopy(CSI300_GBDT_TASK)
task["record"] = {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
}
exp_name = "online_srv_test"
cal = D.calendar()
shift = 10
latest_date = cal[-1 - shift]
train_start = latest_date - pd.Timedelta(days=61)
train_end = latest_date - pd.Timedelta(days=41)
task["dataset"]["kwargs"]["segments"] = {
"train": (train_start, train_end),
"valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)),
"test": (latest_date - pd.Timedelta(days=20), latest_date),
}
task["dataset"]["kwargs"]["handler"]["kwargs"] = {
"start_time": train_start,
"end_time": latest_date,
"fit_start_time": train_start,
"fit_end_time": train_end,
"instruments": "csi300",
}
rec = task_train(task, exp_name)
pred = rec.load_object("pred.pkl")
online_tool = OnlineToolR(exp_name)
online_tool.reset_online_tag(rec) # set to online model
online_tool.update_online_pred()
new_pred = rec.load_object("pred.pkl")
label = rec.load_object("label.pkl")
label_date = label.dropna().index.get_level_values("datetime").max()
pred_date = new_pred.dropna().index.get_level_values("datetime").max()
# The prediction is updated, but the label is not updated.
self.assertTrue(label_date < pred_date)
# Update label now
lu = LabelUpdater(rec)
lu.update()
new_label = rec.load_object("label.pkl")
new_label_date = new_label.index.get_level_values("datetime").max()
self.assertTrue(new_label_date == pred_date) # make sure the label is updated now
if __name__ == "__main__":
unittest.main()

View File

@@ -5,7 +5,6 @@
from pathlib import Path
from collections.abc import Iterable
import pytest
import numpy as np
from qlib.tests import TestAutoData
@@ -33,13 +32,13 @@ class TestStorage(TestAutoData):
print(f"calendar[-1]: {calendar[-1]}")
calendar = CalendarStorage(freq="1min", future=False, provider_uri="not_found")
with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(calendar.data)
with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(calendar[:])
with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(calendar[0])
def test_instrument_storage(self):
@@ -90,10 +89,10 @@ class TestStorage(TestAutoData):
print(f"instrument['SH600000']: {instrument['SH600000']}")
instrument = InstrumentStorage(market="csi300", provider_uri="not_found")
with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(instrument.data)
with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(instrument["sSH600000"])
def test_feature_storage(self):
@@ -150,15 +149,15 @@ class TestStorage(TestAutoData):
"""
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri=self.provider_uri)
feature = FeatureStorage(instrument="SZ300677", field="close", freq="day", provider_uri=self.provider_uri)
with pytest.raises(IndexError):
with self.assertRaises(IndexError):
print(feature[0])
assert isinstance(
feature[815][1], (float, np.float32)
feature[3049][1], (float, np.float32)
), f"{feature.__class__.__name__}.__getitem__(i: int) error"
assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error"
print(f"feature[815: 818]: \n{feature[815: 818]}")
assert len(feature[3049:3052]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error"
print(f"feature[3049: 3052]: \n{feature[3049: 3052]}")
print(f"feature[:].tail(): \n{feature[:].tail()}")

View File

@@ -36,7 +36,7 @@ port_analysis_config = {
}
def train():
def train(uri_path: str = None):
"""train model
Returns
@@ -55,7 +55,7 @@ def train():
print(R)
# start exp
with R.start(experiment_name="workflow"):
with R.start(experiment_name="workflow", uri=uri_path):
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
@@ -79,7 +79,7 @@ def train():
return pred_score, {"ic": ic, "ric": ric}, rid
def train_with_sigana():
def train_with_sigana(uri_path: str = None):
"""train model followed by SigAnaRecord
Returns
@@ -91,9 +91,8 @@ def train_with_sigana():
"""
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
# start exp
with R.start(experiment_name="workflow_with_sigana"):
with R.start(experiment_name="workflow_with_sigana", uri=uri_path):
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
@@ -130,7 +129,7 @@ def fake_experiment():
return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri
def backtest_analysis(pred, rid):
def backtest_analysis(pred, rid, uri_path: str = None):
"""backtest and analysis
Parameters
@@ -139,6 +138,8 @@ def backtest_analysis(pred, rid):
predict scores
rid : str
the id of the recorder to be used in this function
uri_path: str
mlflow uri path
Returns
-------
@@ -146,7 +147,8 @@ def backtest_analysis(pred, rid):
the analysis result
"""
recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
with R.start(experiment_name="workflow", recorder_id=rid, uri=uri_path):
recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
# backtest
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()
@@ -160,24 +162,24 @@ class TestAllFlow(TestAutoData):
REPORT_NORMAL = None
POSITIONS = None
RID = None
URI_PATH = "file:" + str(Path(__file__).parent.joinpath("test_all_flow_mlruns").resolve())
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))
shutil.rmtree(cls.URI_PATH.lstrip("file:"))
def test_0_train_with_sigana(self):
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana()
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana(self.URI_PATH)
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
def test_1_train(self):
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train(self.URI_PATH)
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
def test_2_backtest(self):
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID, self.URI_PATH)
self.assertGreaterEqual(
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
0.10,

View File

@@ -12,10 +12,10 @@ from qlib.tests import TestAutoData
from qlib.tests.config import CSI300_GBDT_TASK
def train_multiseg():
def train_multiseg(uri_path: str = None):
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
with R.start(experiment_name="workflow"):
with R.start(experiment_name="workflow", uri=uri_path):
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
recorder = R.get_recorder()
@@ -25,10 +25,10 @@ def train_multiseg():
return uri
def train_mse():
def train_mse(uri_path: str = None):
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
with R.start(experiment_name="workflow"):
with R.start(experiment_name="workflow", uri=uri_path):
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
recorder = R.get_recorder()
@@ -39,13 +39,17 @@ def train_mse():
class TestAllFlow(TestAutoData):
URI_PATH = "file:" + str(Path(__file__).parent.joinpath("test_contrib_mlruns").resolve())
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree(cls.URI_PATH.lstrip("file:"))
def test_0_multiseg(self):
uri_path = train_multiseg()
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
uri_path = train_multiseg(self.URI_PATH)
def test_1_mse(self):
uri_path = train_mse()
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
uri_path = train_mse(self.URI_PATH)
def suite():