mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
199 Commits
qlib_monit
...
v0.7.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92055d64ec | ||
|
|
b9809a4c33 | ||
|
|
fc243fd29b | ||
|
|
b6a8bd5b80 | ||
|
|
6ee0fe366c | ||
|
|
55b6ff123e | ||
|
|
45ea4bae4e | ||
|
|
17d472cf01 | ||
|
|
c500a01226 | ||
|
|
114c38b4c3 | ||
|
|
414c3082c0 | ||
|
|
3fc2f8c93c | ||
|
|
66ff3e5bf6 | ||
|
|
8ff68a182e | ||
|
|
a105ef1d76 | ||
|
|
d02965ea70 | ||
|
|
b8d1e08010 | ||
|
|
51709c20d8 | ||
|
|
28c99c77be | ||
|
|
bb5cdfe050 | ||
|
|
fb21c591bb | ||
|
|
5279e71423 | ||
|
|
f35254c288 | ||
|
|
5e82c18cb2 | ||
|
|
2759e8c28d | ||
|
|
2461575d30 | ||
|
|
867667531d | ||
|
|
0fc52333b7 | ||
|
|
ab9b6dc47a | ||
|
|
4c5a4d5cd7 | ||
|
|
e84cc23589 | ||
|
|
707399a245 | ||
|
|
6e88ccca88 | ||
|
|
ee5f3de800 | ||
|
|
3605cd7b96 | ||
|
|
d1cbf4c3d9 | ||
|
|
6011a21308 | ||
|
|
76a05f37a9 | ||
|
|
c99494eb76 | ||
|
|
e8126b0c39 | ||
|
|
8f4d320832 | ||
|
|
e2739ac72c | ||
|
|
19d15ddc38 | ||
|
|
12af8f304b | ||
|
|
25b771ddf1 | ||
|
|
1158472489 | ||
|
|
84d2cb3226 | ||
|
|
509bfcb02e | ||
|
|
6608a40965 | ||
|
|
3e75cead93 | ||
|
|
6697f209d4 | ||
|
|
e3b57b1901 | ||
|
|
82a5223166 | ||
|
|
398131cff7 | ||
|
|
e71e2f941c | ||
|
|
0483406c12 | ||
|
|
da1f4db968 | ||
|
|
a7c41b6969 | ||
|
|
5b7b48e376 | ||
|
|
4f9f978909 | ||
|
|
319a2f38cc | ||
|
|
a2c38c979e | ||
|
|
07655f2d5b | ||
|
|
9303415666 | ||
|
|
05d28469ad | ||
|
|
dc6859bdd9 | ||
|
|
a6f9dde006 | ||
|
|
1d22ee56d3 | ||
|
|
3810a4cd33 | ||
|
|
48af7126b6 | ||
|
|
025b1dcff9 | ||
|
|
29e66b2dea | ||
|
|
698e59ac72 | ||
|
|
e006ef40ad | ||
|
|
59d4bc9394 | ||
|
|
b07e0bffb1 | ||
|
|
161343018f | ||
|
|
bee031af68 | ||
|
|
35840606a8 | ||
|
|
2df9b6e076 | ||
|
|
0c3eaf3f16 | ||
|
|
2eee064eb8 | ||
|
|
096ef5a62b | ||
|
|
dd0eebed53 | ||
|
|
7b20abeda1 | ||
|
|
5519420efd | ||
|
|
eb3c5b3088 | ||
|
|
f03df874bf | ||
|
|
8fa22bd2e1 | ||
|
|
d1c8d885aa | ||
|
|
bf7732e284 | ||
|
|
3f5334ab39 | ||
|
|
c97a96363d | ||
|
|
2023f714c9 | ||
|
|
f8a2b0533b | ||
|
|
3183a232df | ||
|
|
8b715268bd | ||
|
|
28cb827a23 | ||
|
|
b723f14619 | ||
|
|
47535ba530 | ||
|
|
d70e5a4f88 | ||
|
|
3b8087677c | ||
|
|
4ec41ea0e7 | ||
|
|
cfcd9fb1f8 | ||
|
|
457dcaa466 | ||
|
|
3c740fc2de | ||
|
|
6d91f28474 | ||
|
|
be8653c505 | ||
|
|
a8974ce535 | ||
|
|
79026e5390 | ||
|
|
4610e16ac2 | ||
|
|
b504cc6ac8 | ||
|
|
d5059e609f | ||
|
|
215f7e0d22 | ||
|
|
dafef0ac08 | ||
|
|
1cb43ea69b | ||
|
|
7ca9cf79f7 | ||
|
|
35f090a6e4 | ||
|
|
ace7484304 | ||
|
|
2d4f0e80f9 | ||
|
|
946c9392a1 | ||
|
|
b523b27d5a | ||
|
|
0b83fb3564 | ||
|
|
d96f7a67c6 | ||
|
|
a7862387a2 | ||
|
|
c4c438249c | ||
|
|
8709dde65b | ||
|
|
d66733c358 | ||
|
|
9cf574b697 | ||
|
|
107e40f3ee | ||
|
|
4837ba8db3 | ||
|
|
2ab4a9adb3 | ||
|
|
8d0b673341 | ||
|
|
8ebdb1e873 | ||
|
|
39340fbf06 | ||
|
|
0e277723a3 | ||
|
|
1418417034 | ||
|
|
b261f7b501 | ||
|
|
bab50e8837 | ||
|
|
0eee4a0f2e | ||
|
|
21eb71d4a9 | ||
|
|
46714adf4c | ||
|
|
99fb49650a | ||
|
|
985fd0816c | ||
|
|
d0f54343c7 | ||
|
|
a3679e6758 | ||
|
|
b6c31540e8 | ||
|
|
a4f6e04199 | ||
|
|
0aee46ee79 | ||
|
|
9c8d423a86 | ||
|
|
b4efbd53b2 | ||
|
|
5a50d7c952 | ||
|
|
0fe8b281ba | ||
|
|
5331ab93f8 | ||
|
|
64582e9d46 | ||
|
|
9e0e2ff736 | ||
|
|
973c4137e4 | ||
|
|
730f6258d6 | ||
|
|
5850490b24 | ||
|
|
d4b36bdab4 | ||
|
|
40416d8c30 | ||
|
|
567e42840c | ||
|
|
65ddca133f | ||
|
|
d199256d34 | ||
|
|
073fe4668e | ||
|
|
89d53853e5 | ||
|
|
bb6c1572ca | ||
|
|
4c4e77b11f | ||
|
|
38c7b7303a | ||
|
|
02d0eedd68 | ||
|
|
5a3dde93a8 | ||
|
|
177f6a59d2 | ||
|
|
492a62a569 | ||
|
|
9a44fbf9c1 | ||
|
|
03eb0882de | ||
|
|
a845a2271b | ||
|
|
ba021f6007 | ||
|
|
7d9544fb91 | ||
|
|
12b7be333d | ||
|
|
ed54f1213c | ||
|
|
554b9c7826 | ||
|
|
6f150f3fd6 | ||
|
|
2a0d991d9b | ||
|
|
1320e53f81 | ||
|
|
8222795ac4 | ||
|
|
616a742db7 | ||
|
|
811d2c975e | ||
|
|
6272ce108f | ||
|
|
64896745d0 | ||
|
|
b2fe2385d5 | ||
|
|
8d05cd2daf | ||
|
|
231bdf8608 | ||
|
|
ab6b88ce14 | ||
|
|
94ab4bbf3f | ||
|
|
ca0363ded8 | ||
|
|
a467e10974 | ||
|
|
6dfbf00a23 | ||
|
|
b24af7fff6 | ||
|
|
45f73361e3 |
2
.github/workflows/python-publish.yml
vendored
2
.github/workflows/python-publish.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
95
.github/workflows/test.yml
vendored
95
.github/workflows/test.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Test
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -12,8 +12,8 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -25,96 +25,41 @@ jobs:
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install black
|
||||
$CONDA\\python.exe -m black qlib -l 120 --check --diff
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install black
|
||||
$CONDA/bin/python -m black qlib -l 120 --check --diff
|
||||
fi
|
||||
shell: bash
|
||||
pip install --upgrade pip
|
||||
pip install black wheel
|
||||
black qlib -l 120 --check --diff
|
||||
|
||||
# Test Qlib installed with pip
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install numpy==1.19.5
|
||||
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml numpy --user
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install numpy==1.19.5
|
||||
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
pip install numpy==1.19.5 ruamel.yaml
|
||||
pip install pyqlib --ignore-installed
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
else
|
||||
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
fi
|
||||
shell: bash
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
$CONDA\\python.exe -m pip uninstall -y pyqlib
|
||||
else
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
sudo $CONDA/bin/python -m pip uninstall -y pyqlib
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
# Test Qlib installed from source
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade cython
|
||||
$CONDA\\python.exe -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
$CONDA\\python.exe -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
$CONDA\\python.exe setup.py install
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade cython
|
||||
sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
sudo $CONDA/bin/python setup.py install
|
||||
fi
|
||||
shell: bash
|
||||
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install -e .
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade pip
|
||||
$CONDA\\python.exe -m pip install black pytest
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade pip
|
||||
sudo $CONDA/bin/python -m pip install black pytest
|
||||
fi
|
||||
shell: bash
|
||||
pip install --upgrade pip
|
||||
pip install black pytest
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pytest . --durations=0
|
||||
else
|
||||
$CONDA/bin/python -m pytest . --durations=0
|
||||
fi
|
||||
shell: bash
|
||||
python -m pytest . --durations=10
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
else
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
fi
|
||||
shell: bash
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
|
||||
72
.github/workflows/test_macos.yml
vendored
Normal file
72
.github/workflows/test_macos.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
# There are some issues (in the downloading data phase) on MacOS when running with other tests. So we split it into an individual config.
|
||||
name: Test MacOS
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: macos-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
python -m pip install pip --upgrade
|
||||
python -m pip install wheel --upgrade
|
||||
python -m pip install black
|
||||
python -m black qlib -l 120 --check --diff
|
||||
# Test Qlib installed with pip
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
python -m pip install numpy==1.19.5
|
||||
python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
- name: Install Lightgbm for MacOS
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
python setup.py install
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -U pyopenssl idna
|
||||
python -m pip install black pytest
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
python -m pytest . --durations=0
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
1
MANIFEST.in
Normal file
1
MANIFEST.in
Normal file
@@ -0,0 +1 @@
|
||||
include qlib/VERSION.txt
|
||||
76
README.md
76
README.md
@@ -11,6 +11,10 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
|
||||
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
|
||||
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
|
||||
| TCTS Model | [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |
|
||||
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
|
||||
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
|
||||
| High-frequency data processing example | [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
|
||||
@@ -20,8 +24,6 @@ Recent released features
|
||||
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
|
||||
</p>
|
||||
@@ -42,7 +44,7 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
|
||||
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
|
||||
- [**Quant Model Zoo**](#quant-model-zoo)
|
||||
- [**Quant Model(Paper) Zoo**](#quant-model-paper-zoo)
|
||||
- [Run a single model](#run-a-single-model)
|
||||
- [Run multiple models](#run-multiple-models)
|
||||
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
|
||||
@@ -68,7 +70,7 @@ Your feedbacks about the features are very important.
|
||||
# Framework of Qlib
|
||||
|
||||
<div style="align: center">
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.1" />
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.2" />
|
||||
</div>
|
||||
|
||||
|
||||
@@ -104,8 +106,9 @@ This table demonstrates the supported Python version of `Qlib`:
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
|
||||
**Note**:
|
||||
1. **Conda** is suggested for managing your Python environment.
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
|
||||
### Install with pip
|
||||
Users can easily install ``Qlib`` by pip according to the following command.
|
||||
@@ -159,6 +162,28 @@ Users could create the same dataset with it.
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
|
||||
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
### Automatic update of daily frequency data (from yahoo finance)
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
|
||||
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
|
||||
* Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
* use *crontab*: `crontab -e`
|
||||
* set up timed tasks:
|
||||
|
||||
```
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
```
|
||||
* **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
* Manual update of data
|
||||
```
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
```
|
||||
* *trading_date*: start of trading day
|
||||
* *end_date*: end of trading day(not included)
|
||||
|
||||
|
||||
<!--
|
||||
- Run the initialization code and get stock data:
|
||||
|
||||
@@ -251,21 +276,25 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
# [Quant Model (Paper) Zoo](examples/benchmarks)
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. 2016)](qlib/contrib/model/xgboost.py)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al. 2017)](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. 2017)](qlib/contrib/model/catboost_model.py)
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](qlib/contrib/model/xgboost.py)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](qlib/contrib/model/catboost_model.py)
|
||||
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al. 1997)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural omputation 1997)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](qlib/contrib/model/pytorch_gru.py)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al. 2017)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](qlib/contrib/model/pytorch_tcts.py)
|
||||
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py)
|
||||
- [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py)
|
||||
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](qlib/contrib/model/pytorch_tra.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -279,9 +308,10 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
|
||||
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
|
||||
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.
|
||||
|
||||
@@ -346,9 +376,7 @@ Such overheads greatly slow down the data loading process.
|
||||
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
|
||||
|
||||
# Related Reports
|
||||
- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA)
|
||||
- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/)
|
||||
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
|
||||
- [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
|
||||
- [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
|
||||
|
||||
@@ -365,7 +393,17 @@ Join IM discussion groups:
|
||||
|
||||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
This project welcomes contributions and suggestions.
|
||||
**Here are some
|
||||
[code standards](docs/developer/code_standard.rst) when you submit a pull request.**
|
||||
|
||||
If you want to contribute to Qlib's document, you can follow the steps in the figure below.
|
||||
<p align="center">
|
||||
<img src="https://github.com/demon143/qlib/blob/main/docs/_static/img/change%20doc.gif" />
|
||||
</p>
|
||||
|
||||
|
||||
Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the right to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
|
||||
1
VERSION.txt
Normal file
1
VERSION.txt
Normal file
@@ -0,0 +1 @@
|
||||
0.7.2
|
||||
@@ -97,4 +97,57 @@ Also, feel free to post a new issue in our GitHub repository. We always check ea
|
||||
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
|
||||
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
|
||||
|
||||
|
||||
4. BadNamespaceError: / is not a connected namespace
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
File "qlib_online.py", line 35, in <module>
|
||||
cal = D.calendar()
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
|
||||
return Cal.calendar(start_time, end_time, freq, future=future)
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
|
||||
self.conn.send_request(
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\python_socketio-5.3.0-py3.8.egg\socketio\client.py", line 369, in emit
|
||||
raise exceptions.BadNamespaceError(
|
||||
BadNamespaceError: / is not a connected namespace.
|
||||
|
||||
- The version of ``python-socketio`` in qlib needs to be the same as the version of ``python-socketio`` in qlib-server:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U python-socketio==<qlib-server python-socketio version>
|
||||
|
||||
|
||||
5. TypeError: send() got an unexpected keyword argument 'binary'
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
File "qlib_online.py", line 35, in <module>
|
||||
cal = D.calendar()
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
|
||||
return Cal.calendar(start_time, end_time, freq, future=future)
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
|
||||
self.conn.send_request(
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 263, in emit
|
||||
self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 339, in _send_packet
|
||||
self.eio.send(ep, binary=binary)
|
||||
TypeError: send() got an unexpected keyword argument 'binary'
|
||||
|
||||
|
||||
- The ``python-engineio`` version needs to be compatible with the ``python-socketio`` version, reference: https://github.com/miguelgrinberg/python-socketio#version-compatibility
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U python-engineio==<compatible python-socketio version>
|
||||
# or
|
||||
pip install -U python-socketio==3.1.2 python-engineio==3.13.2
|
||||
|
||||
BIN
docs/_static/img/change doc.gif
vendored
Normal file
BIN
docs/_static/img/change doc.gif
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.3 MiB |
BIN
docs/_static/img/framework.png
vendored
BIN
docs/_static/img/framework.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 271 KiB After Width: | Height: | Size: 208 KiB |
@@ -67,6 +67,34 @@ After running the above command, users can find china-stock and us-stock data in
|
||||
|
||||
When ``Qlib`` is initialized with this dataset, users could build and evaluate their own models with it. Please refer to `Initialization <../start/initialization.html>`_ for more details.
|
||||
|
||||
Automatic update of daily frequency data
|
||||
----------------------------------------
|
||||
|
||||
**It is recommended that users update the data manually once (\-\-trading_date 2021-05-25) and then set it to update automatically.**
|
||||
|
||||
For more information refer to: `yahoo collector <https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#Automatic-update-of-daily-frequency-data>`_
|
||||
|
||||
- Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
- use *crontab*: `crontab -e`
|
||||
- set up timed tasks:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
|
||||
- **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
- Manual update of data
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
|
||||
- *trading_date*: start of trading day
|
||||
- *end_date*: end of trading day(not included)
|
||||
|
||||
|
||||
|
||||
Converting CSV Format into Qlib Format
|
||||
-------------------------------------------
|
||||
|
||||
@@ -151,6 +179,7 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
The Restoration factor. Normally, ``factor = adjusted_price / original_price``, `adjusted price` reference: `split adjusted <https://www.investopedia.com/terms/s/splitadjusted.asp>`_
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
If you want to use your own alpha-factor which can't be calculate by OCHLV, like PE, EPS and so on, you could add it to the CSV files with OHCLV together and then dump it to the Qlib format data.
|
||||
|
||||
Stock Pool (Market)
|
||||
--------------------------------
|
||||
|
||||
@@ -21,6 +21,8 @@ which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online S
|
||||
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
|
||||
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
|
||||
|
||||
**NOTE**: User should keep his data source updated to support online serving. For example, Qlib provides `a batch of scripts <https://github.com/microsoft/qlib/blob/main/scripts/data_collector/yahoo/README.md#automatic-update-of-daily-frequency-datafrom-yahoo-finance>`_ to help users update Yahoo daily data.
|
||||
|
||||
Online Manager
|
||||
=============
|
||||
|
||||
@@ -43,4 +45,4 @@ Updater
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
:members:
|
||||
|
||||
@@ -90,12 +90,12 @@ Below is a typical config file of ``qrun``.
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
|
||||
|
||||
@@ -142,7 +142,7 @@ The meaning of each field is as follows:
|
||||
|
||||
- `region`
|
||||
- If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
|
||||
- If `region` == "cn", ``Qlib`` will be initialized in china-stock mode.
|
||||
- If `region` == "cn", ``Qlib`` will be initialized in China-stock mode.
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
22
docs/developer/code_standard.rst
Normal file
22
docs/developer/code_standard.rst
Normal file
@@ -0,0 +1,22 @@
|
||||
.. _code_standard:
|
||||
|
||||
=================================
|
||||
Code Standard
|
||||
=================================
|
||||
|
||||
Docstring
|
||||
=================================
|
||||
Please use the `Numpydoc Style <https://stackoverflow.com/a/24385103>`_.
|
||||
|
||||
Continuous Integration
|
||||
=================================
|
||||
Continuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.
|
||||
|
||||
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
|
||||
|
||||
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pip install black
|
||||
python -m black . -l 120
|
||||
@@ -241,6 +241,7 @@ Online Tool
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
|
||||
RecordUpdater
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
@@ -257,4 +258,4 @@ Serializable
|
||||
:members:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -77,7 +77,8 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
})
|
||||
- `mongo`
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
|
||||
Users need to follow the steps in `installation <https://www.mongodb.com/try/download/community>`_ to install MongoDB firstly and then access it via a URI.
|
||||
Users can access mongodb with credential by setting "task_url" to a string like `"mongodb://%s:%s@%s" % (user, pwd, host + ":" + port)`.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
|
||||
@@ -61,7 +61,6 @@ task:
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
with_pretrain: True
|
||||
model_path: "benchmarks/LSTM/csi300_lstm_ts.pkl"
|
||||
GPU: 0
|
||||
dataset:
|
||||
|
||||
@@ -54,7 +54,6 @@ task:
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
with_pretrain: True
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
GPU: 0
|
||||
dataset:
|
||||
@@ -81,4 +80,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
16
examples/benchmarks/LightGBM/features_sample.py
Normal file
16
examples/benchmarks/LightGBM/features_sample.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import datetime
|
||||
import pandas as pd
|
||||
|
||||
from qlib.data.inst_processor import InstProcessor
|
||||
|
||||
|
||||
class Resample1minProcessor(InstProcessor):
|
||||
def __init__(self, hour: int, minute: int, **kwargs):
|
||||
self.hour = hour
|
||||
self.minute = minute
|
||||
|
||||
def __call__(self, df: pd.DataFrame, *args, **kwargs):
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df.loc[df.index.time == datetime.time(self.hour, self.minute)]
|
||||
df.index = df.index.normalize()
|
||||
return df
|
||||
@@ -63,4 +63,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
qlib_init:
|
||||
provider_uri:
|
||||
day: "~/.qlib/qlib_data/cn_data"
|
||||
1min: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
dataset_cache: null
|
||||
maxtasksperchild: 1
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
# 1min closing time is 15:00:00
|
||||
end_time: "2020-08-01 15:00:00"
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
freq:
|
||||
label: day
|
||||
feature: 1min
|
||||
# with label as reference
|
||||
inst_processor:
|
||||
feature:
|
||||
- class: Resample1minProcessor
|
||||
module_path: features_sample.py
|
||||
kwargs:
|
||||
hour: 14
|
||||
minute: 56
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -78,4 +78,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
3
examples/benchmarks/Localformer/requirements.txt
Normal file
3
examples/benchmarks/Localformer/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
torch==1.2.0
|
||||
@@ -0,0 +1,82 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LocalformerModel
|
||||
module_path: qlib.contrib.model.pytorch_localformer_ts
|
||||
kwargs:
|
||||
seed: 0
|
||||
n_jobs: 20
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,73 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LocalformerModel
|
||||
module_path: qlib.contrib.model.pytorch_localformer
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
seed: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,9 +1,13 @@
|
||||
# Benchmarks Performance
|
||||
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs.
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
|
||||
|
||||
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
|
||||
|
||||
> If you need to reproduce the results below, please use the **v1** dataset: `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn --version v1`
|
||||
>
|
||||
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ
|
||||
|
||||
## Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
@@ -18,6 +22,10 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0491±0.01 | 0.3868±0.06 | 0.0589±0.00 | 0.4802±0.04 | 0.0898±0.02 | 1.2490±0.32 | -0.0778±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -34,6 +42,10 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features)| 0.0409±0.00 | 0.3253±0.04 | 0.0488±0.00 | 0.4045±0.02 | 0.0673±0.02 | 1.0389±0.39 | -0.0830±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0442±0.00 | 0.3426±0.03 | 0.0555±0.00 | 0.4395±0.03 | 0.0833±0.03 | 1.2064±0.36 | -0.0849±0.02 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
52
examples/benchmarks/TCTS/README.md
Normal file
52
examples/benchmarks/TCTS/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Temporally Correlated Task Scheduling for Sequence Learning
|
||||
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
|
||||
|
||||
### Background
|
||||
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
|
||||
|
||||
<p align="center">
|
||||
<img src="task_description.png" width="600" height="200"/>
|
||||
</p>
|
||||
|
||||
|
||||
### Method
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
|
||||
<p align="center">
|
||||
<img src="workflow.png"/>
|
||||
</p>
|
||||
|
||||
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
|
||||
|
||||
### DataSet
|
||||
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
|
||||
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
|
||||
### Experiments
|
||||
#### Task Description
|
||||
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
|
||||
<div align=center>
|
||||
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
|
||||
</div>
|
||||
|
||||
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
|
||||
#### Baselines
|
||||
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
|
||||
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
|
||||
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
|
||||
#### Result
|
||||
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
|
||||
| :----: | :----: | :----: | :----: |
|
||||
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
|
||||
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
|
||||
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
|
||||
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
|
||||
BIN
examples/benchmarks/TCTS/task_description.png
Normal file
BIN
examples/benchmarks/TCTS/task_description.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 25 KiB |
BIN
examples/benchmarks/TCTS/workflow.png
Normal file
BIN
examples/benchmarks/TCTS/workflow.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
93
examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
Normal file
93
examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -1) / $close - 1",
|
||||
"Ref($close, -2) / Ref($close, -1) - 1",
|
||||
"Ref($close, -3) / Ref($close, -2) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TCTS
|
||||
module_path: qlib.contrib.model.pytorch_tcts
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
GPU: 0
|
||||
fore_optimizer: adam
|
||||
weight_optimizer: adam
|
||||
output_dim: 3
|
||||
fore_lr: 5e-4
|
||||
weight_lr: 5e-4
|
||||
steps: 3
|
||||
target_label: 1
|
||||
lowest_valid_performance: 0.993
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
label_col: 1
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -8,7 +8,7 @@
|
||||
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
|
||||
|
||||
### Notes
|
||||
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
|
||||
1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.
|
||||
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
|
||||
3. The model must run in GPU, or an error will be raised.
|
||||
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
tensorflow-gpu==1.15.0
|
||||
numpy == 1.19.4
|
||||
pandas==1.1.0
|
||||
pandas==1.1.0
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow.compat.v1 as tf
|
||||
@@ -243,7 +245,7 @@ class TFTModel(ModelFT):
|
||||
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
|
||||
# 0.9)
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
print("Training completed.".format(dte.datetime.now()))
|
||||
print("Training completed at {}.".format(dte.datetime.now()))
|
||||
# ===========================Training Process===========================
|
||||
|
||||
def predict(self, dataset):
|
||||
@@ -289,3 +291,24 @@ class TFTModel(ModelFT):
|
||||
dataset for finetuning
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_pickle(self, path: Union[Path, str]):
|
||||
"""
|
||||
Tensorflow model can't be dumped directly.
|
||||
So the data should be save seperatedly
|
||||
|
||||
**TODO**: Please implement the function to load the files
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Union[Path, str]
|
||||
the target path to be dumped
|
||||
"""
|
||||
# save tensorflow model
|
||||
# path = Path(path)
|
||||
# path.mkdir(parents=True)
|
||||
# self.model.save(path)
|
||||
|
||||
# save qlib model wrapper
|
||||
self.model = None
|
||||
super(TFTModel, self).to_pickle(path / "qlib_model")
|
||||
|
||||
92
examples/benchmarks/TRA/README.md
Normal file
92
examples/benchmarks/TRA/README.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport
|
||||
|
||||
Temporal Routing Adaptor (TRA) is designed to capture multiple trading patterns in the stock market data. Please refer to [our paper](http://arxiv.org/abs/2106.12950) for more details.
|
||||
|
||||
If you find our work useful in your research, please cite:
|
||||
```
|
||||
@inproceedings{HengxuKDD2021,
|
||||
author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
|
||||
title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
|
||||
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
|
||||
series = {KDD '21},
|
||||
year = {2021},
|
||||
publisher = {ACM},
|
||||
}
|
||||
|
||||
@article{yang2020qlib,
|
||||
title={Qlib: An AI-oriented Quantitative Investment Platform},
|
||||
author={Yang, Xiao and Liu, Weiqing and Zhou, Dong and Bian, Jiang and Liu, Tie-Yan},
|
||||
journal={arXiv preprint arXiv:2009.11189},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage (Recommended)
|
||||
|
||||
**Update**: `TRA` has been moved to `qlib.contrib.model.pytorch_tra` to support other `Qlib` components like `qlib.workflow` and `Alpha158/Alpha360` dataset.
|
||||
|
||||
Please follow the official [doc](https://qlib.readthedocs.io/en/latest/component/workflow.html) to use `TRA` with `workflow`. Here we also provide several example config files:
|
||||
|
||||
- `workflow_config_tra_Alpha360.yaml`: running `TRA` with `Alpha360` dataset
|
||||
- `workflow_config_tra_Alpha158.yaml`: running `TRA` with `Alpha158` dataset (with feature subsampling)
|
||||
- `workflow_config_tra_Alpha158_full.yaml`: running `TRA` with `Alpha158` dataset (without feature subsampling)
|
||||
|
||||
The performances of `TRA` are reported in [Benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).
|
||||
|
||||
## Usage (Not Maintained)
|
||||
|
||||
This section is used to reproduce the results in the paper.
|
||||
|
||||
### Running
|
||||
|
||||
We attach our running scripts for the paper in `run.sh`.
|
||||
|
||||
And here are two ways to run the model:
|
||||
|
||||
* Running from scripts with default parameters
|
||||
|
||||
You can directly run from Qlib command `qrun`:
|
||||
```
|
||||
qrun configs/config_alstm.yaml
|
||||
```
|
||||
|
||||
* Running from code with self-defined parameters
|
||||
|
||||
Setting different parameters is also allowed. See codes in `example.py`:
|
||||
```
|
||||
python example.py --config_file configs/config_alstm.yaml
|
||||
```
|
||||
|
||||
Here we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scipts.
|
||||
|
||||
### Results
|
||||
|
||||
After running the scripts, you can find result files in path `./output`:
|
||||
|
||||
* `info.json` - config settings and result metrics.
|
||||
* `log.csv` - running logs.
|
||||
* `model.bin` - the model parameter dictionary.
|
||||
* `pred.pkl` - the prediction scores and output for inference.
|
||||
|
||||
Evaluation metrics reported in the paper:
|
||||
|
||||
| Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD |
|
||||
|-------|-------|------|-----|-----|-----|-----|-----|-----|
|
||||
|Linear|0.163|0.327|0.020|0.132|-3.2%|16.8%|-0.191|32.1%|
|
||||
|LightGBM|0.160(0.000)|0.323(0.000)|0.041|0.292|7.8%|15.5%|0.503|25.7%|
|
||||
|MLP|0.160(0.002)|0.323(0.003)|0.037|0.273|3.7%|15.3%|0.264|26.2%|
|
||||
|SFM|0.159(0.001) |0.321(0.001) |0.047 |0.381 |7.1% |14.3% |0.497 |22.9%|
|
||||
|ALSTM|0.158(0.001) |0.320(0.001) |0.053 |0.419 |12.3% |13.7% |0.897 |20.2%|
|
||||
|Trans.|0.158(0.001) |0.322(0.001) |0.051 |0.400 |14.5% |14.2% |1.028 |22.5%|
|
||||
|ALSTM+TS|0.160(0.002) |0.321(0.002) |0.039 |0.291 |6.7% |14.6% |0.480|22.3%|
|
||||
|Trans.+TS|0.160(0.004) |0.324(0.005) |0.037 |0.278 |10.4% |14.7% |0.722 |23.7%|
|
||||
|ALSTM+TRA(Ours)|0.157(0.000) |0.318(0.000) |0.059 |0.460 |12.4% |14.0% |0.885 |20.4%|
|
||||
|Trans.+TRA(Ours)|0.157(0.000) |0.320(0.000) |0.056 |0.442 |16.1% |14.2% |1.133 |23.1%|
|
||||
|
||||
A more detailed demo for our experiment results in the paper can be found in `Report.ipynb`.
|
||||
|
||||
## Common Issues
|
||||
|
||||
For help or issues using TRA, please submit a GitHub issue.
|
||||
|
||||
Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.
|
||||
796
examples/benchmarks/TRA/Reports.ipynb
Normal file
796
examples/benchmarks/TRA/Reports.ipynb
Normal file
File diff suppressed because one or more lines are too long
63
examples/benchmarks/TRA/configs/config_alstm.yaml
Normal file
63
examples/benchmarks/TRA/configs/config_alstm.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
num_heads: 2
|
||||
use_attn: True
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 1
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/alstm
|
||||
model_type: LSTM
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
63
examples/benchmarks/TRA/configs/config_alstm_tra.yaml
Normal file
63
examples/benchmarks/TRA/configs/config_alstm_tra.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
num_heads: 2
|
||||
use_attn: True
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 10
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/alstm_tra
|
||||
model_type: LSTM
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 2.0
|
||||
rho: 0.99
|
||||
freeze_model: True
|
||||
model_init_state: output/test/alstm_tra_init/model.bin
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
63
examples/benchmarks/TRA/configs/config_alstm_tra_init.yaml
Normal file
63
examples/benchmarks/TRA/configs/config_alstm_tra_init.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
num_heads: 2
|
||||
use_attn: True
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/alstm_tra_init
|
||||
model_type: LSTM
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 512
|
||||
63
examples/benchmarks/TRA/configs/config_transformer.yaml
Normal file
63
examples/benchmarks/TRA/configs/config_transformer.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
num_heads: 4
|
||||
use_attn: False
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 1
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/transformer
|
||||
model_type: Transformer
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
63
examples/benchmarks/TRA/configs/config_transformer_tra.yaml
Normal file
63
examples/benchmarks/TRA/configs/config_transformer_tra.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
num_heads: 4
|
||||
use_attn: False
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0005
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/transformer_tra
|
||||
model_type: Transformer
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: True
|
||||
model_init_state: output/test/transformer_tra_init/model.bin
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 512
|
||||
@@ -0,0 +1,63 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
num_heads: 4
|
||||
use_attn: False
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/transformer_tra_init
|
||||
model_type: Transformer
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 512
|
||||
1
examples/benchmarks/TRA/data/README.md
Normal file
1
examples/benchmarks/TRA/data/README.md
Normal file
@@ -0,0 +1 @@
|
||||
Data Link: https://drive.google.com/drive/folders/1fMqZYSeLyrHiWmVzygeI4sw3vp5Gt8cY?usp=sharing
|
||||
39
examples/benchmarks/TRA/example.py
Normal file
39
examples/benchmarks/TRA/example.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import argparse
|
||||
|
||||
import qlib
|
||||
import ruamel.yaml as yaml
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
|
||||
def main(seed, config_file="configs/config_alstm.yaml"):
|
||||
|
||||
# set random seed
|
||||
with open(config_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
|
||||
seed_suffix = ""
|
||||
config["task"]["model"]["kwargs"].update(
|
||||
{"seed": seed, "logdir": config["task"]["model"]["kwargs"]["logdir"] + seed_suffix}
|
||||
)
|
||||
|
||||
# initialize workflow
|
||||
qlib.init(
|
||||
provider_uri=config["qlib_init"]["provider_uri"],
|
||||
region=config["qlib_init"]["region"],
|
||||
)
|
||||
dataset = init_instance_by_config(config["task"]["dataset"])
|
||||
model = init_instance_by_config(config["task"]["model"])
|
||||
|
||||
# train model
|
||||
model.fit(dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# set params from cmd
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--seed", type=int, default=1000, help="random seed")
|
||||
parser.add_argument("--config_file", type=str, default="configs/config_alstm.yaml", help="config file")
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
29
examples/benchmarks/TRA/run.sh
Normal file
29
examples/benchmarks/TRA/run.sh
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
|
||||
# we used random seed(1 1000 2000 3000 4000 5000) in our experiments
|
||||
|
||||
# Directly run from Qlib command `qrun`
|
||||
qrun configs/config_alstm.yaml
|
||||
|
||||
qrun configs/config_transformer.yaml
|
||||
|
||||
qrun configs/config_transformer_tra_init.yaml
|
||||
qrun configs/config_transformer_tra.yaml
|
||||
|
||||
qrun configs/config_alstm_tra_init.yaml
|
||||
qrun configs/config_alstm_tra.yaml
|
||||
|
||||
|
||||
# Or setting different parameters with example.py
|
||||
python example.py --config_file configs/config_alstm.yaml
|
||||
|
||||
python example.py --config_file configs/config_transformer.yaml
|
||||
|
||||
python example.py --config_file configs/config_transformer_tra_init.yaml
|
||||
python example.py --config_file configs/config_transformer_tra.yaml
|
||||
|
||||
python example.py --config_file configs/config_alstm_tra_init.yaml
|
||||
python example.py --config_file configs/config_alstm_tra.yaml
|
||||
|
||||
|
||||
|
||||
253
examples/benchmarks/TRA/src/dataset.py
Normal file
253
examples/benchmarks/TRA/src/dataset.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return torch.tensor(x, dtype=torch.float, device=device)
|
||||
return x
|
||||
|
||||
|
||||
def _create_ts_slices(index, seq_len):
|
||||
"""
|
||||
create time series slices from pandas index
|
||||
|
||||
Args:
|
||||
index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
|
||||
seq_len (int): sequence length
|
||||
"""
|
||||
assert index.is_lexsorted(), "index should be sorted"
|
||||
|
||||
# number of dates for each code
|
||||
sample_count_by_codes = pd.Series(0, index=index).groupby(level=0).size().values
|
||||
|
||||
# start_index for each code
|
||||
start_index_of_codes = np.roll(np.cumsum(sample_count_by_codes), 1)
|
||||
start_index_of_codes[0] = 0
|
||||
|
||||
# all the [start, stop) indices of features
|
||||
# features btw [start, stop) are used to predict the `stop - 1` label
|
||||
slices = []
|
||||
for cur_loc, cur_cnt in zip(start_index_of_codes, sample_count_by_codes):
|
||||
for stop in range(1, cur_cnt + 1):
|
||||
end = cur_loc + stop
|
||||
start = max(end - seq_len, 0)
|
||||
slices.append(slice(start, end))
|
||||
slices = np.array(slices)
|
||||
|
||||
return slices
|
||||
|
||||
|
||||
def _get_date_parse_fn(target):
|
||||
"""get date parse function
|
||||
|
||||
This method is used to parse date arguments as target type.
|
||||
|
||||
Example:
|
||||
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
|
||||
get_date_parse_fn(20120101)('2017-01-01') => 20170101
|
||||
"""
|
||||
if isinstance(target, pd.Timestamp):
|
||||
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
|
||||
elif isinstance(target, str) and len(target) == 8:
|
||||
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
|
||||
elif isinstance(target, int):
|
||||
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
|
||||
else:
|
||||
_fn = lambda x: x
|
||||
return _fn
|
||||
|
||||
|
||||
class MTSDatasetH(DatasetH):
|
||||
"""Memory Augmented Time Series Dataset
|
||||
|
||||
Args:
|
||||
handler (DataHandler): data handler
|
||||
segments (dict): data split segments
|
||||
seq_len (int): time series sequence length
|
||||
horizon (int): label horizon (to mask historical loss for TRA)
|
||||
num_states (int): how many memory states to be added (for TRA)
|
||||
batch_size (int): batch size (<0 means daily batch)
|
||||
shuffle (bool): whether shuffle data
|
||||
pin_memory (bool): whether pin data to gpu memory
|
||||
drop_last (bool): whether drop last batch < batch_size
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler,
|
||||
segments,
|
||||
seq_len=60,
|
||||
horizon=0,
|
||||
num_states=1,
|
||||
batch_size=-1,
|
||||
shuffle=True,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
assert horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.horizon = horizon
|
||||
self.num_states = num_states
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.drop_last = drop_last
|
||||
self.pin_memory = pin_memory
|
||||
self.params = (batch_size, drop_last, shuffle) # for train/eval switch
|
||||
|
||||
super().__init__(handler, segments, **kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
|
||||
super().setup_data()
|
||||
|
||||
# change index to <code, date>
|
||||
# NOTE: we will use inplace sort to reduce memory use
|
||||
df = self.handler._data
|
||||
df.index = df.index.swaplevel()
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
self._data = df["feature"].values.astype("float32")
|
||||
self._label = df["label"].squeeze().astype("float32")
|
||||
self._index = df.index
|
||||
|
||||
# add memory to feature
|
||||
self._data = np.c_[self._data, np.zeros((len(self._data), self.num_states), dtype=np.float32)]
|
||||
|
||||
# padding tensor
|
||||
self.zeros = np.zeros((self.seq_len, self._data.shape[1]), dtype=np.float32)
|
||||
|
||||
# pin memory
|
||||
if self.pin_memory:
|
||||
self._data = _to_tensor(self._data)
|
||||
self._label = _to_tensor(self._label)
|
||||
self.zeros = _to_tensor(self.zeros)
|
||||
|
||||
# create batch slices
|
||||
self.batch_slices = _create_ts_slices(self._index, self.seq_len)
|
||||
|
||||
# create daily slices
|
||||
index = [slc.stop - 1 for slc in self.batch_slices]
|
||||
act_index = self.restore_index(index)
|
||||
daily_slices = {date: [] for date in sorted(act_index.unique(level=1))}
|
||||
for i, (code, date) in enumerate(act_index):
|
||||
daily_slices[date].append(self.batch_slices[i])
|
||||
self.daily_slices = list(daily_slices.values())
|
||||
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
fn = _get_date_parse_fn(self._index[0][1])
|
||||
start_date = fn(slc.start)
|
||||
end_date = fn(slc.stop)
|
||||
obj = copy.copy(self) # shallow copy
|
||||
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
|
||||
obj._data = self._data
|
||||
obj._label = self._label
|
||||
obj._index = self._index
|
||||
new_batch_slices = []
|
||||
for batch_slc in self.batch_slices:
|
||||
date = self._index[batch_slc.stop - 1][1]
|
||||
if start_date <= date <= end_date:
|
||||
new_batch_slices.append(batch_slc)
|
||||
obj.batch_slices = np.array(new_batch_slices)
|
||||
new_daily_slices = []
|
||||
for daily_slc in self.daily_slices:
|
||||
date = self._index[daily_slc[0].stop - 1][1]
|
||||
if start_date <= date <= end_date:
|
||||
new_daily_slices.append(daily_slc)
|
||||
obj.daily_slices = new_daily_slices
|
||||
return obj
|
||||
|
||||
def restore_index(self, index):
|
||||
if isinstance(index, torch.Tensor):
|
||||
index = index.cpu().numpy()
|
||||
return self._index[index]
|
||||
|
||||
def assign_data(self, index, vals):
|
||||
if isinstance(self._data, torch.Tensor):
|
||||
vals = _to_tensor(vals)
|
||||
elif isinstance(vals, torch.Tensor):
|
||||
vals = vals.detach().cpu().numpy()
|
||||
index = index.detach().cpu().numpy()
|
||||
self._data[index, -self.num_states :] = vals
|
||||
|
||||
def clear_memory(self):
|
||||
self._data[:, -self.num_states :] = 0
|
||||
|
||||
# TODO: better train/eval mode design
|
||||
def train(self):
|
||||
"""enable traning mode"""
|
||||
self.batch_size, self.drop_last, self.shuffle = self.params
|
||||
|
||||
def eval(self):
|
||||
"""enable evaluation mode"""
|
||||
self.batch_size = -1
|
||||
self.drop_last = False
|
||||
self.shuffle = False
|
||||
|
||||
def _get_slices(self):
|
||||
if self.batch_size < 0:
|
||||
slices = self.daily_slices.copy()
|
||||
batch_size = -1 * self.batch_size
|
||||
else:
|
||||
slices = self.batch_slices.copy()
|
||||
batch_size = self.batch_size
|
||||
return slices, batch_size
|
||||
|
||||
def __len__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
if self.drop_last:
|
||||
return len(slices) // batch_size
|
||||
return (len(slices) + batch_size - 1) // batch_size
|
||||
|
||||
def __iter__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
if self.shuffle:
|
||||
np.random.shuffle(slices)
|
||||
|
||||
for i in range(len(slices))[::batch_size]:
|
||||
if self.drop_last and i + batch_size > len(slices):
|
||||
break
|
||||
# get slices for this batch
|
||||
slices_subset = slices[i : i + batch_size]
|
||||
if self.batch_size < 0:
|
||||
slices_subset = np.concatenate(slices_subset)
|
||||
# collect data
|
||||
data = []
|
||||
label = []
|
||||
index = []
|
||||
for slc in slices_subset:
|
||||
_data = self._data[slc].clone() if self.pin_memory else self._data[slc].copy()
|
||||
if len(_data) != self.seq_len:
|
||||
if self.pin_memory:
|
||||
_data = torch.cat([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
|
||||
else:
|
||||
_data = np.concatenate([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
|
||||
if self.num_states > 0:
|
||||
_data[-self.horizon :, -self.num_states :] = 0
|
||||
data.append(_data)
|
||||
label.append(self._label[slc.stop - 1])
|
||||
index.append(slc.stop - 1)
|
||||
# concate
|
||||
index = torch.tensor(index, device=device)
|
||||
if isinstance(data[0], torch.Tensor):
|
||||
data = torch.stack(data)
|
||||
label = torch.stack(label)
|
||||
else:
|
||||
data = _to_tensor(np.stack(data))
|
||||
label = _to_tensor(np.stack(label))
|
||||
# yield -> generator
|
||||
yield {"data": data, "label": label, "index": index}
|
||||
603
examples/benchmarks/TRA/src/model.py
Normal file
603
examples/benchmarks/TRA/src/model.py
Normal file
@@ -0,0 +1,603 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
import math
|
||||
import json
|
||||
import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
class TRAModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
tra_config,
|
||||
model_type="LSTM",
|
||||
lr=1e-3,
|
||||
n_epochs=500,
|
||||
early_stop=50,
|
||||
smooth_steps=5,
|
||||
max_steps_per_epoch=None,
|
||||
freeze_model=False,
|
||||
model_init_state=None,
|
||||
lamb=0.0,
|
||||
rho=0.99,
|
||||
seed=0,
|
||||
logdir=None,
|
||||
eval_train=True,
|
||||
eval_test=False,
|
||||
avg_params=True,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
self.logger = get_module_logger("TRA")
|
||||
self.logger.info("TRA Model...")
|
||||
|
||||
self.model = eval(model_type)(**model_config).to(device)
|
||||
if model_init_state:
|
||||
self.model.load_state_dict(torch.load(model_init_state, map_location="cpu")["model"])
|
||||
if freeze_model:
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters()]))
|
||||
|
||||
self.tra = TRA(self.model.output_size, **tra_config).to(device)
|
||||
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters()]))
|
||||
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=lr)
|
||||
|
||||
self.model_config = model_config
|
||||
self.tra_config = tra_config
|
||||
self.lr = lr
|
||||
self.n_epochs = n_epochs
|
||||
self.early_stop = early_stop
|
||||
self.smooth_steps = smooth_steps
|
||||
self.max_steps_per_epoch = max_steps_per_epoch
|
||||
self.lamb = lamb
|
||||
self.rho = rho
|
||||
self.seed = seed
|
||||
self.logdir = logdir
|
||||
self.eval_train = eval_train
|
||||
self.eval_test = eval_test
|
||||
self.avg_params = avg_params
|
||||
|
||||
if self.tra.num_states > 1 and not self.eval_train:
|
||||
self.logger.warn("`eval_train` will be ignored when using TRA")
|
||||
|
||||
if self.logdir is not None:
|
||||
if os.path.exists(self.logdir):
|
||||
self.logger.warn(f"logdir {self.logdir} is not empty")
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
|
||||
self.fitted = False
|
||||
self.global_step = -1
|
||||
|
||||
def train_epoch(self, data_set):
|
||||
|
||||
self.model.train()
|
||||
self.tra.train()
|
||||
|
||||
data_set.train()
|
||||
|
||||
max_steps = self.n_epochs
|
||||
if self.max_steps_per_epoch is not None:
|
||||
max_steps = min(self.max_steps_per_epoch, self.n_epochs)
|
||||
|
||||
count = 0
|
||||
total_loss = 0
|
||||
total_count = 0
|
||||
for batch in tqdm(data_set, total=max_steps):
|
||||
count += 1
|
||||
if count > max_steps:
|
||||
break
|
||||
|
||||
self.global_step += 1
|
||||
|
||||
data, label, index = batch["data"], batch["label"], batch["index"]
|
||||
|
||||
feature = data[:, :, : -self.tra.num_states]
|
||||
hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]
|
||||
|
||||
hidden = self.model(feature)
|
||||
pred, all_preds, prob = self.tra(hidden, hist_loss)
|
||||
|
||||
loss = (pred - label).pow(2).mean()
|
||||
|
||||
L = (all_preds.detach() - label[:, None]).pow(2)
|
||||
L -= L.min(dim=-1, keepdim=True).values # normalize & ensure postive input
|
||||
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
|
||||
if prob is not None:
|
||||
P = sinkhorn(-L, epsilon=0.01) # sample assignment matrix
|
||||
lamb = self.lamb * (self.rho ** self.global_step)
|
||||
reg = prob.log().mul(P).sum(dim=-1).mean()
|
||||
loss = loss - lamb * reg
|
||||
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_count += len(pred)
|
||||
|
||||
total_loss /= total_count
|
||||
|
||||
return total_loss
|
||||
|
||||
def test_epoch(self, data_set, return_pred=False):
|
||||
|
||||
self.model.eval()
|
||||
self.tra.eval()
|
||||
data_set.eval()
|
||||
|
||||
preds = []
|
||||
metrics = []
|
||||
for batch in tqdm(data_set):
|
||||
data, label, index = batch["data"], batch["label"], batch["index"]
|
||||
|
||||
feature = data[:, :, : -self.tra.num_states]
|
||||
hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]
|
||||
|
||||
with torch.no_grad():
|
||||
hidden = self.model(feature)
|
||||
pred, all_preds, prob = self.tra(hidden, hist_loss)
|
||||
|
||||
L = (all_preds - label[:, None]).pow(2)
|
||||
|
||||
L -= L.min(dim=-1, keepdim=True).values # normalize & ensure postive input
|
||||
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
|
||||
X = np.c_[
|
||||
pred.cpu().numpy(),
|
||||
label.cpu().numpy(),
|
||||
]
|
||||
columns = ["score", "label"]
|
||||
if prob is not None:
|
||||
X = np.c_[X, all_preds.cpu().numpy(), prob.cpu().numpy()]
|
||||
columns += ["score_%d" % d for d in range(all_preds.shape[1])] + [
|
||||
"prob_%d" % d for d in range(all_preds.shape[1])
|
||||
]
|
||||
|
||||
pred = pd.DataFrame(X, index=index.cpu().numpy(), columns=columns)
|
||||
|
||||
metrics.append(evaluate(pred))
|
||||
|
||||
if return_pred:
|
||||
preds.append(pred)
|
||||
|
||||
metrics = pd.DataFrame(metrics)
|
||||
metrics = {
|
||||
"MSE": metrics.MSE.mean(),
|
||||
"MAE": metrics.MAE.mean(),
|
||||
"IC": metrics.IC.mean(),
|
||||
"ICIR": metrics.IC.mean() / metrics.IC.std(),
|
||||
}
|
||||
|
||||
if return_pred:
|
||||
preds = pd.concat(preds, axis=0)
|
||||
preds.index = data_set.restore_index(preds.index)
|
||||
preds.index = preds.index.swaplevel()
|
||||
preds.sort_index(inplace=True)
|
||||
|
||||
return metrics, preds
|
||||
|
||||
def fit(self, dataset, evals_result=dict()):
|
||||
|
||||
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
|
||||
|
||||
best_score = -1
|
||||
best_epoch = 0
|
||||
stop_rounds = 0
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
params_list = {
|
||||
"model": collections.deque(maxlen=self.smooth_steps),
|
||||
"tra": collections.deque(maxlen=self.smooth_steps),
|
||||
}
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
evals_result["test"] = []
|
||||
|
||||
# train
|
||||
self.fitted = True
|
||||
self.global_step = -1
|
||||
|
||||
if self.tra.num_states > 1:
|
||||
self.logger.info("init memory...")
|
||||
self.test_epoch(train_set)
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
self.logger.info("Epoch %d:", epoch)
|
||||
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(train_set)
|
||||
|
||||
self.logger.info("evaluating...")
|
||||
# average params for inference
|
||||
params_list["model"].append(copy.deepcopy(self.model.state_dict()))
|
||||
params_list["tra"].append(copy.deepcopy(self.tra.state_dict()))
|
||||
self.model.load_state_dict(average_params(params_list["model"]))
|
||||
self.tra.load_state_dict(average_params(params_list["tra"]))
|
||||
|
||||
# NOTE: during evaluating, the whole memory will be refreshed
|
||||
if self.tra.num_states > 1 or self.eval_train:
|
||||
train_set.clear_memory() # NOTE: clear the shared memory
|
||||
train_metrics = self.test_epoch(train_set)[0]
|
||||
evals_result["train"].append(train_metrics)
|
||||
self.logger.info("\ttrain metrics: %s" % train_metrics)
|
||||
|
||||
valid_metrics = self.test_epoch(valid_set)[0]
|
||||
evals_result["valid"].append(valid_metrics)
|
||||
self.logger.info("\tvalid metrics: %s" % valid_metrics)
|
||||
|
||||
if self.eval_test:
|
||||
test_metrics = self.test_epoch(test_set)[0]
|
||||
evals_result["test"].append(test_metrics)
|
||||
self.logger.info("\ttest metrics: %s" % test_metrics)
|
||||
|
||||
if valid_metrics["IC"] > best_score:
|
||||
best_score = valid_metrics["IC"]
|
||||
stop_rounds = 0
|
||||
best_epoch = epoch
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
else:
|
||||
stop_rounds += 1
|
||||
if stop_rounds >= self.early_stop:
|
||||
self.logger.info("early stop @ %s" % epoch)
|
||||
break
|
||||
|
||||
# restore parameters
|
||||
self.model.load_state_dict(params_list["model"][-1])
|
||||
self.tra.load_state_dict(params_list["tra"][-1])
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_params["model"])
|
||||
self.tra.load_state_dict(best_params["tra"])
|
||||
|
||||
metrics, preds = self.test_epoch(test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % metrics)
|
||||
|
||||
if self.logdir:
|
||||
self.logger.info("save model & pred to local directory")
|
||||
|
||||
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
|
||||
self.logdir + "/logs.csv", index=False
|
||||
)
|
||||
|
||||
torch.save(best_params, self.logdir + "/model.bin")
|
||||
|
||||
preds.to_pickle(self.logdir + "/pred.pkl")
|
||||
|
||||
info = {
|
||||
"config": {
|
||||
"model_config": self.model_config,
|
||||
"tra_config": self.tra_config,
|
||||
"lr": self.lr,
|
||||
"n_epochs": self.n_epochs,
|
||||
"early_stop": self.early_stop,
|
||||
"smooth_steps": self.smooth_steps,
|
||||
"max_steps_per_epoch": self.max_steps_per_epoch,
|
||||
"lamb": self.lamb,
|
||||
"rho": self.rho,
|
||||
"seed": self.seed,
|
||||
"logdir": self.logdir,
|
||||
},
|
||||
"best_eval_metric": -best_score, # NOTE: minux -1 for minimize
|
||||
"metric": metrics,
|
||||
}
|
||||
with open(self.logdir + "/info.json", "w") as f:
|
||||
json.dump(info, f)
|
||||
|
||||
def predict(self, dataset, segment="test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
test_set = dataset.prepare(segment)
|
||||
|
||||
metrics, preds = self.test_epoch(test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % metrics)
|
||||
|
||||
return preds
|
||||
|
||||
|
||||
class LSTM(nn.Module):
|
||||
|
||||
"""LSTM Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of hidden layers
|
||||
use_attn (bool): whether use attention layer.
|
||||
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
|
||||
dropout (float): dropout rate
|
||||
input_drop (float): input dropout for data augmentation
|
||||
noise_level (float): add gaussian noise to input for data augmentation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
use_attn=True,
|
||||
dropout=0.0,
|
||||
input_drop=0.0,
|
||||
noise_level=0.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.use_attn = use_attn
|
||||
self.noise_level = noise_level
|
||||
|
||||
self.input_drop = nn.Dropout(input_drop)
|
||||
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
if self.use_attn:
|
||||
self.W = nn.Linear(hidden_size, hidden_size)
|
||||
self.u = nn.Linear(hidden_size, 1, bias=False)
|
||||
self.output_size = hidden_size * 2
|
||||
else:
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.input_drop(x)
|
||||
|
||||
if self.training and self.noise_level > 0:
|
||||
noise = torch.randn_like(x).to(x)
|
||||
x = x + noise * self.noise_level
|
||||
|
||||
rnn_out, _ = self.rnn(x)
|
||||
last_out = rnn_out[:, -1]
|
||||
|
||||
if self.use_attn:
|
||||
laten = self.W(rnn_out).tanh()
|
||||
scores = self.u(laten).softmax(dim=1)
|
||||
att_out = (rnn_out * scores).sum(dim=1).squeeze()
|
||||
last_out = torch.cat([last_out, att_out], dim=1)
|
||||
|
||||
return last_out
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[: x.size(0), :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
"""Transformer Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of transformer layers
|
||||
num_heads (int): number of heads in transformer
|
||||
dropout (float): dropout rate
|
||||
input_drop (float): input dropout for data augmentation
|
||||
noise_level (float): add gaussian noise to input for data augmentation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
num_heads=2,
|
||||
dropout=0.0,
|
||||
input_drop=0.0,
|
||||
noise_level=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.noise_level = noise_level
|
||||
|
||||
self.input_drop = nn.Dropout(input_drop)
|
||||
|
||||
self.input_proj = nn.Linear(input_size, hidden_size)
|
||||
|
||||
self.pe = PositionalEncoding(input_size, dropout)
|
||||
layer = nn.TransformerEncoderLayer(
|
||||
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
|
||||
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.input_drop(x)
|
||||
|
||||
if self.training and self.noise_level > 0:
|
||||
noise = torch.randn_like(x).to(x)
|
||||
x = x + noise * self.noise_level
|
||||
|
||||
x = x.permute(1, 0, 2).contiguous() # the first dim need to be sequence
|
||||
x = self.pe(x)
|
||||
|
||||
x = self.input_proj(x)
|
||||
out = self.encoder(x)
|
||||
|
||||
return out[-1]
|
||||
|
||||
|
||||
class TRA(nn.Module):
|
||||
|
||||
"""Temporal Routing Adaptor (TRA)
|
||||
|
||||
TRA takes historical prediction erros & latent representation as inputs,
|
||||
then routes the input sample to a specific predictor for training & inference.
|
||||
|
||||
Args:
|
||||
input_size (int): input size (RNN/Transformer's hidden size)
|
||||
num_states (int): number of latent states (i.e., trading patterns)
|
||||
If `num_states=1`, then TRA falls back to traditional methods
|
||||
hidden_size (int): hidden size of the router
|
||||
tau (float): gumbel softmax temperature
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"):
|
||||
super().__init__()
|
||||
|
||||
self.num_states = num_states
|
||||
self.tau = tau
|
||||
self.src_info = src_info
|
||||
|
||||
if num_states > 1:
|
||||
self.router = nn.LSTM(
|
||||
input_size=num_states,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
)
|
||||
self.fc = nn.Linear(hidden_size + input_size, num_states)
|
||||
|
||||
self.predictors = nn.Linear(input_size, num_states)
|
||||
|
||||
def forward(self, hidden, hist_loss):
|
||||
|
||||
preds = self.predictors(hidden)
|
||||
|
||||
if self.num_states == 1:
|
||||
return preds.squeeze(-1), preds, None
|
||||
|
||||
# information type
|
||||
router_out, _ = self.router(hist_loss)
|
||||
if "LR" in self.src_info:
|
||||
latent_representation = hidden
|
||||
else:
|
||||
latent_representation = torch.randn(hidden.shape).to(hidden)
|
||||
if "TPE" in self.src_info:
|
||||
temporal_pred_error = router_out[:, -1]
|
||||
else:
|
||||
temporal_pred_error = torch.randn(router_out[:, -1].shape).to(hidden)
|
||||
|
||||
out = self.fc(torch.cat([temporal_pred_error, latent_representation], dim=-1))
|
||||
prob = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=False)
|
||||
|
||||
if self.training:
|
||||
final_pred = (preds * prob).sum(dim=-1)
|
||||
else:
|
||||
final_pred = preds[range(len(preds)), prob.argmax(dim=-1)]
|
||||
|
||||
return final_pred, preds, prob
|
||||
|
||||
|
||||
def evaluate(pred):
|
||||
pred = pred.rank(pct=True) # transform into percentiles
|
||||
score = pred.score
|
||||
label = pred.label
|
||||
diff = score - label
|
||||
MSE = (diff ** 2).mean()
|
||||
MAE = (diff.abs()).mean()
|
||||
IC = score.corr(label)
|
||||
return {"MSE": MSE, "MAE": MAE, "IC": IC}
|
||||
|
||||
|
||||
def average_params(params_list):
|
||||
assert isinstance(params_list, (tuple, list, collections.deque))
|
||||
n = len(params_list)
|
||||
if n == 1:
|
||||
return params_list[0]
|
||||
new_params = collections.OrderedDict()
|
||||
keys = None
|
||||
for i, params in enumerate(params_list):
|
||||
if keys is None:
|
||||
keys = params.keys()
|
||||
for k, v in params.items():
|
||||
if k not in keys:
|
||||
raise ValueError("the %d-th model has different params" % i)
|
||||
if k not in new_params:
|
||||
new_params[k] = v / n
|
||||
else:
|
||||
new_params[k] += v / n
|
||||
return new_params
|
||||
|
||||
|
||||
def shoot_infs(inp_tensor):
|
||||
"""Replaces inf by maximum of tensor"""
|
||||
mask_inf = torch.isinf(inp_tensor)
|
||||
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
|
||||
if len(ind_inf) > 0:
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = 0
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = 0
|
||||
m = torch.max(inp_tensor)
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = m
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = m
|
||||
return inp_tensor
|
||||
|
||||
|
||||
def sinkhorn(Q, n_iters=3, epsilon=0.01):
|
||||
# epsilon should be adjusted according to logits value's scale
|
||||
with torch.no_grad():
|
||||
Q = shoot_infs(Q)
|
||||
Q = torch.exp(Q / epsilon)
|
||||
for i in range(n_iters):
|
||||
Q /= Q.sum(dim=0, keepdim=True)
|
||||
Q /= Q.sum(dim=1, keepdim=True)
|
||||
return Q
|
||||
129
examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml
Normal file
129
examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml
Normal file
@@ -0,0 +1,129 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 20
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.0
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha158
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
123
examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml
Normal file
123
examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml
Normal file
@@ -0,0 +1,123 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 158
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.2
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha158_full
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
123
examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml
Normal file
123
examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml
Normal file
@@ -0,0 +1,123 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.0
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha360
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size: 6
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
3
examples/benchmarks/Transformer/requirements.txt
Normal file
3
examples/benchmarks/Transformer/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
torch==1.2.0
|
||||
@@ -0,0 +1,82 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TransformerModel
|
||||
module_path: qlib.contrib.model.pytorch_transformer_ts
|
||||
kwargs:
|
||||
seed: 0
|
||||
n_jobs: 20
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,73 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TransformerModel
|
||||
module_path: qlib.contrib.model.pytorch_transformer
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
seed: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,7 +1,5 @@
|
||||
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.utils import get_cls_kwargs
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.contrib.data.handler import check_transform_proc
|
||||
|
||||
|
||||
class HighFreqHandler(DataHandlerLP):
|
||||
@@ -16,20 +14,9 @@ class HighFreqHandler(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
@@ -112,8 +99,6 @@ class HighFreqHandler(DataHandlerLP):
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
|
||||
names += ["date"]
|
||||
return fields, names
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def get_calendar_day(freq="day", future=False):
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
|
||||
_calendar = np.array(list(map(lambda x: pd.Timestamp(x.date()), Cal.load_calendar(freq, future))))
|
||||
H["c"][flag] = _calendar
|
||||
return _calendar
|
||||
|
||||
|
||||
@@ -33,6 +33,9 @@ class HighFreqNorm(Processor):
|
||||
self.feature_vmin[name] = np.nanmin(part_values)
|
||||
|
||||
def __call__(self, df_features):
|
||||
df_features["date"] = pd.to_datetime(
|
||||
df_features.index.get_level_values(level="datetime").to_series().dt.date.values
|
||||
)
|
||||
df_features.set_index("date", append=True, drop=True, inplace=True)
|
||||
df_values = df_features.values
|
||||
names = {
|
||||
|
||||
@@ -33,7 +33,7 @@ class HighfreqWorkflow:
|
||||
"fit_start_time": start_time,
|
||||
"fit_end_time": train_end_time,
|
||||
"instruments": MARKET,
|
||||
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}],
|
||||
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor"}],
|
||||
}
|
||||
DATA_HANDLER_CONFIG1 = {
|
||||
"start_time": start_time,
|
||||
|
||||
1
examples/model_rolling/requirements.txt
Normal file
1
examples/model_rolling/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
xgboost
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be shown in task_collecting.
|
||||
Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
@@ -13,10 +14,10 @@ import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from qlib.model.trainer import TrainerRM, task_train
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
@@ -68,6 +69,11 @@ class RollingTaskExample:
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
|
||||
print("========== worker ==========")
|
||||
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
|
||||
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
@@ -13,7 +14,7 @@ from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
@@ -22,8 +23,8 @@ class OnlineSimulationExample:
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
@@ -46,7 +47,7 @@ class OnlineSimulationExample:
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
@@ -59,7 +60,7 @@ class OnlineSimulationExample:
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
@@ -85,6 +86,15 @@ class OnlineSimulationExample:
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
# FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
self.trainer.worker()
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
|
||||
@@ -13,11 +13,13 @@ Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
import os
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
@@ -25,16 +27,17 @@ class RollingOnlineExample:
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
rolling_step=550,
|
||||
tasks=None,
|
||||
add_tasks=None,
|
||||
):
|
||||
if add_tasks is None:
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
@@ -53,17 +56,28 @@ class RollingOnlineExample:
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineManager(strategies)
|
||||
self.trainer = trainer
|
||||
self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
self.trainer.worker(experiment_name=name_id)
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
TaskManager(task_pool=name_id).remove()
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
@@ -23,7 +23,6 @@ from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
# init qlib
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
exp_folder_name = "run_all_model_records"
|
||||
@@ -40,6 +39,7 @@ exp_manager = {
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
|
||||
|
||||
# decorator to check the arguments
|
||||
def only_allow_defined_args(function_to_decorate):
|
||||
@functools.wraps(function_to_decorate)
|
||||
@@ -92,7 +92,8 @@ def create_env():
|
||||
|
||||
|
||||
# function to execute the cmd
|
||||
def execute(cmd):
|
||||
def execute(cmd, wait_when_err=False):
|
||||
print("Running CMD:", cmd)
|
||||
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
|
||||
for line in p.stdout:
|
||||
sys.stdout.write(line.split("\b")[0])
|
||||
@@ -102,6 +103,8 @@ def execute(cmd):
|
||||
sys.stdout.write("\b" * 10 + "\b".join(line.split("\b")[1:-1]))
|
||||
|
||||
if p.returncode != 0:
|
||||
if wait_when_err:
|
||||
input("Press Enter to Continue")
|
||||
return p.stderr
|
||||
else:
|
||||
return None
|
||||
@@ -184,7 +187,15 @@ def gen_and_save_md_table(metrics, dataset):
|
||||
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
def run(
|
||||
times=1,
|
||||
models=None,
|
||||
dataset="Alpha360",
|
||||
exclude=False,
|
||||
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
|
||||
wait_before_rm_env: bool = False,
|
||||
wait_when_err: bool = False,
|
||||
):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parrallel running the same model
|
||||
@@ -200,6 +211,13 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
determines whether the model being used is excluded or included.
|
||||
dataset : str
|
||||
determines the dataset to be used for each model.
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path
|
||||
wait_before_rm_env : bool
|
||||
wait before remove environment.
|
||||
wait_when_err : bool
|
||||
wait when errors raised when executing commands
|
||||
|
||||
Usage:
|
||||
-------
|
||||
@@ -240,32 +258,36 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
sys.stderr.write("\n")
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
execute(f"{python_path} -m pip install -r {req_path}")
|
||||
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
|
||||
sys.stderr.write("\n")
|
||||
# setup gpu for tft
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn"
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
execute(f"{python_path} -m pip install --upgrade pip") # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e git+https://github.com/microsoft/qlib#egg=pyqlib"
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
else:
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e git+https://github.com/microsoft/qlib#egg=pyqlib"
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config for multiple times
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
errs = execute(
|
||||
f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
|
||||
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
if errs is not None:
|
||||
_errs = errors.get(fn, {})
|
||||
@@ -274,6 +296,8 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
if wait_before_rm_env:
|
||||
input("Press Enter to Continue")
|
||||
shutil.rmtree(env_path)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
|
||||
@@ -220,7 +220,7 @@
|
||||
"\n",
|
||||
"# backtest and analysis\n",
|
||||
"with R.start(experiment_name=\"backtest_analysis\"):\n",
|
||||
" recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
|
||||
" recorder = R.get_recorder(recorder_id=rid, experiment_name=\"train_model\")\n",
|
||||
" model = recorder.load_object(\"trained_model\")\n",
|
||||
"\n",
|
||||
" # prediction\n",
|
||||
@@ -249,7 +249,7 @@
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
"from qlib.data import D\n",
|
||||
"recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"pred_df = recorder.load_object(\"pred.pkl\")\n",
|
||||
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
|
||||
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal.pkl\")\n",
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
__version__ = "0.6.3.99"
|
||||
_version_path = Path(__file__).absolute().parent / "VERSION.txt" # This file is copyed from setup.py
|
||||
__version__ = _version_path.read_text(encoding="utf-8").strip()
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
|
||||
|
||||
import os
|
||||
import yaml
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
@@ -20,76 +18,83 @@ def init(default_conf="client", **kwargs):
|
||||
from .config import C
|
||||
from .data.cache import H
|
||||
|
||||
H.clear()
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
logger = get_module_logger("Initialization", level=logging.INFO)
|
||||
|
||||
skip_if_reg = kwargs.pop("skip_if_reg", False)
|
||||
if skip_if_reg and C.registered:
|
||||
# if we reinitialize Qlib during running an experiment `R.start`.
|
||||
# it will result in loss of the recorder
|
||||
logger.warning("Skip initialization because `skip_if_reg is True`")
|
||||
return
|
||||
|
||||
H.clear()
|
||||
C.set(default_conf, **kwargs)
|
||||
|
||||
# check path if server/local
|
||||
if C.get_uri_type() == C.LOCAL_URI:
|
||||
if not os.path.exists(C["provider_uri"]):
|
||||
if C["auto_mount"]:
|
||||
logger.error(
|
||||
f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
|
||||
elif C.get_uri_type() == C.NFS_URI:
|
||||
_mount_nfs_uri(C)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of URI is not supported")
|
||||
# mount nfs
|
||||
for _freq, provider_uri in C.provider_uri.items():
|
||||
mount_path = C["mount_path"][_freq]
|
||||
# check path if server/local
|
||||
uri_type = C.dpm.get_uri_type(provider_uri)
|
||||
if uri_type == C.LOCAL_URI:
|
||||
if not Path(provider_uri).exists():
|
||||
if C["auto_mount"]:
|
||||
logger.error(
|
||||
f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
|
||||
elif uri_type == C.NFS_URI:
|
||||
_mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
|
||||
else:
|
||||
raise NotImplementedError(f"This type of URI is not supported")
|
||||
|
||||
C.register()
|
||||
|
||||
if "flask_server" in C:
|
||||
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
logger.info("qlib successfully initialized based on %s settings." % default_conf)
|
||||
logger.info(f"data_path={C.get_data_path()}")
|
||||
data_path = {_freq: C.dpm.get_data_path(_freq) for _freq in C.dpm.provider_uri.keys()}
|
||||
logger.info(f"data_path={data_path}")
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
|
||||
LOG = get_module_logger("mount nfs", level=logging.INFO)
|
||||
|
||||
# FIXME: the C["provider_uri"] is modified in this function
|
||||
# If it is not modified, we can pass only provider_uri or mount_path instead of C
|
||||
mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"])
|
||||
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
|
||||
# If the provider uri looks like this 172.23.233.89//data/csdesign'
|
||||
# It will be a nfs path. The client provider will be used
|
||||
if not C["auto_mount"]:
|
||||
if not os.path.exists(C["mount_path"]):
|
||||
if not auto_mount:
|
||||
if not Path(mount_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Invalid mount path: {C['mount_path']}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
)
|
||||
else:
|
||||
# Judging system type
|
||||
sys_type = platform.system()
|
||||
if "win" in sys_type.lower():
|
||||
# system: window
|
||||
exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":"))
|
||||
exec_result = os.popen("mount -o anon %s %s" % (provider_uri, mount_path + ":"))
|
||||
result = exec_result.read()
|
||||
if "85" in result:
|
||||
LOG.warning("already mounted or window mount path already exists")
|
||||
LOG.warning(f"{provider_uri} on Windows:{mount_path} is already mounted")
|
||||
elif "53" in result:
|
||||
raise OSError("not find network path")
|
||||
elif "error" in result or "错误" in result:
|
||||
raise OSError("Invalid mount path")
|
||||
elif C["provider_uri"] in result:
|
||||
elif provider_uri in result:
|
||||
LOG.info("window success mount..")
|
||||
else:
|
||||
raise OSError(f"unknown error: {result}")
|
||||
|
||||
# config mount path
|
||||
C["mount_path"] = C["mount_path"] + ":\\"
|
||||
else:
|
||||
# system: linux/Unix/Mac
|
||||
# check mount
|
||||
_remote_uri = C["provider_uri"]
|
||||
_remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri
|
||||
_mount_path = C["mount_path"]
|
||||
_mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path
|
||||
_remote_uri = provider_uri[:-1] if provider_uri.endswith("/") else provider_uri
|
||||
_mount_path = mount_path[:-1] if mount_path.endswith("/") else mount_path
|
||||
_check_level_num = 2
|
||||
_is_mount = False
|
||||
while _check_level_num:
|
||||
@@ -115,11 +120,9 @@ def _mount_nfs_uri(C):
|
||||
|
||||
if not _is_mount:
|
||||
try:
|
||||
os.makedirs(C["mount_path"], exist_ok=True)
|
||||
Path(mount_path).mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
raise OSError(
|
||||
f"Failed to create directory {C['mount_path']}, please create {C['mount_path']} manually!"
|
||||
)
|
||||
raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
|
||||
|
||||
# check nfs-common
|
||||
command_res = os.popen("dpkg -l | grep nfs-common")
|
||||
@@ -130,11 +133,11 @@ def _mount_nfs_uri(C):
|
||||
command_status = os.system(mount_command)
|
||||
if command_status == 256:
|
||||
raise OSError(
|
||||
f"mount {C['provider_uri']} on {C['mount_path']} error! Needs SUDO! Please mount manually: {mount_command}"
|
||||
f"mount {provider_uri} on {mount_path} error! Needs SUDO! Please mount manually: {mount_command}"
|
||||
)
|
||||
elif command_status == 32512:
|
||||
# LOG.error("Command error")
|
||||
raise OSError(f"mount {C['provider_uri']} on {C['mount_path']} error! Command error")
|
||||
raise OSError(f"mount {provider_uri} on {mount_path} error! Command error")
|
||||
elif command_status == 0:
|
||||
LOG.info("Mount finished")
|
||||
else:
|
||||
@@ -197,14 +200,15 @@ def auto_init(**kwargs):
|
||||
- Find the project configuration and init qlib
|
||||
- The parsing process will be affected by the `conf_type` of the configuration file
|
||||
- Init qlib with default config
|
||||
- Skip initialization if already initialized
|
||||
"""
|
||||
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
|
||||
|
||||
try:
|
||||
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
137
qlib/config.py
137
qlib/config.py
@@ -15,8 +15,10 @@ import os
|
||||
import re
|
||||
import copy
|
||||
import logging
|
||||
import platform
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -73,6 +75,12 @@ REG_US = "us"
|
||||
|
||||
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
|
||||
|
||||
DISK_DATASET_CACHE = "DiskDatasetCache"
|
||||
SIMPLE_DATASET_CACHE = "SimpleDatasetCache"
|
||||
DISK_EXPRESSION_CACHE = "DiskExpressionCache"
|
||||
|
||||
DEPENDENCY_REDIS_CACHE = (DISK_DATASET_CACHE, DISK_EXPRESSION_CACHE)
|
||||
|
||||
_default_config = {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
@@ -82,6 +90,15 @@ _default_config = {
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
# config it in qlib.init()
|
||||
# "provider_uri" str or dict:
|
||||
# # str
|
||||
# "~/.qlib/stock_data/cn_data"
|
||||
# # dict
|
||||
# {"day": "~/.qlib/stock_data/cn_data", "1min": "~/.qlib/stock_data/cn_data_1min"}
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
|
||||
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
|
||||
# 3. qlib.init: provider_uri
|
||||
"provider_uri": "",
|
||||
# cache
|
||||
"expression_cache": None,
|
||||
@@ -167,8 +184,9 @@ MODE_CONF = {
|
||||
"redis_task_db": 1,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# cache
|
||||
"expression_cache": "DiskExpressionCache",
|
||||
"dataset_cache": "DiskDatasetCache",
|
||||
"expression_cache": DISK_EXPRESSION_CACHE,
|
||||
"dataset_cache": DISK_DATASET_CACHE,
|
||||
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
|
||||
"mount_path": None,
|
||||
},
|
||||
"client": {
|
||||
@@ -183,8 +201,10 @@ MODE_CONF = {
|
||||
"provider_uri": "~/.qlib/qlib_data/cn_data",
|
||||
# cache
|
||||
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
|
||||
"expression_cache": "DiskExpressionCache",
|
||||
"dataset_cache": "DiskDatasetCache",
|
||||
"expression_cache": DISK_EXPRESSION_CACHE,
|
||||
"dataset_cache": DISK_DATASET_CACHE,
|
||||
# SimpleDatasetCache directory
|
||||
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
|
||||
"calendar_cache": None,
|
||||
# client config
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
@@ -195,7 +215,10 @@ MODE_CONF = {
|
||||
"timeout": 100,
|
||||
"logging_level": logging.INFO,
|
||||
"region": REG_CN,
|
||||
## Custom Operator
|
||||
# custom operator
|
||||
# each element of custom_ops should be Type[ExpressionOps] or dict
|
||||
# if element of custom_ops is Type[ExpressionOps], it represents the custom operator class
|
||||
# if element of custom_ops is dict, it represents the config of custom operator and should include `class` and `module_path` keys.
|
||||
"custom_ops": [],
|
||||
},
|
||||
}
|
||||
@@ -225,11 +248,43 @@ class QlibConfig(Config):
|
||||
# URI_TYPE
|
||||
LOCAL_URI = "local"
|
||||
NFS_URI = "nfs"
|
||||
DEFAULT_FREQ = "__DEFAULT_FREQ"
|
||||
|
||||
def __init__(self, default_conf):
|
||||
super().__init__(default_conf)
|
||||
self._registered = False
|
||||
|
||||
class DataPathManager:
|
||||
def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]):
|
||||
self.provider_uri = provider_uri
|
||||
self.mount_path = mount_path
|
||||
|
||||
@staticmethod
|
||||
def get_uri_type(uri: Union[str, Path]):
|
||||
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
|
||||
is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:'
|
||||
# such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None
|
||||
|
||||
if is_nfs_or_win and not is_win:
|
||||
return QlibConfig.NFS_URI
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def get_data_path(self, freq: str = None) -> Path:
|
||||
if freq is None or freq not in self.provider_uri:
|
||||
freq = QlibConfig.DEFAULT_FREQ
|
||||
_provider_uri = self.provider_uri[freq]
|
||||
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
|
||||
return Path(_provider_uri)
|
||||
elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
|
||||
if "win" in platform.system().lower():
|
||||
# windows, mount_path is the drive
|
||||
return Path(f"{self.mount_path[freq]}:\\")
|
||||
return Path(self.mount_path[freq])
|
||||
else:
|
||||
raise NotImplementedError(f"This type of uri is not supported")
|
||||
|
||||
def set_mode(self, mode):
|
||||
# raise KeyError
|
||||
self.update(MODE_CONF[mode])
|
||||
@@ -239,32 +294,43 @@ class QlibConfig(Config):
|
||||
# raise KeyError
|
||||
self.update(_default_region_config[region])
|
||||
|
||||
@staticmethod
|
||||
def is_depend_redis(cache_name: str):
|
||||
return cache_name in DEPENDENCY_REDIS_CACHE
|
||||
|
||||
@property
|
||||
def dpm(self):
|
||||
return self.DataPathManager(self["provider_uri"], self["mount_path"])
|
||||
|
||||
def resolve_path(self):
|
||||
# resolve path
|
||||
if self["mount_path"] is not None:
|
||||
self["mount_path"] = str(Path(self["mount_path"]).expanduser().resolve())
|
||||
_mount_path = self["mount_path"]
|
||||
_provider_uri = self["provider_uri"]
|
||||
if _provider_uri is None:
|
||||
raise ValueError("provider_uri cannot be None")
|
||||
if not isinstance(_provider_uri, dict):
|
||||
_provider_uri = {self.DEFAULT_FREQ: _provider_uri}
|
||||
if not isinstance(_mount_path, dict):
|
||||
_mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}
|
||||
|
||||
if self.get_uri_type() == QlibConfig.LOCAL_URI:
|
||||
self["provider_uri"] = str(Path(self["provider_uri"]).expanduser().resolve())
|
||||
# check provider_uri and mount_path
|
||||
_miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())
|
||||
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
|
||||
|
||||
def get_uri_type(self):
|
||||
is_win = re.match("^[a-zA-Z]:.*", self["provider_uri"]) is not None # such as 'C:\\data', 'D:'
|
||||
is_nfs_or_win = (
|
||||
re.match("^[^/]+:.+", self["provider_uri"]) is not None
|
||||
) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
# resolve
|
||||
for _freq, _uri in _provider_uri.items():
|
||||
# provider_uri
|
||||
if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
|
||||
_provider_uri[_freq] = str(Path(_uri).expanduser().resolve())
|
||||
# mount_path
|
||||
_mount_path[_freq] = (
|
||||
_mount_path[_freq]
|
||||
if _mount_path[_freq] is None
|
||||
else str(Path(_mount_path[_freq]).expanduser().resolve())
|
||||
)
|
||||
|
||||
if is_nfs_or_win and not is_win:
|
||||
return QlibConfig.NFS_URI
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def get_data_path(self):
|
||||
if self.get_uri_type() == QlibConfig.LOCAL_URI:
|
||||
return self["provider_uri"]
|
||||
elif self.get_uri_type() == QlibConfig.NFS_URI:
|
||||
return self["mount_path"]
|
||||
else:
|
||||
raise NotImplementedError(f"This type of uri is not supported")
|
||||
self["provider_uri"] = _provider_uri
|
||||
self["mount_path"] = _mount_path
|
||||
|
||||
def set(self, default_conf="client", **kwargs):
|
||||
from .utils import set_log_with_config, get_module_logger, can_use_cache
|
||||
@@ -296,11 +362,20 @@ class QlibConfig(Config):
|
||||
if not (self["expression_cache"] is None and self["dataset_cache"] is None):
|
||||
# check redis
|
||||
if not can_use_cache():
|
||||
logger.warning(
|
||||
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), cache will not be used!"
|
||||
)
|
||||
self["expression_cache"] = None
|
||||
self["dataset_cache"] = None
|
||||
log_str = ""
|
||||
# check expression cache
|
||||
if self.is_depend_redis(self["expression_cache"]):
|
||||
log_str += self["expression_cache"]
|
||||
self["expression_cache"] = None
|
||||
# check dataset cache
|
||||
if self.is_depend_redis(self["dataset_cache"]):
|
||||
log_str += f" and {self['dataset_cache']}" if log_str else self["dataset_cache"]
|
||||
self["dataset_cache"] = None
|
||||
if log_str:
|
||||
logger.warning(
|
||||
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), "
|
||||
f"{log_str} will not be used!"
|
||||
)
|
||||
|
||||
def register(self):
|
||||
from .utils import init_instance_by_config
|
||||
|
||||
@@ -16,6 +16,7 @@ def get_benchmark_weight(
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
path=None,
|
||||
freq="day",
|
||||
):
|
||||
"""get_benchmark_weight
|
||||
|
||||
@@ -25,6 +26,7 @@ def get_benchmark_weight(
|
||||
:param start_date:
|
||||
:param end_date:
|
||||
:param path:
|
||||
:param freq:
|
||||
|
||||
:return: The weight distribution of the the benchmark described by a pandas dataframe
|
||||
Every row corresponds to a trading day.
|
||||
@@ -33,7 +35,7 @@ def get_benchmark_weight(
|
||||
|
||||
"""
|
||||
if not path:
|
||||
path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
|
||||
path = Path(C.dpm.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
|
||||
# TODO: the storage of weights should be implemented in a more elegent way
|
||||
# TODO: The benchmark is not consistant with the filename in instruments.
|
||||
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
|
||||
@@ -222,6 +224,7 @@ def brinson_pa(
|
||||
group_method="category",
|
||||
group_n=None,
|
||||
deal_price="vwap",
|
||||
freq="day",
|
||||
):
|
||||
"""brinson profit attribution
|
||||
|
||||
@@ -243,7 +246,7 @@ def brinson_pa(
|
||||
|
||||
start_date, end_date = min(dates), max(dates)
|
||||
|
||||
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date)
|
||||
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date, freq)
|
||||
|
||||
# The attributes for allocation will not
|
||||
if not group_field.startswith("$"):
|
||||
@@ -259,13 +262,14 @@ def brinson_pa(
|
||||
start_time=shift_start_date,
|
||||
end_time=end_date,
|
||||
as_list=True,
|
||||
freq=freq,
|
||||
)
|
||||
stock_df = D.features(
|
||||
instruments,
|
||||
[group_field, deal_price],
|
||||
start_time=shift_start_date,
|
||||
end_time=end_date,
|
||||
freq="day",
|
||||
freq=freq,
|
||||
)
|
||||
stock_df.columns = [group_field, "deal_price"]
|
||||
|
||||
|
||||
346
qlib/contrib/data/dataset.py
Normal file
346
qlib/contrib/data/dataset.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return torch.tensor(x, dtype=torch.float, device=device)
|
||||
return x
|
||||
|
||||
|
||||
def _create_ts_slices(index, seq_len):
|
||||
"""
|
||||
create time series slices from pandas index
|
||||
|
||||
Args:
|
||||
index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
|
||||
seq_len (int): sequence length
|
||||
"""
|
||||
assert isinstance(index, pd.MultiIndex), "unsupported index type"
|
||||
assert seq_len > 0, "sequence length should be larger than 0"
|
||||
assert index.is_monotonic_increasing, "index should be sorted"
|
||||
|
||||
# number of dates for each instrument
|
||||
sample_count_by_insts = index.to_series().groupby(level=0).size().values
|
||||
|
||||
# start index for each instrument
|
||||
start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1)
|
||||
start_index_of_insts[0] = 0
|
||||
|
||||
# all the [start, stop) indices of features
|
||||
# features between [start, stop) will be used to predict label at `stop - 1`
|
||||
slices = []
|
||||
for cur_loc, cur_cnt in zip(start_index_of_insts, sample_count_by_insts):
|
||||
for stop in range(1, cur_cnt + 1):
|
||||
end = cur_loc + stop
|
||||
start = max(end - seq_len, 0)
|
||||
slices.append(slice(start, end))
|
||||
slices = np.array(slices, dtype="object")
|
||||
|
||||
assert len(slices) == len(index) # the i-th slice = index[i]
|
||||
|
||||
return slices
|
||||
|
||||
|
||||
def _get_date_parse_fn(target):
|
||||
"""get date parse function
|
||||
|
||||
This method is used to parse date arguments as target type.
|
||||
|
||||
Example:
|
||||
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
|
||||
get_date_parse_fn(20120101)('2017-01-01') => 20170101
|
||||
"""
|
||||
if isinstance(target, pd.Timestamp):
|
||||
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
|
||||
elif isinstance(target, int):
|
||||
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
|
||||
elif isinstance(target, str) and len(target) == 8:
|
||||
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
|
||||
else:
|
||||
_fn = lambda x: x # '2021-01-01'
|
||||
return _fn
|
||||
|
||||
|
||||
def _maybe_padding(x, seq_len, zeros=None):
|
||||
"""padding 2d <time * feature> data with zeros
|
||||
|
||||
Args:
|
||||
x (np.ndarray): 2d data with shape <time * feature>
|
||||
seq_len (int): target sequence length
|
||||
zeros (np.ndarray): zeros with shape <seq_len * feature>
|
||||
"""
|
||||
assert seq_len > 0, "sequence length should be larger than 0"
|
||||
if zeros is None:
|
||||
zeros = np.zeros((seq_len, x.shape[1]), dtype=np.float32)
|
||||
else:
|
||||
assert len(zeros) >= seq_len, "zeros matrix is not large enough for padding"
|
||||
if len(x) != seq_len: # padding zeros
|
||||
x = np.concatenate([zeros[: seq_len - len(x), : x.shape[1]], x], axis=0)
|
||||
return x
|
||||
|
||||
|
||||
class MTSDatasetH(DatasetH):
|
||||
"""Memory Augmented Time Series Dataset
|
||||
|
||||
Args:
|
||||
handler (DataHandler): data handler
|
||||
segments (dict): data split segments
|
||||
seq_len (int): time series sequence length
|
||||
horizon (int): label horizon
|
||||
num_states (int): how many memory states to be added
|
||||
memory_mode (str): memory mode (daily or sample)
|
||||
batch_size (int): batch size (<0 will use daily sampling)
|
||||
n_samples (int): number of samples in the same day
|
||||
shuffle (bool): whether shuffle data
|
||||
drop_last (bool): whether drop last batch < batch_size
|
||||
input_size (int): reshape flatten rows as this input_size (backward compatibility)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler,
|
||||
segments,
|
||||
seq_len=60,
|
||||
horizon=0,
|
||||
num_states=0,
|
||||
memory_mode="sample",
|
||||
batch_size=-1,
|
||||
n_samples=None,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
input_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
assert memory_mode in ["sample", "daily"], "unsupported memory mode"
|
||||
assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)"
|
||||
assert batch_size != 0, "invalid batch size"
|
||||
|
||||
if batch_size > 0 and n_samples is not None:
|
||||
warnings.warn("`n_samples` can only be used for daily sampling (`batch_size < 0`)")
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.horizon = horizon
|
||||
self.num_states = num_states
|
||||
self.memory_mode = memory_mode
|
||||
self.batch_size = batch_size
|
||||
self.n_samples = n_samples
|
||||
self.shuffle = shuffle
|
||||
self.drop_last = drop_last
|
||||
self.input_size = input_size
|
||||
self.params = (batch_size, n_samples, drop_last, shuffle) # for train/eval switch
|
||||
|
||||
super().__init__(handler, segments, **kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
|
||||
super().setup_data(**kwargs)
|
||||
|
||||
if handler_kwargs is not None:
|
||||
self.handler.setup_data(**handler_kwargs)
|
||||
|
||||
# pre-fetch data and change index to <code, date>
|
||||
# NOTE: we will use inplace sort to reduce memory use
|
||||
try:
|
||||
df = self.handler._learn.copy() # use copy otherwise recorder will fail
|
||||
# FIXME: currently we cannot support switching from `_learn` to `_infer` for inference
|
||||
except:
|
||||
warnings.warn("cannot access `_learn`, will load raw data")
|
||||
df = self.handler._data.copy()
|
||||
df.index = df.index.swaplevel()
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
# convert to numpy
|
||||
self._data = df["feature"].values.astype("float32")
|
||||
np.nan_to_num(self._data, copy=False) # NOTE: fillna in case users forget using the fillna processor
|
||||
self._label = df["label"].squeeze().values.astype("float32")
|
||||
self._index = df.index
|
||||
|
||||
if self.input_size is not None and self.input_size != self._data.shape[1]:
|
||||
warnings.warn("the data has different shape from input_size and the data will be reshaped")
|
||||
assert self._data.shape[1] % self.input_size == 0, "data mismatch, please check `input_size`"
|
||||
|
||||
# create batch slices
|
||||
self._batch_slices = _create_ts_slices(self._index, self.seq_len)
|
||||
|
||||
# create daily slices
|
||||
daily_slices = {date: [] for date in sorted(self._index.unique(level=1))} # sorted by date
|
||||
for i, (code, date) in enumerate(self._index):
|
||||
daily_slices[date].append(self._batch_slices[i])
|
||||
self._daily_slices = np.array(list(daily_slices.values()), dtype="object")
|
||||
self._daily_index = pd.Series(list(daily_slices.keys())) # index is the original date index
|
||||
|
||||
# add memory (sample wise and daily)
|
||||
if self.memory_mode == "sample":
|
||||
self._memory = np.zeros((len(self._data), self.num_states), dtype=np.float32)
|
||||
elif self.memory_mode == "daily":
|
||||
self._memory = np.zeros((len(self._daily_index), self.num_states), dtype=np.float32)
|
||||
else:
|
||||
raise ValueError(f"invalid memory_mode `{self.memory_mode}`")
|
||||
|
||||
# padding tensor
|
||||
self._zeros = np.zeros((self.seq_len, max(self.num_states, self._data.shape[1])), dtype=np.float32)
|
||||
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
fn = _get_date_parse_fn(self._index[0][1])
|
||||
start_date = fn(slc.start)
|
||||
end_date = fn(slc.stop)
|
||||
obj = copy.copy(self) # shallow copy
|
||||
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
|
||||
obj._data = self._data # reference (no copy)
|
||||
obj._label = self._label
|
||||
obj._index = self._index
|
||||
obj._memory = self._memory
|
||||
obj._zeros = self._zeros
|
||||
# update index for this batch
|
||||
date_index = self._index.get_level_values(1)
|
||||
obj._batch_slices = self._batch_slices[(date_index >= start_date) & (date_index <= end_date)]
|
||||
mask = (self._daily_index.values >= start_date) & (self._daily_index.values <= end_date)
|
||||
obj._daily_slices = self._daily_slices[mask]
|
||||
obj._daily_index = self._daily_index[mask]
|
||||
return obj
|
||||
|
||||
def restore_index(self, index):
|
||||
return self._index[index]
|
||||
|
||||
def restore_daily_index(self, daily_index):
|
||||
return pd.Index(self._daily_index.loc[daily_index])
|
||||
|
||||
def assign_data(self, index, vals):
|
||||
if self.num_states == 0:
|
||||
raise ValueError("cannot assign data as `num_states==0`")
|
||||
if isinstance(vals, torch.Tensor):
|
||||
vals = vals.detach().cpu().numpy()
|
||||
self._memory[index] = vals
|
||||
|
||||
def clear_memory(self):
|
||||
if self.num_states == 0:
|
||||
raise ValueError("cannot clear memory as `num_states==0`")
|
||||
self._memory[:] = 0
|
||||
|
||||
def train(self):
|
||||
"""enable traning mode"""
|
||||
self.batch_size, self.n_samples, self.drop_last, self.shuffle = self.params
|
||||
|
||||
def eval(self):
|
||||
"""enable evaluation mode"""
|
||||
self.batch_size = -1
|
||||
self.n_samples = None
|
||||
self.drop_last = False
|
||||
self.shuffle = False
|
||||
|
||||
def _get_slices(self):
|
||||
if self.batch_size < 0: # daily sampling
|
||||
slices = self._daily_slices.copy()
|
||||
batch_size = -1 * self.batch_size
|
||||
else: # normal sampling
|
||||
slices = self._batch_slices.copy()
|
||||
batch_size = self.batch_size
|
||||
return slices, batch_size
|
||||
|
||||
def __len__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
if self.drop_last:
|
||||
return len(slices) // batch_size
|
||||
return (len(slices) + batch_size - 1) // batch_size
|
||||
|
||||
def __iter__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
indices = np.arange(len(slices))
|
||||
if self.shuffle:
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[::batch_size]:
|
||||
if self.drop_last and i + batch_size > len(indices):
|
||||
break
|
||||
|
||||
data = [] # store features
|
||||
label = [] # store labels
|
||||
index = [] # store index
|
||||
state = [] # store memory states
|
||||
daily_index = [] # store daily index
|
||||
daily_count = [] # store number of samples for each day
|
||||
|
||||
for j in indices[i : i + batch_size]:
|
||||
|
||||
# normal sampling: self.batch_size > 0 => slices is a list => slices_subset is a slice
|
||||
# daily sampling: self.batch_size < 0 => slices is a nested list => slices_subset is a list
|
||||
slices_subset = slices[j]
|
||||
|
||||
# daily sampling
|
||||
# each slices_subset contains a list of slices for multiple stocks
|
||||
# NOTE: daily sampling is used in 1) eval mode, 2) train mode with self.batch_size < 0
|
||||
if self.batch_size < 0:
|
||||
|
||||
# store daily index
|
||||
idx = self._daily_index.index[j] # daily_index.index is the index of the original data
|
||||
daily_index.append(idx)
|
||||
|
||||
# store daily memory if specified
|
||||
# NOTE: daily memory always requires daily sampling (self.batch_size < 0)
|
||||
if self.memory_mode == "daily":
|
||||
slc = slice(max(idx - self.seq_len - self.horizon, 0), max(idx - self.horizon, 0))
|
||||
state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros))
|
||||
|
||||
# down-sample stocks and store count
|
||||
if self.n_samples and 0 < self.n_samples < len(slices_subset): # intraday subsample
|
||||
slices_subset = np.random.choice(slices_subset, self.n_samples, replace=False)
|
||||
daily_count.append(len(slices_subset))
|
||||
|
||||
# normal sampling
|
||||
# each slices_subset is a single slice
|
||||
# NOTE: normal sampling is used in train mode with self.batch_size > 0
|
||||
else:
|
||||
slices_subset = [slices_subset]
|
||||
|
||||
for slc in slices_subset:
|
||||
|
||||
# legacy support for Alpha360 data by `input_size`
|
||||
if self.input_size:
|
||||
data.append(self._data[slc.stop - 1].reshape(self.input_size, -1).T)
|
||||
else:
|
||||
data.append(_maybe_padding(self._data[slc], self.seq_len, self._zeros))
|
||||
|
||||
if self.memory_mode == "sample":
|
||||
state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros)[: -self.horizon])
|
||||
|
||||
label.append(self._label[slc.stop - 1])
|
||||
index.append(slc.stop - 1)
|
||||
|
||||
# end slices loop
|
||||
|
||||
# end indices batch loop
|
||||
|
||||
# concate
|
||||
data = _to_tensor(np.stack(data))
|
||||
state = _to_tensor(np.stack(state))
|
||||
label = _to_tensor(np.stack(label))
|
||||
index = np.array(index)
|
||||
daily_index = np.array(daily_index)
|
||||
daily_count = np.array(daily_count)
|
||||
|
||||
# yield -> generator
|
||||
yield {
|
||||
"data": data,
|
||||
"label": label,
|
||||
"state": state,
|
||||
"index": index,
|
||||
"daily_index": daily_index,
|
||||
"daily_count": daily_count,
|
||||
}
|
||||
|
||||
# end indice loop
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_cls_kwargs
|
||||
from ...utils import get_callable_kwargs
|
||||
from ...data.dataset import processor as processor_module
|
||||
from ...log import TimeInspector
|
||||
from inspect import getfullargspec
|
||||
@@ -14,7 +14,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
if not isinstance(p, Processor):
|
||||
klass, pkwargs = get_cls_kwargs(p, processor_module)
|
||||
klass, pkwargs = get_callable_kwargs(p, processor_module)
|
||||
args = getfullargspec(klass).args
|
||||
if "fit_start_time" in args and "fit_end_time" in args:
|
||||
assert (
|
||||
@@ -26,8 +26,10 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
# FIXME: the `module_path` parameter is missed.
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
proc_config = {"class": klass.__name__, "kwargs": pkwargs}
|
||||
if isinstance(p, dict) and "module_path" in p:
|
||||
proc_config["module_path"] = p["module_path"]
|
||||
new_l.append(proc_config)
|
||||
else:
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
@@ -56,6 +58,7 @@ class Alpha360(DataHandlerLP):
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
@@ -70,6 +73,7 @@ class Alpha360(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"inst_processor": inst_processor,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -142,6 +146,7 @@ class Alpha158(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
@@ -156,6 +161,7 @@ class Alpha158(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"inst_processor": inst_processor,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
|
||||
@@ -53,7 +53,6 @@ class GATs(Model):
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
with_pretrain=True,
|
||||
model_path=None,
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
@@ -76,7 +75,6 @@ class GATs(Model):
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
@@ -94,7 +92,6 @@ class GATs(Model):
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nwith_pretrain : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
@@ -110,7 +107,6 @@ class GATs(Model):
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
with_pretrain,
|
||||
model_path,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
@@ -253,24 +249,22 @@ class GATs(Model):
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.with_pretrain:
|
||||
if self.model_path == None:
|
||||
raise ValueError("the path of the pretrained model should be given first!")
|
||||
self.logger.info("Loading pretrained model...")
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
|
||||
@@ -27,10 +27,9 @@ from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
class DailyBatchSampler(Sampler):
|
||||
def __init__(self, data_source):
|
||||
|
||||
self.data_source = data_source
|
||||
self.data = self.data_source.data.loc[self.data_source.get_index()]
|
||||
self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch
|
||||
# calculate number of samples in each batch
|
||||
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values
|
||||
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
|
||||
self.daily_index[0] = 0
|
||||
|
||||
@@ -72,7 +71,6 @@ class GATs(Model):
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
with_pretrain=True,
|
||||
model_path=None,
|
||||
optimizer="adam",
|
||||
GPU="0",
|
||||
@@ -96,7 +94,6 @@ class GATs(Model):
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
@@ -115,7 +112,6 @@ class GATs(Model):
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nwith_pretrain : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
@@ -131,7 +127,6 @@ class GATs(Model):
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
with_pretrain,
|
||||
model_path,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
@@ -270,28 +265,22 @@ class GATs(Model):
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.with_pretrain:
|
||||
if self.model_path == None:
|
||||
raise ValueError("the path of the pretrained model should be given first!")
|
||||
self.logger.info("Loading pretrained model...")
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
|
||||
)
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers
|
||||
)
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
|
||||
331
qlib/contrib/model/pytorch_localformer.py
Normal file
331
qlib/contrib/model/pytorch_localformer.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
import math
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from torch.nn.modules.container import ModuleList
|
||||
|
||||
# qrun examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml ”
|
||||
|
||||
|
||||
class LocalformerModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat: int = 20,
|
||||
d_model: int = 64,
|
||||
batch_size: int = 2048,
|
||||
nhead: int = 2,
|
||||
num_layers: int = 2,
|
||||
dropout: float = 0,
|
||||
n_epochs=100,
|
||||
lr=0.0001,
|
||||
metric="",
|
||||
early_stop=5,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
reg=1e-3,
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.reg = reg
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_jobs = n_jobs
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.logger = get_module_logger("TransformerModel")
|
||||
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred.float() - label.float()) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, max_len=1000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
# [T, N, F]
|
||||
return x + self.pe[: x.size(0), :]
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
class LocalformerEncoder(nn.Module):
|
||||
__constants__ = ["norm"]
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, d_model):
|
||||
super(LocalformerEncoder, self).__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.conv = _get_clones(nn.Conv1d(d_model, d_model, 3, 1, 1), num_layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(self, src, mask):
|
||||
output = src
|
||||
out = src
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
# [T, N, F] --> [N, T, F] --> [N, F, T]
|
||||
out = output.transpose(1, 0).transpose(2, 1)
|
||||
out = self.conv[i](out).transpose(2, 1).transpose(1, 0)
|
||||
|
||||
output = mod(output + out, src_mask=mask)
|
||||
|
||||
return output + out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
|
||||
super(Transformer, self).__init__()
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_model,
|
||||
hidden_size=d_model,
|
||||
num_layers=num_layers,
|
||||
batch_first=False,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.feature_layer = nn.Linear(d_feat, d_model)
|
||||
self.pos_encoder = PositionalEncoding(d_model)
|
||||
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
|
||||
self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model)
|
||||
self.decoder_layer = nn.Linear(d_model, 1)
|
||||
self.device = device
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, src):
|
||||
# src [N, F*T] --> [N, T, F]
|
||||
src = src.reshape(len(src), self.d_feat, -1).permute(0, 2, 1)
|
||||
src = self.feature_layer(src)
|
||||
|
||||
# src [N, T, F] --> [T, N, F], [60, 512, 8]
|
||||
src = src.transpose(1, 0) # not batch first
|
||||
|
||||
mask = None
|
||||
|
||||
src = self.pos_encoder(src)
|
||||
output = self.transformer_encoder(src, mask) # [60, 512, 8]
|
||||
|
||||
output, _ = self.rnn(output)
|
||||
|
||||
# [T, N, F] --> [N, T*F]
|
||||
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
|
||||
|
||||
return output.squeeze()
|
||||
308
qlib/contrib/model/pytorch_localformer_ts.py
Normal file
308
qlib/contrib/model/pytorch_localformer_ts.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
import math
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from torch.nn.modules.container import ModuleList
|
||||
|
||||
|
||||
class LocalformerModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat: int = 20,
|
||||
d_model: int = 64,
|
||||
batch_size: int = 8192,
|
||||
nhead: int = 2,
|
||||
num_layers: int = 2,
|
||||
dropout: float = 0,
|
||||
n_epochs=100,
|
||||
lr=0.0001,
|
||||
metric="",
|
||||
early_stop=5,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
reg=1e-3,
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.reg = reg
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_jobs = n_jobs
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.logger = get_module_logger("TransformerModel")
|
||||
self.logger.info(
|
||||
"Improved Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred.float() - label.float()) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, data_loader):
|
||||
|
||||
self.model.train()
|
||||
|
||||
for data in data_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.model(feature.float()) # .float()
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_loader):
|
||||
|
||||
self.model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
for data in data_loader:
|
||||
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature.float()) # .float()
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(train_loader)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(train_loader)
|
||||
val_loss, val_score = self.test_epoch(valid_loader)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.model.eval()
|
||||
preds = []
|
||||
|
||||
for data in test_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=dl_test.get_index())
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, max_len=1000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
# [T, N, F]
|
||||
return x + self.pe[: x.size(0), :]
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
class LocalformerEncoder(nn.Module):
|
||||
__constants__ = ["norm"]
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, d_model):
|
||||
super(LocalformerEncoder, self).__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.conv = _get_clones(nn.Conv1d(d_model, d_model, 3, 1, 1), num_layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(self, src, mask):
|
||||
output = src
|
||||
out = src
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
# [T, N, F] --> [N, T, F] --> [N, F, T]
|
||||
out = output.transpose(1, 0).transpose(2, 1)
|
||||
out = self.conv[i](out).transpose(2, 1).transpose(1, 0)
|
||||
|
||||
output = mod(output + out, src_mask=mask)
|
||||
|
||||
return output + out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
|
||||
super(Transformer, self).__init__()
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_model,
|
||||
hidden_size=d_model,
|
||||
num_layers=num_layers,
|
||||
batch_first=False,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.feature_layer = nn.Linear(d_feat, d_model)
|
||||
self.pos_encoder = PositionalEncoding(d_model)
|
||||
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
|
||||
self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model)
|
||||
self.decoder_layer = nn.Linear(d_model, 1)
|
||||
self.device = device
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, src):
|
||||
# src [N, T, F], [512, 60, 6]
|
||||
src = self.feature_layer(src) # [512, 60, 8]
|
||||
|
||||
# src [N, T, F] --> [T, N, F], [60, 512, 8]
|
||||
src = src.transpose(1, 0) # not batch first
|
||||
|
||||
mask = None
|
||||
|
||||
src = self.pos_encoder(src)
|
||||
output = self.transformer_encoder(src, mask) # [60, 512, 8]
|
||||
|
||||
output, _ = self.rnn(output)
|
||||
|
||||
# [T, N, F] --> [N, T*F]
|
||||
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
|
||||
|
||||
return output.squeeze()
|
||||
@@ -297,7 +297,7 @@ class DNNModelPytorch(Model):
|
||||
_model_path = os.path.join(model_dir, _model_name)
|
||||
# Load model
|
||||
self.dnn_model.load_state_dict(torch.load(_model_path))
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
|
||||
@@ -564,7 +564,7 @@ class FeatureTransformer(nn.Module):
|
||||
self.shared = None
|
||||
self.independ = nn.ModuleList()
|
||||
if first:
|
||||
self.independ.append(GLU(inp, out_dim, vbs=vbs))
|
||||
self.independ.append(GLU(inp_dim, out_dim, vbs=vbs))
|
||||
for x in range(first, n_ind):
|
||||
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
|
||||
self.scale = float(np.sqrt(0.5))
|
||||
|
||||
420
qlib/contrib/model/pytorch_tcts.py
Normal file
420
qlib/contrib/model/pytorch_tcts.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
import random
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class TCTS(Model):
|
||||
"""TCTS Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
fore_optimizer="adam",
|
||||
weight_optimizer="adam",
|
||||
output_dim=5,
|
||||
fore_lr=5e-7,
|
||||
weight_lr=5e-7,
|
||||
steps=3,
|
||||
GPU=0,
|
||||
seed=0,
|
||||
target_label=0,
|
||||
lowest_valid_performance=0.993,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCTS")
|
||||
self.logger.info("TCTS pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
self.output_dim = output_dim
|
||||
self.fore_lr = fore_lr
|
||||
self.weight_lr = weight_lr
|
||||
self.steps = steps
|
||||
self.target_label = target_label
|
||||
self.lowest_valid_performance = lowest_valid_performance
|
||||
self._fore_optimizer = fore_optimizer
|
||||
self._weight_optimizer = weight_optimizer
|
||||
|
||||
self.logger.info(
|
||||
"TCTS parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
batch_size,
|
||||
early_stop,
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
def loss_fn(self, pred, label, weight):
|
||||
|
||||
loc = torch.argmax(weight, 1)
|
||||
loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def train_epoch(self, x_train, y_train, x_valid, y_valid):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
init_fore_model = copy.deepcopy(self.fore_model)
|
||||
for p in init_fore_model.parameters():
|
||||
p.init_fore_model = False
|
||||
|
||||
self.fore_model.train()
|
||||
self.weight_model.train()
|
||||
|
||||
for p in self.weight_model.parameters():
|
||||
p.requires_grad = False
|
||||
for p in self.fore_model.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
for i in range(self.steps):
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
init_pred = init_fore_model(feature)
|
||||
pred = self.fore_model(feature)
|
||||
|
||||
dis = init_pred - label.transpose(0, 1)
|
||||
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, init_pred.view(-1, 1)), 1)
|
||||
weight = self.weight_model(weight_feature)
|
||||
|
||||
loss = self.loss_fn(pred, label, weight) # hard
|
||||
|
||||
self.fore_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.fore_model.parameters(), 3.0)
|
||||
self.fore_optimizer.step()
|
||||
|
||||
x_valid_values = x_valid.values
|
||||
y_valid_values = np.squeeze(y_valid.values)
|
||||
|
||||
indices = np.arange(len(x_valid_values))
|
||||
np.random.shuffle(indices)
|
||||
for p in self.weight_model.parameters():
|
||||
p.requires_grad = True
|
||||
for p in self.fore_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# fix forecasting model and valid weight model
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.fore_model(feature)
|
||||
dis = pred - label.transpose(0, 1)
|
||||
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1)), 1)
|
||||
weight = self.weight_model(weight_feature)
|
||||
loc = torch.argmax(weight, 1)
|
||||
valid_loss = torch.mean((pred - label[:, 0]) ** 2)
|
||||
loss = torch.mean(-valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))
|
||||
|
||||
self.weight_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.weight_model.parameters(), 3.0)
|
||||
self.weight_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.fore_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.fore_model(feature)
|
||||
loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2)
|
||||
losses.append(loss.item())
|
||||
|
||||
return np.mean(losses)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = get_or_create_path(save_path)
|
||||
best_loss = np.inf
|
||||
while best_loss > self.lowest_valid_performance:
|
||||
if best_loss < np.inf:
|
||||
print("Failed! Start retraining.")
|
||||
self.seed = random.randint(0, 1000) # reset random seed
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
best_loss = self.training(
|
||||
x_train, y_train, x_valid, y_valid, x_test, y_test, verbose=verbose, save_path=save_path
|
||||
)
|
||||
|
||||
def training(
|
||||
self,
|
||||
x_train,
|
||||
y_train,
|
||||
x_valid,
|
||||
y_valid,
|
||||
x_test,
|
||||
y_test,
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
self.fore_model = GRUModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.weight_model = MLPModel(
|
||||
d_feat=360 + 2 * self.output_dim + 1,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
output_dim=self.output_dim,
|
||||
)
|
||||
if self._fore_optimizer.lower() == "adam":
|
||||
self.fore_optimizer = optim.Adam(self.fore_model.parameters(), lr=self.fore_lr)
|
||||
elif self._fore_optimizer.lower() == "gd":
|
||||
self.fore_optimizer = optim.SGD(self.fore_model.parameters(), lr=self.fore_lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(self._fore_optimizer))
|
||||
if self._weight_optimizer.lower() == "adam":
|
||||
self.weight_optimizer = optim.Adam(self.weight_model.parameters(), lr=self.weight_lr)
|
||||
elif self._weight_optimizer.lower() == "gd":
|
||||
self.weight_optimizer = optim.SGD(self.weight_model.parameters(), lr=self.weight_lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(self._weight_optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.fore_model.to(self.device)
|
||||
self.weight_model.to(self.device)
|
||||
|
||||
best_loss = np.inf
|
||||
best_epoch = 0
|
||||
stop_round = 0
|
||||
fore_best_param = copy.deepcopy(self.fore_optimizer.state_dict())
|
||||
weight_best_param = copy.deepcopy(self.weight_optimizer.state_dict())
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
print("Epoch:", epoch)
|
||||
|
||||
print("training...")
|
||||
self.train_epoch(x_train, y_train, x_valid, y_valid)
|
||||
print("evaluating...")
|
||||
val_loss = self.test_epoch(x_valid, y_valid)
|
||||
test_loss = self.test_epoch(x_test, y_test)
|
||||
|
||||
if verbose:
|
||||
print("valid %.6f, test %.6f" % (val_loss, test_loss))
|
||||
|
||||
if val_loss < best_loss:
|
||||
best_loss = val_loss
|
||||
stop_round = 0
|
||||
best_epoch = epoch
|
||||
torch.save(copy.deepcopy(self.fore_model.state_dict()), save_path + "_fore_model.bin")
|
||||
torch.save(copy.deepcopy(self.weight_model.state_dict()), save_path + "_weight_model.bin")
|
||||
|
||||
else:
|
||||
stop_round += 1
|
||||
if stop_round >= self.early_stop:
|
||||
print("early stop")
|
||||
break
|
||||
|
||||
print("best loss:", best_loss, "@", best_epoch)
|
||||
best_param = torch.load(save_path + "_fore_model.bin")
|
||||
self.fore_model.load_state_dict(best_param)
|
||||
best_param = torch.load(save_path + "_weight_model.bin")
|
||||
self.weight_model.load_state_dict(best_param)
|
||||
self.fitted = True
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return best_loss
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
self.fore_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.fore_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.fore_model(x_batch).detach().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class MLPModel(nn.Module):
|
||||
def __init__(self, d_feat, hidden_size=256, num_layers=3, dropout=0.0, output_dim=1):
|
||||
super().__init__()
|
||||
|
||||
self.mlp = nn.Sequential()
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
|
||||
for i in range(num_layers):
|
||||
if i > 0:
|
||||
self.mlp.add_module("drop_%d" % i, nn.Dropout(dropout))
|
||||
self.mlp.add_module("fc_%d" % i, nn.Linear(d_feat if i == 0 else hidden_size, hidden_size))
|
||||
self.mlp.add_module("relu_%d" % i, nn.ReLU())
|
||||
|
||||
self.mlp.add_module("fc_out", nn.Linear(hidden_size, output_dim))
|
||||
|
||||
def forward(self, x):
|
||||
# feature
|
||||
# [N, F]
|
||||
out = self.mlp(x).squeeze()
|
||||
out = self.softmax(out)
|
||||
return out
|
||||
|
||||
|
||||
class GRUModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.fc_out = nn.Linear(hidden_size, 1)
|
||||
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, x):
|
||||
# x: [N, F*T]
|
||||
x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x = x.permute(0, 2, 1) # [N, T, F]
|
||||
out, _ = self.rnn(x)
|
||||
return self.fc_out(out[:, -1, :]).squeeze()
|
||||
944
qlib/contrib/model/pytorch_tra.py
Normal file
944
qlib/contrib/model/pytorch_tra.py
Normal file
@@ -0,0 +1,944 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import io
|
||||
import os
|
||||
import copy
|
||||
import math
|
||||
import json
|
||||
import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except:
|
||||
SummaryWriter = None
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.contrib.data.dataset import MTSDatasetH
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
class TRAModel(Model):
|
||||
"""
|
||||
TRA Model
|
||||
|
||||
Args:
|
||||
model_config (dict): model config (will be used by RNN or Transformer)
|
||||
tra_config (dict): TRA config (will be used by TRA)
|
||||
model_type (str): which backbone model to use (RNN/Transformer)
|
||||
lr (float): learning rate
|
||||
n_epochs (int): number of total epochs
|
||||
early_stop (int): early stop when performance not improved at this step
|
||||
update_freq (int): gradient update frequency
|
||||
max_steps_per_epoch (int): maximum number of steps in one epoch
|
||||
lamb (float): regularization parameter
|
||||
rho (float): exponential decay rate for `lamb`
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
seed (int): random seed
|
||||
logdir (str): local log directory
|
||||
eval_train (bool): whether evaluate train set between epochs
|
||||
eval_test (bool): whether evaluate test set between epochs
|
||||
pretrain (bool): whether pretrain the backbone model before training TRA.
|
||||
Note that only TRA will be optimized after pretraining
|
||||
init_state (str): model init state path
|
||||
freeze_model (bool): whether freeze backbone model parameters
|
||||
freeze_predictors (bool): whether freeze predictors parameters
|
||||
transport_method (str): transport method, can be none/router/oracle
|
||||
memory_mode (str): memory mode, the same argument for MTSDatasetH
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
tra_config,
|
||||
model_type="RNN",
|
||||
lr=1e-3,
|
||||
n_epochs=500,
|
||||
early_stop=50,
|
||||
update_freq=1,
|
||||
max_steps_per_epoch=None,
|
||||
lamb=0.0,
|
||||
rho=0.99,
|
||||
alpha=1.0,
|
||||
seed=0,
|
||||
logdir=None,
|
||||
eval_train=False,
|
||||
eval_test=False,
|
||||
pretrain=False,
|
||||
init_state=None,
|
||||
reset_router=False,
|
||||
freeze_model=False,
|
||||
freeze_predictors=False,
|
||||
transport_method="none",
|
||||
memory_mode="sample",
|
||||
):
|
||||
|
||||
self.logger = get_module_logger("TRA")
|
||||
|
||||
assert memory_mode in ["sample", "daily"], "invalid memory mode"
|
||||
assert transport_method in ["none", "router", "oracle"], f"invalid transport method {transport_method}"
|
||||
assert transport_method == "none" or tra_config["num_states"] > 1, "optimal transport requires `num_states` > 1"
|
||||
assert (
|
||||
memory_mode != "daily" or tra_config["src_info"] == "TPE"
|
||||
), "daily transport can only support TPE as `src_info`"
|
||||
|
||||
if transport_method == "router" and not eval_train:
|
||||
self.logger.warning("`eval_train` will be ignored when using TRA.router")
|
||||
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
self.model_config = model_config
|
||||
self.tra_config = tra_config
|
||||
self.model_type = model_type
|
||||
self.lr = lr
|
||||
self.n_epochs = n_epochs
|
||||
self.early_stop = early_stop
|
||||
self.update_freq = update_freq
|
||||
self.max_steps_per_epoch = max_steps_per_epoch
|
||||
self.lamb = lamb
|
||||
self.rho = rho
|
||||
self.alpha = alpha
|
||||
self.seed = seed
|
||||
self.logdir = logdir
|
||||
self.eval_train = eval_train
|
||||
self.eval_test = eval_test
|
||||
self.pretrain = pretrain
|
||||
self.init_state = init_state
|
||||
self.reset_router = reset_router
|
||||
self.freeze_model = freeze_model
|
||||
self.freeze_predictors = freeze_predictors
|
||||
self.transport_method = transport_method
|
||||
self.use_daily_transport = memory_mode == "daily"
|
||||
self.transport_fn = transport_daily if self.use_daily_transport else transport_sample
|
||||
|
||||
self._writer = None
|
||||
if self.logdir is not None:
|
||||
if os.path.exists(self.logdir):
|
||||
self.logger.warning(f"logdir {self.logdir} is not empty")
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
if SummaryWriter is not None:
|
||||
self._writer = SummaryWriter(log_dir=self.logdir)
|
||||
|
||||
self._init_model()
|
||||
|
||||
def _init_model(self):
|
||||
|
||||
self.logger.info("init TRAModel...")
|
||||
|
||||
self.model = eval(self.model_type)(**self.model_config).to(device)
|
||||
print(self.model)
|
||||
|
||||
self.tra = TRA(self.model.output_size, **self.tra_config).to(device)
|
||||
print(self.tra)
|
||||
|
||||
if self.init_state:
|
||||
self.logger.warning(f"load state dict from `init_state`")
|
||||
state_dict = torch.load(self.init_state, map_location="cpu")
|
||||
self.model.load_state_dict(state_dict["model"])
|
||||
res = load_state_dict_unsafe(self.tra, state_dict["tra"])
|
||||
self.logger.warning(str(res))
|
||||
|
||||
if self.reset_router:
|
||||
self.logger.warning(f"reset TRA.router parameters")
|
||||
self.tra.fc.reset_parameters()
|
||||
self.tra.router.reset_parameters()
|
||||
|
||||
if self.freeze_model:
|
||||
self.logger.warning(f"freeze model parameters")
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
if self.freeze_predictors:
|
||||
self.logger.warning(f"freeze TRA.predictors parameters")
|
||||
for param in self.tra.predictors.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters() if p.requires_grad]))
|
||||
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters() if p.requires_grad]))
|
||||
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
|
||||
|
||||
self.fitted = False
|
||||
self.global_step = -1
|
||||
|
||||
def train_epoch(self, epoch, data_set, is_pretrain=False):
|
||||
|
||||
self.model.train()
|
||||
self.tra.train()
|
||||
data_set.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
P_all = []
|
||||
prob_all = []
|
||||
choice_all = []
|
||||
max_steps = len(data_set)
|
||||
if self.max_steps_per_epoch is not None:
|
||||
if epoch == 0 and self.max_steps_per_epoch < max_steps:
|
||||
self.logger.info(f"max steps updated from {max_steps} to {self.max_steps_per_epoch}")
|
||||
max_steps = min(self.max_steps_per_epoch, max_steps)
|
||||
|
||||
cur_step = 0
|
||||
total_loss = 0
|
||||
total_count = 0
|
||||
for batch in tqdm(data_set, total=max_steps):
|
||||
cur_step += 1
|
||||
if cur_step > max_steps:
|
||||
break
|
||||
|
||||
if not is_pretrain:
|
||||
self.global_step += 1
|
||||
|
||||
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
|
||||
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
|
||||
|
||||
with torch.set_grad_enabled(not self.freeze_model):
|
||||
hidden = self.model(data)
|
||||
|
||||
all_preds, choice, prob = self.tra(hidden, state)
|
||||
|
||||
if is_pretrain or self.transport_method != "none":
|
||||
# NOTE: use oracle transport for pre-training
|
||||
loss, pred, L, P = self.transport_fn(
|
||||
all_preds,
|
||||
label,
|
||||
choice,
|
||||
prob,
|
||||
state.mean(dim=1),
|
||||
count,
|
||||
self.transport_method if not is_pretrain else "oracle",
|
||||
self.alpha,
|
||||
training=True,
|
||||
)
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
if self.use_daily_transport: # only save for daily transport
|
||||
P_all.append(pd.DataFrame(P.detach().cpu().numpy(), index=index))
|
||||
prob_all.append(pd.DataFrame(prob.detach().cpu().numpy(), index=index))
|
||||
choice_all.append(pd.DataFrame(choice.detach().cpu().numpy(), index=index))
|
||||
decay = self.rho ** (self.global_step // 100) # decay every 100 steps
|
||||
lamb = 0 if is_pretrain else self.lamb * decay
|
||||
reg = prob.log().mul(P).sum(dim=1).mean() # train router to predict OT assignment
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/router_loss", -reg.item(), self.global_step)
|
||||
self._writer.add_scalar("training/reg_loss", loss.item(), self.global_step)
|
||||
self._writer.add_scalar("training/lamb", lamb, self.global_step)
|
||||
if not self.use_daily_transport:
|
||||
P_mean = P.mean(axis=0).detach()
|
||||
self._writer.add_scalar("training/P", P_mean.max() / P_mean.min(), self.global_step)
|
||||
loss = loss - lamb * reg
|
||||
else:
|
||||
pred = all_preds.mean(dim=1)
|
||||
loss = loss_fn(pred, label)
|
||||
|
||||
(loss / self.update_freq).backward()
|
||||
if cur_step % self.update_freq == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/total_loss", loss.item(), self.global_step)
|
||||
|
||||
total_loss += loss.item()
|
||||
total_count += 1
|
||||
|
||||
if self.use_daily_transport and len(P_all):
|
||||
P_all = pd.concat(P_all, axis=0)
|
||||
prob_all = pd.concat(prob_all, axis=0)
|
||||
choice_all = pd.concat(choice_all, axis=0)
|
||||
P_all.index = data_set.restore_daily_index(P_all.index)
|
||||
prob_all.index = P_all.index
|
||||
choice_all.index = P_all.index
|
||||
if not is_pretrain:
|
||||
self._writer.add_image("P", plot(P_all), epoch, dataformats="HWC")
|
||||
self._writer.add_image("prob", plot(prob_all), epoch, dataformats="HWC")
|
||||
self._writer.add_image("choice", plot(choice_all), epoch, dataformats="HWC")
|
||||
|
||||
total_loss /= total_count
|
||||
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/loss", total_loss, epoch)
|
||||
|
||||
return total_loss
|
||||
|
||||
def test_epoch(self, epoch, data_set, return_pred=False, prefix="test", is_pretrain=False):
|
||||
|
||||
self.model.eval()
|
||||
self.tra.eval()
|
||||
data_set.eval()
|
||||
|
||||
preds = []
|
||||
probs = []
|
||||
P_all = []
|
||||
metrics = []
|
||||
for batch in tqdm(data_set):
|
||||
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
|
||||
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
|
||||
|
||||
with torch.no_grad():
|
||||
hidden = self.model(data)
|
||||
all_preds, choice, prob = self.tra(hidden, state)
|
||||
|
||||
if is_pretrain or self.transport_method != "none":
|
||||
loss, pred, L, P = self.transport_fn(
|
||||
all_preds,
|
||||
label,
|
||||
choice,
|
||||
prob,
|
||||
state.mean(dim=1),
|
||||
count,
|
||||
self.transport_method if not is_pretrain else "oracle",
|
||||
self.alpha,
|
||||
training=False,
|
||||
)
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
if P is not None and return_pred:
|
||||
P_all.append(pd.DataFrame(P.cpu().numpy(), index=index))
|
||||
else:
|
||||
pred = all_preds.mean(dim=1)
|
||||
|
||||
X = np.c_[pred.cpu().numpy(), label.cpu().numpy(), all_preds.cpu().numpy()]
|
||||
columns = ["score", "label"] + ["score_%d" % d for d in range(all_preds.shape[1])]
|
||||
pred = pd.DataFrame(X, index=batch["index"], columns=columns)
|
||||
|
||||
metrics.append(evaluate(pred))
|
||||
|
||||
if return_pred:
|
||||
preds.append(pred)
|
||||
if prob is not None:
|
||||
columns = ["prob_%d" % d for d in range(all_preds.shape[1])]
|
||||
probs.append(pd.DataFrame(prob.cpu().numpy(), index=index, columns=columns))
|
||||
|
||||
metrics = pd.DataFrame(metrics)
|
||||
metrics = {
|
||||
"MSE": metrics.MSE.mean(),
|
||||
"MAE": metrics.MAE.mean(),
|
||||
"IC": metrics.IC.mean(),
|
||||
"ICIR": metrics.IC.mean() / metrics.IC.std(),
|
||||
}
|
||||
|
||||
if self._writer is not None and epoch >= 0 and not is_pretrain:
|
||||
for key, value in metrics.items():
|
||||
self._writer.add_scalar(prefix + "/" + key, value, epoch)
|
||||
|
||||
if return_pred:
|
||||
preds = pd.concat(preds, axis=0)
|
||||
preds.index = data_set.restore_index(preds.index)
|
||||
preds.index = preds.index.swaplevel()
|
||||
preds.sort_index(inplace=True)
|
||||
|
||||
if probs:
|
||||
probs = pd.concat(probs, axis=0)
|
||||
if self.use_daily_transport:
|
||||
probs.index = data_set.restore_daily_index(probs.index)
|
||||
else:
|
||||
probs.index = data_set.restore_index(probs.index)
|
||||
probs.index = probs.index.swaplevel()
|
||||
probs.sort_index(inplace=True)
|
||||
|
||||
if len(P_all):
|
||||
P_all = pd.concat(P_all, axis=0)
|
||||
if self.use_daily_transport:
|
||||
P_all.index = data_set.restore_daily_index(P_all.index)
|
||||
else:
|
||||
P_all.index = data_set.restore_index(P_all.index)
|
||||
P_all.index = P_all.index.swaplevel()
|
||||
P_all.sort_index(inplace=True)
|
||||
|
||||
return metrics, preds, probs, P_all
|
||||
|
||||
def _fit(self, train_set, valid_set, test_set, evals_result, is_pretrain=True):
|
||||
|
||||
best_score = -1
|
||||
best_epoch = 0
|
||||
stop_rounds = 0
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
# train
|
||||
if not is_pretrain and self.transport_method != "none":
|
||||
self.logger.info("init memory...")
|
||||
self.test_epoch(-1, train_set)
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
self.logger.info("Epoch %d:", epoch)
|
||||
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(epoch, train_set, is_pretrain=is_pretrain)
|
||||
|
||||
self.logger.info("evaluating...")
|
||||
# NOTE: during evaluating, the whole memory will be refreshed
|
||||
if not is_pretrain and (self.transport_method == "router" or self.eval_train):
|
||||
train_set.clear_memory() # NOTE: clear the shared memory
|
||||
train_metrics = self.test_epoch(epoch, train_set, is_pretrain=is_pretrain, prefix="train")[0]
|
||||
evals_result["train"].append(train_metrics)
|
||||
self.logger.info("train metrics: %s" % train_metrics)
|
||||
|
||||
valid_metrics = self.test_epoch(epoch, valid_set, is_pretrain=is_pretrain, prefix="valid")[0]
|
||||
evals_result["valid"].append(valid_metrics)
|
||||
self.logger.info("valid metrics: %s" % valid_metrics)
|
||||
|
||||
if self.eval_test:
|
||||
test_metrics = self.test_epoch(epoch, test_set, is_pretrain=is_pretrain, prefix="test")[0]
|
||||
evals_result["test"].append(test_metrics)
|
||||
self.logger.info("test metrics: %s" % test_metrics)
|
||||
|
||||
if valid_metrics["IC"] > best_score:
|
||||
best_score = valid_metrics["IC"]
|
||||
stop_rounds = 0
|
||||
best_epoch = epoch
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
if self.logdir is not None:
|
||||
torch.save(best_params, self.logdir + "/model.bin")
|
||||
else:
|
||||
stop_rounds += 1
|
||||
if stop_rounds >= self.early_stop:
|
||||
self.logger.info("early stop @ %s" % epoch)
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_params["model"])
|
||||
self.tra.load_state_dict(best_params["tra"])
|
||||
|
||||
return best_score
|
||||
|
||||
def fit(self, dataset, evals_result=dict()):
|
||||
|
||||
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
|
||||
|
||||
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
|
||||
|
||||
self.fitted = True
|
||||
self.global_step = -1
|
||||
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
evals_result["test"] = []
|
||||
|
||||
if self.pretrain:
|
||||
self.logger.info("pretraining...")
|
||||
self.optimizer = optim.Adam(
|
||||
list(self.model.parameters()) + list(self.tra.predictors.parameters()), lr=self.lr
|
||||
)
|
||||
self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True)
|
||||
|
||||
# reset optimizer
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
|
||||
|
||||
self.logger.info("training...")
|
||||
best_score = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=False)
|
||||
|
||||
self.logger.info("inference")
|
||||
train_metrics, train_preds, train_probs, train_P = self.test_epoch(-1, train_set, return_pred=True)
|
||||
self.logger.info("train metrics: %s" % train_metrics)
|
||||
|
||||
valid_metrics, valid_preds, valid_probs, valid_P = self.test_epoch(-1, valid_set, return_pred=True)
|
||||
self.logger.info("valid metrics: %s" % valid_metrics)
|
||||
|
||||
test_metrics, test_preds, test_probs, test_P = self.test_epoch(-1, test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % test_metrics)
|
||||
|
||||
if self.logdir:
|
||||
self.logger.info("save model & pred to local directory")
|
||||
|
||||
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
|
||||
self.logdir + "/logs.csv", index=False
|
||||
)
|
||||
|
||||
torch.save({"model": self.model.state_dict(), "tra": self.tra.state_dict()}, self.logdir + "/model.bin")
|
||||
|
||||
train_preds.to_pickle(self.logdir + "/train_pred.pkl")
|
||||
valid_preds.to_pickle(self.logdir + "/valid_pred.pkl")
|
||||
test_preds.to_pickle(self.logdir + "/test_pred.pkl")
|
||||
|
||||
if len(train_probs):
|
||||
train_probs.to_pickle(self.logdir + "/train_prob.pkl")
|
||||
valid_probs.to_pickle(self.logdir + "/valid_prob.pkl")
|
||||
test_probs.to_pickle(self.logdir + "/test_prob.pkl")
|
||||
|
||||
if len(train_P):
|
||||
train_P.to_pickle(self.logdir + "/train_P.pkl")
|
||||
valid_P.to_pickle(self.logdir + "/valid_P.pkl")
|
||||
test_P.to_pickle(self.logdir + "/test_P.pkl")
|
||||
|
||||
info = {
|
||||
"config": {
|
||||
"model_config": self.model_config,
|
||||
"tra_config": self.tra_config,
|
||||
"model_type": self.model_type,
|
||||
"lr": self.lr,
|
||||
"n_epochs": self.n_epochs,
|
||||
"early_stop": self.early_stop,
|
||||
"max_steps_per_epoch": self.max_steps_per_epoch,
|
||||
"lamb": self.lamb,
|
||||
"rho": self.rho,
|
||||
"alpha": self.alpha,
|
||||
"seed": self.seed,
|
||||
"logdir": self.logdir,
|
||||
"pretrain": self.pretrain,
|
||||
"init_state": self.init_state,
|
||||
"transport_method": self.transport_method,
|
||||
"use_daily_transport": self.use_daily_transport,
|
||||
},
|
||||
"best_eval_metric": -best_score, # NOTE: -1 for minimize
|
||||
"metrics": {"train": train_metrics, "valid": valid_metrics, "test": test_metrics},
|
||||
}
|
||||
with open(self.logdir + "/info.json", "w") as f:
|
||||
json.dump(info, f)
|
||||
|
||||
def predict(self, dataset, segment="test"):
|
||||
|
||||
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
|
||||
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
test_set = dataset.prepare(segment)
|
||||
|
||||
metrics, preds, _, _ = self.test_epoch(-1, test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % metrics)
|
||||
|
||||
return preds
|
||||
|
||||
|
||||
class RNN(nn.Module):
|
||||
|
||||
"""RNN Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of hidden layers
|
||||
rnn_arch (str): rnn architecture
|
||||
use_attn (bool): whether use attention layer.
|
||||
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
|
||||
dropout (float): dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
rnn_arch="GRU",
|
||||
use_attn=True,
|
||||
dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.rnn_arch = rnn_arch
|
||||
self.use_attn = use_attn
|
||||
|
||||
if hidden_size < input_size:
|
||||
# compression
|
||||
self.input_proj = nn.Linear(input_size, hidden_size)
|
||||
else:
|
||||
self.input_proj = None
|
||||
|
||||
self.rnn = getattr(nn, rnn_arch)(
|
||||
input_size=min(input_size, hidden_size),
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
if self.use_attn:
|
||||
self.W = nn.Linear(hidden_size, hidden_size)
|
||||
self.u = nn.Linear(hidden_size, 1, bias=False)
|
||||
self.output_size = hidden_size * 2
|
||||
else:
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.input_proj is not None:
|
||||
x = self.input_proj(x)
|
||||
|
||||
rnn_out, last_out = self.rnn(x)
|
||||
if self.rnn_arch == "LSTM":
|
||||
last_out = last_out[0]
|
||||
last_out = last_out.mean(dim=0)
|
||||
|
||||
if self.use_attn:
|
||||
laten = self.W(rnn_out).tanh()
|
||||
scores = self.u(laten).softmax(dim=1)
|
||||
att_out = (rnn_out * scores).sum(dim=1)
|
||||
last_out = torch.cat([last_out, att_out], dim=1)
|
||||
|
||||
return last_out
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[: x.size(0), :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
"""Transformer Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of transformer layers
|
||||
num_heads (int): number of heads in transformer
|
||||
dropout (float): dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
num_heads=2,
|
||||
dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.input_proj = nn.Linear(input_size, hidden_size)
|
||||
|
||||
self.pe = PositionalEncoding(input_size, dropout)
|
||||
layer = nn.TransformerEncoderLayer(
|
||||
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
|
||||
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = x.permute(1, 0, 2).contiguous() # the first dim need to be time
|
||||
x = self.pe(x)
|
||||
|
||||
x = self.input_proj(x)
|
||||
out = self.encoder(x)
|
||||
|
||||
return out[-1]
|
||||
|
||||
|
||||
class TRA(nn.Module):
|
||||
|
||||
"""Temporal Routing Adaptor (TRA)
|
||||
|
||||
TRA takes historical prediction erros & latent representation as inputs,
|
||||
then routes the input sample to a specific predictor for training & inference.
|
||||
|
||||
Args:
|
||||
input_size (int): input size (RNN/Transformer's hidden size)
|
||||
num_states (int): number of latent states (i.e., trading patterns)
|
||||
If `num_states=1`, then TRA falls back to traditional methods
|
||||
hidden_size (int): hidden size of the router
|
||||
tau (float): gumbel softmax temperature
|
||||
src_info (str): information for the router
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
num_states=1,
|
||||
hidden_size=8,
|
||||
rnn_arch="GRU",
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
tau=1.0,
|
||||
src_info="LR_TPE",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert src_info in ["LR", "TPE", "LR_TPE"], "invalid `src_info`"
|
||||
|
||||
self.num_states = num_states
|
||||
self.tau = tau
|
||||
self.rnn_arch = rnn_arch
|
||||
self.src_info = src_info
|
||||
|
||||
self.predictors = nn.Linear(input_size, num_states)
|
||||
|
||||
if self.num_states > 1:
|
||||
if "TPE" in src_info:
|
||||
self.router = getattr(nn, rnn_arch)(
|
||||
input_size=num_states,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.fc = nn.Linear(hidden_size + input_size if "LR" in src_info else hidden_size, num_states)
|
||||
else:
|
||||
self.fc = nn.Linear(input_size, num_states)
|
||||
|
||||
def reset_parameters(self):
|
||||
for child in self.children():
|
||||
child.reset_parameters()
|
||||
|
||||
def forward(self, hidden, hist_loss):
|
||||
|
||||
preds = self.predictors(hidden)
|
||||
|
||||
if self.num_states == 1: # no need for router when having only one prediction
|
||||
return preds, None, None
|
||||
|
||||
if "TPE" in self.src_info:
|
||||
out = self.router(hist_loss)[1] # TPE
|
||||
if self.rnn_arch == "LSTM":
|
||||
out = out[0]
|
||||
out = out.mean(dim=0)
|
||||
if "LR" in self.src_info:
|
||||
out = torch.cat([hidden, out], dim=-1) # LR_TPE
|
||||
else:
|
||||
out = hidden # LR
|
||||
|
||||
out = self.fc(out)
|
||||
|
||||
choice = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=True)
|
||||
prob = torch.softmax(out / self.tau, dim=-1)
|
||||
|
||||
return preds, choice, prob
|
||||
|
||||
|
||||
def evaluate(pred):
|
||||
pred = pred.rank(pct=True) # transform into percentiles
|
||||
score = pred.score
|
||||
label = pred.label
|
||||
diff = score - label
|
||||
MSE = (diff ** 2).mean()
|
||||
MAE = (diff.abs()).mean()
|
||||
IC = score.corr(label, method="spearman")
|
||||
return {"MSE": MSE, "MAE": MAE, "IC": IC}
|
||||
|
||||
|
||||
def shoot_infs(inp_tensor):
|
||||
"""Replaces inf by maximum of tensor"""
|
||||
mask_inf = torch.isinf(inp_tensor)
|
||||
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
|
||||
if len(ind_inf) > 0:
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = 0
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = 0
|
||||
m = torch.max(inp_tensor)
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = m
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = m
|
||||
return inp_tensor
|
||||
|
||||
|
||||
def sinkhorn(Q, n_iters=3, epsilon=0.1):
|
||||
# epsilon should be adjusted according to logits value's scale
|
||||
with torch.no_grad():
|
||||
Q = torch.exp(Q / epsilon)
|
||||
Q = shoot_infs(Q)
|
||||
for i in range(n_iters):
|
||||
Q /= Q.sum(dim=0, keepdim=True)
|
||||
Q /= Q.sum(dim=1, keepdim=True)
|
||||
return Q
|
||||
|
||||
|
||||
def loss_fn(pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
if len(pred.shape) == 2:
|
||||
label = label[:, None]
|
||||
return (pred[mask] - label[mask]).pow(2).mean(dim=0)
|
||||
|
||||
|
||||
def minmax_norm(x):
|
||||
xmin = x.min(dim=-1, keepdim=True).values
|
||||
xmax = x.max(dim=-1, keepdim=True).values
|
||||
mask = (xmin == xmax).squeeze()
|
||||
x = (x - xmin) / (xmax - xmin + 1e-12)
|
||||
x[mask] = 1
|
||||
return x
|
||||
|
||||
|
||||
def transport_sample(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
|
||||
"""
|
||||
sample-wise transport
|
||||
|
||||
Args:
|
||||
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
|
||||
label (torch.Tensor): label, [sample]
|
||||
choice (torch.Tensor): gumbel softmax choice, [sample x states]
|
||||
prob (torch.Tensor): router predicted probility, [sample x states]
|
||||
hist_loss (torch.Tensor): history loss matrix, [sample x states]
|
||||
count (list): sample counts for each day, empty list for sample-wise transport
|
||||
transport_method (str): transportation method
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
training (bool): indicate training or inference
|
||||
"""
|
||||
assert all_preds.shape == choice.shape
|
||||
assert len(all_preds) == len(label)
|
||||
assert transport_method in ["oracle", "router"]
|
||||
|
||||
all_loss = torch.zeros_like(all_preds)
|
||||
mask = ~torch.isnan(label)
|
||||
all_loss[mask] = (all_preds[mask] - label[mask, None]).pow(2) # [sample x states]
|
||||
|
||||
L = minmax_norm(all_loss.detach())
|
||||
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
|
||||
Lh = minmax_norm(Lh)
|
||||
P = sinkhorn(-Lh)
|
||||
del Lh
|
||||
|
||||
if transport_method == "router":
|
||||
if training:
|
||||
pred = (all_preds * choice).sum(dim=1) # gumbel softmax
|
||||
else:
|
||||
pred = all_preds[range(len(all_preds)), prob.argmax(dim=-1)] # argmax
|
||||
else:
|
||||
pred = (all_preds * P).sum(dim=1)
|
||||
|
||||
if transport_method == "router":
|
||||
loss = loss_fn(pred, label)
|
||||
else:
|
||||
loss = (all_loss * P).sum(dim=1).mean()
|
||||
|
||||
return loss, pred, L, P
|
||||
|
||||
|
||||
def transport_daily(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
|
||||
"""
|
||||
daily transport
|
||||
|
||||
Args:
|
||||
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
|
||||
label (torch.Tensor): label, [sample]
|
||||
choice (torch.Tensor): gumbel softmax choice, [days x states]
|
||||
prob (torch.Tensor): router predicted probility, [days x states]
|
||||
hist_loss (torch.Tensor): history loss matrix, [days x states]
|
||||
count (list): sample counts for each day, [days]
|
||||
transport_method (str): transportation method
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
training (bool): indicate training or inference
|
||||
"""
|
||||
assert len(prob) == len(count)
|
||||
assert len(all_preds) == sum(count)
|
||||
assert transport_method in ["oracle", "router"]
|
||||
|
||||
all_loss = [] # loss of all predictions
|
||||
start = 0
|
||||
for i, cnt in enumerate(count):
|
||||
slc = slice(start, start + cnt) # samples from the i-th day
|
||||
start += cnt
|
||||
tloss = loss_fn(all_preds[slc], label[slc]) # loss of the i-th day
|
||||
all_loss.append(tloss)
|
||||
all_loss = torch.stack(all_loss, dim=0) # [days x states]
|
||||
|
||||
L = minmax_norm(all_loss.detach())
|
||||
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
|
||||
Lh = minmax_norm(Lh)
|
||||
P = sinkhorn(-Lh)
|
||||
del Lh
|
||||
|
||||
pred = []
|
||||
start = 0
|
||||
for i, cnt in enumerate(count):
|
||||
slc = slice(start, start + cnt) # samples from the i-th day
|
||||
start += cnt
|
||||
if transport_method == "router":
|
||||
if training:
|
||||
tpred = all_preds[slc] @ choice[i] # gumbel softmax
|
||||
else:
|
||||
tpred = all_preds[slc][:, prob[i].argmax(dim=-1)] # argmax
|
||||
else:
|
||||
tpred = all_preds[slc] @ P[i]
|
||||
pred.append(tpred)
|
||||
pred = torch.cat(pred, dim=0) # [samples]
|
||||
|
||||
if transport_method == "router":
|
||||
loss = loss_fn(pred, label)
|
||||
else:
|
||||
loss = (all_loss * P).sum(dim=1).mean()
|
||||
|
||||
return loss, pred, L, P
|
||||
|
||||
|
||||
def load_state_dict_unsafe(model, state_dict):
|
||||
"""
|
||||
Load state dict to provided model while ignore exceptions.
|
||||
"""
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(model)
|
||||
load = None # break load->load reference cycle
|
||||
|
||||
return {"unexpected_keys": unexpected_keys, "missing_keys": missing_keys, "error_msgs": error_msgs}
|
||||
|
||||
|
||||
def plot(P):
|
||||
assert isinstance(P, pd.DataFrame)
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
|
||||
P.plot.area(ax=axes[0], xlabel="")
|
||||
P.idxmax(axis=1).value_counts().sort_index().plot.bar(ax=axes[1], xlabel="")
|
||||
plt.tight_layout()
|
||||
|
||||
with io.BytesIO() as buf:
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
img = plt.imread(buf)
|
||||
plt.close()
|
||||
|
||||
return np.uint8(img * 255)
|
||||
294
qlib/contrib/model/pytorch_transformer.py
Normal file
294
qlib/contrib/model/pytorch_transformer.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
import math
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
# qrun examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml ”
|
||||
|
||||
|
||||
class TransformerModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat: int = 20,
|
||||
d_model: int = 64,
|
||||
batch_size: int = 2048,
|
||||
nhead: int = 2,
|
||||
num_layers: int = 2,
|
||||
dropout: float = 0,
|
||||
n_epochs=100,
|
||||
lr=0.0001,
|
||||
metric="",
|
||||
early_stop=5,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
reg=1e-3,
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.reg = reg
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_jobs = n_jobs
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.logger = get_module_logger("TransformerModel")
|
||||
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred.float() - label.float()) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, max_len=1000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
# [T, N, F]
|
||||
return x + self.pe[: x.size(0), :]
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
|
||||
super(Transformer, self).__init__()
|
||||
self.feature_layer = nn.Linear(d_feat, d_model)
|
||||
self.pos_encoder = PositionalEncoding(d_model)
|
||||
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
|
||||
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
|
||||
self.decoder_layer = nn.Linear(d_model, 1)
|
||||
self.device = device
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, src):
|
||||
# src [N, F*T] --> [N, T, F]
|
||||
src = src.reshape(len(src), self.d_feat, -1).permute(0, 2, 1)
|
||||
src = self.feature_layer(src)
|
||||
|
||||
# src [N, T, F] --> [T, N, F], [60, 512, 8]
|
||||
src = src.transpose(1, 0) # not batch first
|
||||
|
||||
mask = None
|
||||
|
||||
src = self.pos_encoder(src)
|
||||
output = self.transformer_encoder(src, mask) # [60, 512, 8]
|
||||
|
||||
# [T, N, F] --> [N, T*F]
|
||||
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
|
||||
|
||||
return output.squeeze()
|
||||
269
qlib/contrib/model/pytorch_transformer_ts.py
Normal file
269
qlib/contrib/model/pytorch_transformer_ts.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
import math
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class TransformerModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat: int = 20,
|
||||
d_model: int = 64,
|
||||
batch_size: int = 8192,
|
||||
nhead: int = 2,
|
||||
num_layers: int = 2,
|
||||
dropout: float = 0,
|
||||
n_epochs=100,
|
||||
lr=0.0001,
|
||||
metric="",
|
||||
early_stop=5,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
reg=1e-3,
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.reg = reg
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_jobs = n_jobs
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.logger = get_module_logger("TransformerModel")
|
||||
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred.float() - label.float()) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, data_loader):
|
||||
|
||||
self.model.train()
|
||||
|
||||
for data in data_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.model(feature.float()) # .float()
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_loader):
|
||||
|
||||
self.model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
for data in data_loader:
|
||||
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature.float()) # .float()
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(train_loader)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(train_loader)
|
||||
val_loss, val_score = self.test_epoch(valid_loader)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.model.eval()
|
||||
preds = []
|
||||
|
||||
for data in test_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=dl_test.get_index())
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, max_len=1000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
# [T, N, F]
|
||||
return x + self.pe[: x.size(0), :]
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
|
||||
super(Transformer, self).__init__()
|
||||
self.feature_layer = nn.Linear(d_feat, d_model)
|
||||
self.pos_encoder = PositionalEncoding(d_model)
|
||||
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
|
||||
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
|
||||
self.decoder_layer = nn.Linear(d_model, 1)
|
||||
self.device = device
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, src):
|
||||
# src [N, T, F], [512, 60, 6]
|
||||
src = self.feature_layer(src) # [512, 60, 8]
|
||||
|
||||
# src [N, T, F] --> [T, N, F], [60, 512, 8]
|
||||
src = src.transpose(1, 0) # not batch first
|
||||
|
||||
mask = None
|
||||
|
||||
src = self.pos_encoder(src)
|
||||
output = self.transformer_encoder(src, mask) # [60, 512, 8]
|
||||
|
||||
# [T, N, F] --> [N, T*F]
|
||||
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
|
||||
|
||||
return output.squeeze()
|
||||
@@ -62,7 +62,7 @@ class XGBModel(Model, FeatureInt):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test)), index=x_test.index)
|
||||
|
||||
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import plotly.tools as tls
|
||||
import plotly.graph_objs as go
|
||||
|
||||
import statsmodels.api as sm
|
||||
@@ -80,9 +79,35 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
|
||||
:param dist:
|
||||
:return:
|
||||
"""
|
||||
fig, ax = plt.subplots(figsize=(8, 5))
|
||||
_mpl_fig = sm.qqplot(data.dropna(), dist, fit=True, line="45", ax=ax)
|
||||
return tls.mpl_to_plotly(_mpl_fig)
|
||||
# NOTE: plotly.tools.mpl_to_plotly not actively maintained, resulting in errors in the new version of matplotlib,
|
||||
# ref: https://github.com/plotly/plotly.py/issues/2913#issuecomment-730071567
|
||||
# removing plotly.tools.mpl_to_plotly for greater compatibility with matplotlib versions
|
||||
_plt_fig = sm.qqplot(data.dropna(), dist=dist, fit=True, line="45")
|
||||
plt.close(_plt_fig)
|
||||
qqplot_data = _plt_fig.gca().lines
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
{
|
||||
"type": "scatter",
|
||||
"x": qqplot_data[0].get_xdata(),
|
||||
"y": qqplot_data[0].get_ydata(),
|
||||
"mode": "markers",
|
||||
"marker": {"color": "#19d3f3"},
|
||||
}
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
{
|
||||
"type": "scatter",
|
||||
"x": qqplot_data[1].get_xdata(),
|
||||
"y": qqplot_data[1].get_ydata(),
|
||||
"mode": "lines",
|
||||
"line": {"color": "#636efa"},
|
||||
}
|
||||
)
|
||||
del qqplot_data
|
||||
return fig
|
||||
|
||||
|
||||
def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple:
|
||||
|
||||
@@ -148,7 +148,6 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
|
||||
pred score for this trade date, index is stock_id, contain 'score' column.
|
||||
current : Position()
|
||||
current position.
|
||||
trade_exchange : Exchange()
|
||||
trade_date : pd.Timestamp
|
||||
trade date.
|
||||
"""
|
||||
@@ -222,9 +221,9 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
only_tradable : bool
|
||||
will the strategy only consider the tradable stock when buying and selling.
|
||||
if only_tradable:
|
||||
strategy will make buy sell decision without checking the tradable state of the stock.
|
||||
the strategy will peek at the information in the short future to avoid untradable stocks (untradable stocks include stocks that meet suspension, or hit limit up or limit down).
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
the strategy will generate orders without peeking any information in the future, so the order generated by the strategies may fail.
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__()
|
||||
ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None))
|
||||
|
||||
@@ -196,9 +196,9 @@ class Feature(Expression):
|
||||
|
||||
def __init__(self, name=None):
|
||||
if name:
|
||||
self._name = name.lower()
|
||||
self._name = name
|
||||
else:
|
||||
self._name = type(self).__name__.lower()
|
||||
self._name = type(self).__name__
|
||||
|
||||
def __str__(self):
|
||||
return "$" + self._name
|
||||
|
||||
@@ -17,6 +17,7 @@ import abc
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union, Iterable
|
||||
from collections import OrderedDict
|
||||
|
||||
from ..config import C
|
||||
@@ -216,12 +217,14 @@ class CacheUtils:
|
||||
redis_lock.reset_all(r)
|
||||
|
||||
@staticmethod
|
||||
def visit(cache_path):
|
||||
def visit(cache_path: Union[str, Path]):
|
||||
# FIXME: Because read_lock was canceled when reading the cache, multiple processes may have read and write exceptions here
|
||||
try:
|
||||
with open(cache_path + ".meta", "rb") as f:
|
||||
cache_path = Path(cache_path)
|
||||
meta_path = cache_path.with_suffix(".meta")
|
||||
with meta_path.open("rb") as f:
|
||||
d = pickle.load(f)
|
||||
with open(cache_path + ".meta", "wb") as f:
|
||||
with meta_path.open("wb") as f:
|
||||
try:
|
||||
d["meta"]["last_visit"] = str(time.time())
|
||||
d["meta"]["visits"] = d["meta"]["visits"] + 1
|
||||
@@ -237,7 +240,7 @@ class CacheUtils:
|
||||
lock.acquire()
|
||||
except redis_lock.AlreadyAcquired:
|
||||
raise QlibCacheException(
|
||||
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
|
||||
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
|
||||
You can use the following command to clear your redis keys and rerun your commands:
|
||||
$ redis-cli
|
||||
> select {C.redis_task_db}
|
||||
@@ -249,17 +252,17 @@ class CacheUtils:
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def reader_lock(redis_t, lock_name):
|
||||
lock_name = f"{C.provider_uri}:{lock_name}"
|
||||
current_cache_rlock = redis_lock.Lock(redis_t, "%s-rlock" % lock_name)
|
||||
current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name)
|
||||
def reader_lock(redis_t, lock_name: str):
|
||||
current_cache_rlock = redis_lock.Lock(redis_t, f"{lock_name}-rlock")
|
||||
current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock")
|
||||
lock_reader = f"{lock_name}-reader"
|
||||
# make sure only one reader is entering
|
||||
current_cache_rlock.acquire(timeout=60)
|
||||
try:
|
||||
current_cache_readers = redis_t.get("%s-reader" % lock_name)
|
||||
current_cache_readers = redis_t.get(lock_reader)
|
||||
if current_cache_readers is None or int(current_cache_readers) == 0:
|
||||
CacheUtils.acquire(current_cache_wlock, lock_name)
|
||||
redis_t.incr("%s-reader" % lock_name)
|
||||
redis_t.incr(lock_reader)
|
||||
finally:
|
||||
current_cache_rlock.release()
|
||||
try:
|
||||
@@ -268,9 +271,9 @@ class CacheUtils:
|
||||
# make sure only one reader is leaving
|
||||
current_cache_rlock.acquire(timeout=60)
|
||||
try:
|
||||
redis_t.decr("%s-reader" % lock_name)
|
||||
if int(redis_t.get("%s-reader" % lock_name)) == 0:
|
||||
redis_t.delete("%s-reader" % lock_name)
|
||||
redis_t.decr(lock_reader)
|
||||
if int(redis_t.get(lock_reader)) == 0:
|
||||
redis_t.delete(lock_reader)
|
||||
current_cache_wlock.reset()
|
||||
finally:
|
||||
current_cache_rlock.release()
|
||||
@@ -278,8 +281,7 @@ class CacheUtils:
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def writer_lock(redis_t, lock_name):
|
||||
lock_name = f"{C.provider_uri}:{lock_name}"
|
||||
current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name, id=CacheUtils.LOCK_ID)
|
||||
current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock", id=CacheUtils.LOCK_ID)
|
||||
CacheUtils.acquire(current_cache_wlock, lock_name)
|
||||
try:
|
||||
yield
|
||||
@@ -297,6 +299,30 @@ class BaseProviderCache:
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.provider, attr)
|
||||
|
||||
@staticmethod
|
||||
def check_cache_exists(cache_path: Union[str, Path], suffix_list: Iterable = (".index", ".meta")) -> bool:
|
||||
cache_path = Path(cache_path)
|
||||
for p in [cache_path] + [cache_path.with_suffix(_s) for _s in suffix_list]:
|
||||
if not p.exists():
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def clear_cache(cache_path: Union[str, Path]):
|
||||
for p in [
|
||||
cache_path,
|
||||
cache_path.with_suffix(".meta"),
|
||||
cache_path.with_suffix(".index"),
|
||||
]:
|
||||
if p.exists():
|
||||
p.unlink()
|
||||
|
||||
@staticmethod
|
||||
def get_cache_dir(dir_name: str, freq: str = None) -> Path:
|
||||
cache_dir = Path(C.dpm.get_data_path(freq)).joinpath(dir_name)
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
|
||||
class ExpressionCache(BaseProviderCache):
|
||||
"""Expression cache mechanism base class.
|
||||
@@ -330,15 +356,16 @@ class ExpressionCache(BaseProviderCache):
|
||||
"""
|
||||
raise NotImplementedError("Implement this method if you want to use expression cache")
|
||||
|
||||
def update(self, cache_uri):
|
||||
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
|
||||
"""Update expression cache to latest calendar.
|
||||
|
||||
Overide this method to define how to update expression cache corresponding to users' own cache mechanism.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache_uri : str
|
||||
cache_uri : str or Path
|
||||
the complete uri of expression cache file (include dir path).
|
||||
freq : str
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -358,7 +385,9 @@ class DatasetCache(BaseProviderCache):
|
||||
|
||||
HDF_KEY = "df"
|
||||
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
"""Get feature dataset.
|
||||
|
||||
.. note:: Same interface as `dataset` method in dataset provider
|
||||
@@ -369,13 +398,19 @@ class DatasetCache(BaseProviderCache):
|
||||
"""
|
||||
if disk_cache == 0:
|
||||
# skip cache
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
else:
|
||||
# use and replace cache
|
||||
try:
|
||||
return self._dataset(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
return self._dataset(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
except NotImplementedError:
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, **kwargs):
|
||||
"""Get dataset cache file uri.
|
||||
@@ -384,14 +419,18 @@ class DatasetCache(BaseProviderCache):
|
||||
"""
|
||||
raise NotImplementedError("Implement this function to match your own cache mechanism")
|
||||
|
||||
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
"""Get feature dataset using cache.
|
||||
|
||||
Override this method to define how to get feature dataset corresponding to users' own cache mechanism.
|
||||
"""
|
||||
raise NotImplementedError("Implement this method if you want to use dataset feature cache")
|
||||
|
||||
def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def _dataset_uri(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
"""Get a uri of feature dataset using cache.
|
||||
specially:
|
||||
disk_cache=1 means using data set cache and return the uri of cache file.
|
||||
@@ -403,15 +442,16 @@ class DatasetCache(BaseProviderCache):
|
||||
"Implement this method if you want to use dataset feature cache as a cache file for client"
|
||||
)
|
||||
|
||||
def update(self, cache_uri):
|
||||
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
|
||||
"""Update dataset cache to latest calendar.
|
||||
|
||||
Overide this method to define how to update dataset cache corresponding to users' own cache mechanism.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache_uri : str
|
||||
cache_uri : str or Path
|
||||
the complete uri of dataset cache file (include dir path).
|
||||
freq : str
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -452,25 +492,19 @@ class DiskExpressionCache(ExpressionCache):
|
||||
self.r = get_redis_connection()
|
||||
# remote==True means client is using this module, writing behaviour will not be allowed.
|
||||
self.remote = kwargs.get("remote", False)
|
||||
self.expr_cache_path = os.path.join(C.get_data_path(), C.features_cache_dir_name)
|
||||
os.makedirs(self.expr_cache_path, exist_ok=True)
|
||||
|
||||
def get_cache_dir(self, freq: str = None) -> Path:
|
||||
return super(DiskExpressionCache, self).get_cache_dir(C.features_cache_dir_name, freq)
|
||||
|
||||
def _uri(self, instrument, field, start_time, end_time, freq):
|
||||
field = remove_fields_space(field)
|
||||
instrument = str(instrument).lower()
|
||||
return hash_args(instrument, field, freq)
|
||||
|
||||
@staticmethod
|
||||
def check_cache_exists(cache_path):
|
||||
for p in [cache_path, cache_path + ".meta"]:
|
||||
if not Path(p).exists():
|
||||
return False
|
||||
return True
|
||||
|
||||
def _expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
|
||||
_cache_uri = self._uri(instrument=instrument, field=field, start_time=None, end_time=None, freq=freq)
|
||||
_instrument_dir = os.path.join(self.expr_cache_path, instrument.lower())
|
||||
cache_path = os.path.join(_instrument_dir, _cache_uri)
|
||||
_instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower())
|
||||
cache_path = _instrument_dir.joinpath(_cache_uri)
|
||||
# get calendar
|
||||
from .data import Cal
|
||||
|
||||
@@ -478,7 +512,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
|
||||
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False)
|
||||
|
||||
if self.check_cache_exists(cache_path):
|
||||
if self.check_cache_exists(cache_path, suffix_list=[".meta"]):
|
||||
"""
|
||||
In most cases, we do not need reader_lock.
|
||||
Because updating data is a small probability event compare to reading data.
|
||||
@@ -502,8 +536,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
# normalize field
|
||||
field = remove_fields_space(field)
|
||||
# cache unavailable, generate the cache
|
||||
if not os.path.exists(_instrument_dir):
|
||||
os.makedirs(_instrument_dir, exist_ok=True)
|
||||
_instrument_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not isinstance(eval(parse_field(field)), Feature):
|
||||
# When the expression is not a raw feature
|
||||
# generate expression cache if the feature is not a Feature
|
||||
@@ -511,7 +544,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq)
|
||||
if not series.empty:
|
||||
# This expresion is empty, we don't generate any cache for it.
|
||||
with CacheUtils.writer_lock(self.r, "expression-%s" % _cache_uri):
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:expression-{_cache_uri}"):
|
||||
self.gen_expression_cache(
|
||||
expression_data=series,
|
||||
cache_path=cache_path,
|
||||
@@ -527,14 +560,6 @@ class DiskExpressionCache(ExpressionCache):
|
||||
# If the expression is a raw feature(such as $close, $open)
|
||||
return self.provider.expression(instrument, field, start_time, end_time, freq)
|
||||
|
||||
@staticmethod
|
||||
def clear_cache(cache_path):
|
||||
meta_path = cache_path + ".meta"
|
||||
for p in [cache_path, meta_path]:
|
||||
p = Path(p)
|
||||
if p.exists():
|
||||
p.unlink()
|
||||
|
||||
def gen_expression_cache(self, expression_data, cache_path, instrument, field, freq, last_update):
|
||||
"""use bin file to save like feature-data."""
|
||||
# Make sure the cache runs right when the directory is deleted
|
||||
@@ -544,27 +569,28 @@ class DiskExpressionCache(ExpressionCache):
|
||||
"meta": {"last_visit": time.time(), "visits": 1},
|
||||
}
|
||||
self.logger.debug(f"generating expression cache: {meta}")
|
||||
os.makedirs(self.expr_cache_path, exist_ok=True)
|
||||
self.clear_cache(cache_path)
|
||||
meta_path = cache_path + ".meta"
|
||||
meta_path = cache_path.with_suffix(".meta")
|
||||
|
||||
with open(meta_path, "wb") as f:
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(meta, f)
|
||||
os.chmod(meta_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
meta_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
df = expression_data.to_frame()
|
||||
|
||||
r = np.hstack([df.index[0], expression_data]).astype("<f")
|
||||
r.tofile(str(cache_path))
|
||||
|
||||
def update(self, sid, cache_uri):
|
||||
cp_cache_uri = os.path.join(self.expr_cache_path, sid, cache_uri)
|
||||
if not self.check_cache_exists(cp_cache_uri):
|
||||
def update(self, sid, cache_uri, freq: str = "day"):
|
||||
|
||||
cp_cache_uri = self.get_cache_dir(freq).joinpath(sid).joinpath(cache_uri)
|
||||
meta_path = cp_cache_uri.with_suffix(".meta")
|
||||
if not self.check_cache_exists(cp_cache_uri, suffix_list=[".meta"]):
|
||||
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
|
||||
self.clear_cache(cp_cache_uri)
|
||||
return 2
|
||||
|
||||
with CacheUtils.writer_lock(self.r, "expression-%s" % cache_uri):
|
||||
with open(cp_cache_uri + ".meta", "rb") as f:
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:expression-{cache_uri}"):
|
||||
with meta_path.open("rb") as f:
|
||||
d = pickle.load(f)
|
||||
instrument = d["info"]["instrument"]
|
||||
field = d["info"]["field"]
|
||||
@@ -611,7 +637,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
f.write(data)
|
||||
# update meta file
|
||||
d["info"]["last_update"] = str(new_calendar[-1])
|
||||
with open(cp_cache_uri + ".meta", "wb") as f:
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(d, f)
|
||||
return 0
|
||||
|
||||
@@ -623,22 +649,16 @@ class DiskDatasetCache(DatasetCache):
|
||||
super(DiskDatasetCache, self).__init__(provider)
|
||||
self.r = get_redis_connection()
|
||||
self.remote = kwargs.get("remote", False)
|
||||
self.dtst_cache_path = os.path.join(C.get_data_path(), C.dataset_cache_dir_name)
|
||||
os.makedirs(self.dtst_cache_path, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache)
|
||||
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
|
||||
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
|
||||
|
||||
@staticmethod
|
||||
def check_cache_exists(cache_path):
|
||||
for p in [cache_path, cache_path + ".index", cache_path + ".meta"]:
|
||||
if not Path(p).exists():
|
||||
return False
|
||||
return True
|
||||
def get_cache_dir(self, freq: str = None) -> Path:
|
||||
return super(DiskDatasetCache, self).get_cache_dir(C.dataset_cache_dir_name, freq)
|
||||
|
||||
@classmethod
|
||||
def read_data_from_cache(cls, cache_path, start_time, end_time, fields):
|
||||
def read_data_from_cache(cls, cache_path: Union[str, Path], start_time, end_time, fields):
|
||||
"""read_cache_from
|
||||
|
||||
This function can read data from the disk cache dataset
|
||||
@@ -671,17 +691,32 @@ class DiskDatasetCache(DatasetCache):
|
||||
df = pd.DataFrame(columns=fields)
|
||||
return df
|
||||
|
||||
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
|
||||
):
|
||||
|
||||
if disk_cache == 0:
|
||||
# In this case, data_set cache is configured but will not be used.
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
_cache_uri = self._uri(
|
||||
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq=freq,
|
||||
disk_cache=disk_cache,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
|
||||
cache_path = os.path.join(self.dtst_cache_path, _cache_uri)
|
||||
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
|
||||
|
||||
features = pd.DataFrame()
|
||||
gen_flag = False
|
||||
@@ -689,7 +724,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
if self.check_cache_exists(cache_path):
|
||||
if disk_cache == 1:
|
||||
# use cache
|
||||
with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri):
|
||||
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
CacheUtils.visit(cache_path)
|
||||
features = self.read_data_from_cache(cache_path, start_time, end_time, fields)
|
||||
elif disk_cache == 2:
|
||||
@@ -699,15 +734,21 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
if gen_flag:
|
||||
# cache unavailable, generate the cache
|
||||
with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri):
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
features = self.gen_dataset_cache(
|
||||
cache_path=cache_path, instruments=instruments, fields=fields, freq=freq
|
||||
cache_path=cache_path,
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
freq=freq,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
if not features.empty:
|
||||
features = features.sort_index().loc(axis=0)[:, start_time:end_time]
|
||||
return features
|
||||
|
||||
def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
|
||||
def _dataset_uri(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
|
||||
):
|
||||
if disk_cache == 0:
|
||||
# In this case, server only checks the expression cache.
|
||||
# The client will load the cache data by itself.
|
||||
@@ -715,21 +756,38 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq)
|
||||
return ""
|
||||
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
_cache_uri = self._uri(
|
||||
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq=freq,
|
||||
disk_cache=disk_cache,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
cache_path = os.path.join(self.dtst_cache_path, _cache_uri)
|
||||
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
|
||||
|
||||
if self.check_cache_exists(cache_path):
|
||||
self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly")
|
||||
with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri):
|
||||
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
CacheUtils.visit(cache_path)
|
||||
return _cache_uri
|
||||
else:
|
||||
# cache unavailable, generate the cache
|
||||
with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri):
|
||||
self.gen_dataset_cache(cache_path=cache_path, instruments=instruments, fields=fields, freq=freq)
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
self.gen_dataset_cache(
|
||||
cache_path=cache_path,
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
freq=freq,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
return _cache_uri
|
||||
|
||||
class IndexManager:
|
||||
@@ -740,8 +798,9 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
KEY = "df"
|
||||
|
||||
def __init__(self, cache_path):
|
||||
self.index_path = cache_path + ".index"
|
||||
def __init__(self, cache_path: Union[str, Path]):
|
||||
|
||||
self.index_path = cache_path.with_suffix(".index")
|
||||
self._data = None
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
@@ -757,7 +816,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
self._data.sort_index(inplace=True)
|
||||
self._data.to_hdf(self.index_path, key=self.KEY, mode="w", format="table")
|
||||
# The index should be readable for all users
|
||||
os.chmod(self.index_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
self.index_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
|
||||
def sync_from_disk(self):
|
||||
# The file will not be closed directly if we read_hdf from the disk directly
|
||||
@@ -784,10 +843,10 @@ class DiskDatasetCache(DatasetCache):
|
||||
def build_index_from_data(data, start_index=0):
|
||||
if data.empty:
|
||||
return pd.DataFrame()
|
||||
line_data = data.iloc[:, 0].fillna(0).groupby("datetime").count()
|
||||
line_data = data.groupby("datetime").size()
|
||||
line_data.sort_index(inplace=True)
|
||||
index_end = line_data.cumsum()
|
||||
index_start = index_end.shift(1).fillna(0)
|
||||
index_start = index_end.shift(1, fill_value=0)
|
||||
|
||||
index_data = pd.DataFrame()
|
||||
index_data["start"] = index_start
|
||||
@@ -795,15 +854,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
index_data += start_index
|
||||
return index_data
|
||||
|
||||
@staticmethod
|
||||
def clear_cache(cache_path):
|
||||
meta_path = cache_path + ".meta"
|
||||
for p in [cache_path, meta_path, cache_path + ".index", cache_path + ".data"]:
|
||||
p = Path(p)
|
||||
if p.exists():
|
||||
p.unlink()
|
||||
|
||||
def gen_dataset_cache(self, cache_path, instruments, fields, freq):
|
||||
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=[]):
|
||||
"""gen_dataset_cache
|
||||
|
||||
.. note:: This function does not consider the cache read write lock. Please
|
||||
@@ -838,20 +889,23 @@ class DiskDatasetCache(DatasetCache):
|
||||
:param instruments: The instruments to store the cache.
|
||||
:param fields: The fields to store the cache.
|
||||
:param freq: The freq to store the cache.
|
||||
:param inst_processors: Instrument processors.
|
||||
|
||||
:return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
|
||||
"""
|
||||
# get calendar
|
||||
from .data import Cal
|
||||
|
||||
cache_path = Path(cache_path)
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
self.logger.debug(f"Generating dataset cache {cache_path}")
|
||||
# Make sure the cache runs right when the directory is deleted
|
||||
# while running
|
||||
os.makedirs(self.dtst_cache_path, exist_ok=True)
|
||||
self.clear_cache(cache_path)
|
||||
|
||||
features = self.provider.dataset(instruments, fields, _calendar[0], _calendar[-1], freq)
|
||||
features = self.provider.dataset(
|
||||
instruments, fields, _calendar[0], _calendar[-1], freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
if features.empty:
|
||||
return features
|
||||
@@ -860,7 +914,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
features = features.swaplevel("instrument", "datetime").sort_index()
|
||||
|
||||
# write cache data
|
||||
with pd.HDFStore(cache_path + ".data") as store:
|
||||
with pd.HDFStore(str(cache_path.with_suffix(".data"))) as store:
|
||||
cache_to_orig_map = dict(zip(remove_fields_space(features.columns), features.columns))
|
||||
orig_to_cache_map = dict(zip(features.columns, remove_fields_space(features.columns)))
|
||||
cache_features = features[list(cache_to_orig_map.values())].rename(columns=orig_to_cache_map)
|
||||
@@ -876,12 +930,13 @@ class DiskDatasetCache(DatasetCache):
|
||||
"fields": cache_columns,
|
||||
"freq": freq,
|
||||
"last_update": str(_calendar[-1]), # The last_update to store the cache
|
||||
"inst_processors": inst_processors, # The last_update to store the cache
|
||||
},
|
||||
"meta": {"last_visit": time.time(), "visits": 1},
|
||||
}
|
||||
with open(cache_path + ".meta", "wb") as f:
|
||||
with cache_path.with_suffix(".meta").open("wb") as f:
|
||||
pickle.dump(meta, f)
|
||||
os.chmod(cache_path + ".meta", stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
cache_path.with_suffix(".meta").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
# write index file
|
||||
im = DiskDatasetCache.IndexManager(cache_path)
|
||||
index_data = im.build_index_from_data(features)
|
||||
@@ -890,26 +945,27 @@ class DiskDatasetCache(DatasetCache):
|
||||
# rename the file after the cache has been generated
|
||||
# this doesn't work well on windows, but our server won't use windows
|
||||
# temporarily
|
||||
os.replace(cache_path + ".data", cache_path)
|
||||
cache_path.with_suffix(".data").rename(cache_path)
|
||||
# the fields of the cached features are converted to the original fields
|
||||
return features.swaplevel("datetime", "instrument")
|
||||
|
||||
def update(self, cache_uri):
|
||||
cp_cache_uri = os.path.join(self.dtst_cache_path, cache_uri)
|
||||
|
||||
def update(self, cache_uri, freq: str = "day"):
|
||||
cp_cache_uri = self.get_cache_dir(freq).joinpath(cache_uri)
|
||||
meta_path = cp_cache_uri.with_suffix(".meta")
|
||||
if not self.check_cache_exists(cp_cache_uri):
|
||||
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
|
||||
self.clear_cache(cp_cache_uri)
|
||||
return 2
|
||||
|
||||
im = DiskDatasetCache.IndexManager(cp_cache_uri)
|
||||
with CacheUtils.writer_lock(self.r, "dataset-%s" % cache_uri):
|
||||
with open(cp_cache_uri + ".meta", "rb") as f:
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:dataset-{cache_uri}"):
|
||||
with meta_path.open("rb") as f:
|
||||
d = pickle.load(f)
|
||||
instruments = d["info"]["instruments"]
|
||||
fields = d["info"]["fields"]
|
||||
freq = d["info"]["freq"]
|
||||
last_update_time = d["info"]["last_update"]
|
||||
inst_processors = d["info"]["inst_processors"]
|
||||
index_data = im.get_index()
|
||||
|
||||
self.logger.debug("Updating dataset: {}".format(d))
|
||||
@@ -960,7 +1016,12 @@ class DiskDatasetCache(DatasetCache):
|
||||
)
|
||||
|
||||
data = self.provider.dataset(
|
||||
instruments, fields, whole_calendar[current_index - rm_n_period], new_calendar[-1], freq
|
||||
instruments,
|
||||
fields,
|
||||
whole_calendar[current_index - rm_n_period],
|
||||
new_calendar[-1],
|
||||
freq,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
|
||||
if not data.empty:
|
||||
@@ -995,7 +1056,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
# update meta file
|
||||
d["info"]["last_update"] = str(new_calendar[-1])
|
||||
with open(cp_cache_uri + ".meta", "wb") as f:
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(d, f)
|
||||
return 0
|
||||
|
||||
@@ -1006,26 +1067,36 @@ class SimpleDatasetCache(DatasetCache):
|
||||
def __init__(self, provider):
|
||||
super(SimpleDatasetCache, self).__init__(provider)
|
||||
try:
|
||||
self.local_cache_path = C["local_cache_path"]
|
||||
except KeyError as e:
|
||||
self.local_cache_path: Path = Path(C["local_cache_path"]).expanduser().resolve()
|
||||
except (KeyError, TypeError) as e:
|
||||
self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism")
|
||||
raise
|
||||
self.logger.info(
|
||||
f"DatasetCache directory: {self.local_cache_path}, "
|
||||
f"modify the cache directory via the local_cache_path in the config"
|
||||
)
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
|
||||
instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq)
|
||||
local_cache_path = str(Path(self.local_cache_path).expanduser().resolve())
|
||||
return hash_args(instruments, fields, start_time, end_time, freq, disk_cache, local_cache_path)
|
||||
return hash_args(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors
|
||||
)
|
||||
|
||||
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
if disk_cache == 0:
|
||||
# In this case, data_set cache is configured but will not be used.
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
os.makedirs(os.path.expanduser(self.local_cache_path), exist_ok=True)
|
||||
cache_file = os.path.join(
|
||||
self.local_cache_path, self._uri(instruments, fields, start_time, end_time, freq, disk_cache=disk_cache)
|
||||
self.local_cache_path.mkdir(exist_ok=True, parents=True)
|
||||
cache_file = self.local_cache_path.joinpath(
|
||||
self._uri(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache=disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
)
|
||||
gen_flag = False
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
if cache_file.exists():
|
||||
if disk_cache == 1:
|
||||
# use cache
|
||||
df = pd.read_pickle(cache_file)
|
||||
@@ -1037,7 +1108,9 @@ class SimpleDatasetCache(DatasetCache):
|
||||
gen_flag = True
|
||||
|
||||
if gen_flag:
|
||||
data = self.provider.dataset(instruments, normalize_cache_fields(fields), start_time, end_time, freq)
|
||||
data = self.provider.dataset(
|
||||
instruments, normalize_cache_fields(fields), start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
data.to_pickle(cache_file)
|
||||
return self.cache_to_origin_data(data, fields)
|
||||
|
||||
@@ -1045,26 +1118,53 @@ class SimpleDatasetCache(DatasetCache):
|
||||
class DatasetURICache(DatasetCache):
|
||||
"""Prepared cache mechanism for server."""
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache)
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
|
||||
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
|
||||
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
|
||||
def dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
|
||||
):
|
||||
|
||||
if "local" in C.dataset_provider.lower():
|
||||
# use LocalDatasetProvider
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
if disk_cache == 0:
|
||||
# do not use data_set cache, load data from remote expression cache directly
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq, disk_cache, return_uri=False)
|
||||
|
||||
return self.provider.dataset(
|
||||
instruments,
|
||||
fields,
|
||||
start_time,
|
||||
end_time,
|
||||
freq,
|
||||
disk_cache,
|
||||
return_uri=False,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
# use ClientDatasetProvider
|
||||
feature_uri = self._uri(instruments, fields, None, None, freq, disk_cache=disk_cache)
|
||||
feature_uri = self._uri(
|
||||
instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
value, expire = MemCacheExpire.get_cache(H["f"], feature_uri)
|
||||
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
|
||||
if value is None or expire or not os.path.exists(mnt_feature_uri):
|
||||
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
|
||||
if value is None or expire or not mnt_feature_uri.exists():
|
||||
df, uri = self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, return_uri=True
|
||||
instruments,
|
||||
fields,
|
||||
start_time,
|
||||
end_time,
|
||||
freq,
|
||||
disk_cache,
|
||||
return_uri=True,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
# cache uri
|
||||
MemCacheExpire.set_cache(H["f"], uri, uri)
|
||||
@@ -1072,7 +1172,6 @@ class DatasetURICache(DatasetCache):
|
||||
# HZ['f'][uri] = df.copy()
|
||||
get_module_logger("cache").debug(f"get feature from {C.dataset_provider}")
|
||||
else:
|
||||
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
|
||||
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
|
||||
get_module_logger("cache").debug("get feature from uri cache")
|
||||
|
||||
|
||||
@@ -5,28 +5,34 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import abc
|
||||
import copy
|
||||
import time
|
||||
import queue
|
||||
import bisect
|
||||
import logging
|
||||
import importlib
|
||||
import traceback
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from multiprocessing import Pool
|
||||
from typing import Iterable, Union
|
||||
|
||||
from .cache import H
|
||||
from ..config import C
|
||||
from .ops import Operators
|
||||
from ..log import get_module_logger
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
|
||||
from .base import Feature
|
||||
from .ops import Operators
|
||||
from .inst_processor import InstProcessor
|
||||
|
||||
from ..log import get_module_logger
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
from ..utils import (
|
||||
Wrapper,
|
||||
init_instance_by_config,
|
||||
register_wrapper,
|
||||
get_module_by_module_path,
|
||||
parse_field,
|
||||
hash_args,
|
||||
normalize_cache_fields,
|
||||
code_to_fname,
|
||||
)
|
||||
|
||||
|
||||
class ProviderBackendMixin:
|
||||
@@ -48,8 +54,14 @@ class ProviderBackendMixin:
|
||||
# default provider_uri map
|
||||
if "provider_uri" not in backend_kwargs:
|
||||
# if the user has no uri configured, use: uri = uri_map[freq]
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
|
||||
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
|
||||
# 3. qlib.init: provider_uri
|
||||
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {})
|
||||
freq = kwargs.get("freq", "day")
|
||||
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
|
||||
if freq not in provider_uri_map:
|
||||
provider_uri_map[freq] = C.dpm.get_data_path(freq)
|
||||
backend_kwargs["provider_uri"] = provider_uri_map[freq]
|
||||
backend.setdefault("kwargs", {}).update(**kwargs)
|
||||
return init_instance_by_config(backend)
|
||||
@@ -199,13 +211,23 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin):
|
||||
'filter_start_time': None,
|
||||
'filter_end_time': None}]}
|
||||
"""
|
||||
from .filter import SeriesDFilter
|
||||
|
||||
if filter_pipe is None:
|
||||
filter_pipe = []
|
||||
config = {"market": market, "filter_pipe": []}
|
||||
# the order of the filters will affect the result, so we need to keep
|
||||
# the order
|
||||
for filter_t in filter_pipe:
|
||||
config["filter_pipe"].append(filter_t.to_config())
|
||||
if isinstance(filter_t, dict):
|
||||
_config = filter_t
|
||||
elif isinstance(filter_t, SeriesDFilter):
|
||||
_config = filter_t.to_config()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported filter types: {type(filter_t)}! Filter only supports dict or isinstance(filter, SeriesDFilter)"
|
||||
)
|
||||
config["filter_pipe"].append(_config)
|
||||
return config
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -341,7 +363,7 @@ class DatasetProvider(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"):
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=[]):
|
||||
"""Get dataset data.
|
||||
|
||||
Parameters
|
||||
@@ -356,6 +378,8 @@ class DatasetProvider(abc.ABC):
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency.
|
||||
inst_processors: Iterable[Union[dict, InstProcessor]]
|
||||
the operations performed on each instrument
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -372,6 +396,7 @@ class DatasetProvider(abc.ABC):
|
||||
end_time=None,
|
||||
freq="day",
|
||||
disk_cache=1,
|
||||
inst_processors=[],
|
||||
**kwargs,
|
||||
):
|
||||
"""Get task uri, used when generating rabbitmq task in qlib_server
|
||||
@@ -392,7 +417,8 @@ class DatasetProvider(abc.ABC):
|
||||
whether to skip(0)/use(1)/replace(2) disk_cache.
|
||||
|
||||
"""
|
||||
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
# TODO: qlib-server support inst_processors
|
||||
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache, inst_processors)
|
||||
|
||||
@staticmethod
|
||||
def get_instruments_d(instruments, freq):
|
||||
@@ -433,7 +459,7 @@ class DatasetProvider(abc.ABC):
|
||||
return [ExpressionD.get_expression_instance(f) for f in fields]
|
||||
|
||||
@staticmethod
|
||||
def dataset_processor(instruments_d, column_names, start_time, end_time, freq):
|
||||
def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=[]):
|
||||
"""
|
||||
Load and process the data, return the data set.
|
||||
- default using multi-kernel method.
|
||||
@@ -459,6 +485,7 @@ class DatasetProvider(abc.ABC):
|
||||
normalize_column_names,
|
||||
spans,
|
||||
C,
|
||||
inst_processors,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -473,6 +500,7 @@ class DatasetProvider(abc.ABC):
|
||||
normalize_column_names,
|
||||
None,
|
||||
C,
|
||||
inst_processors,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -494,7 +522,9 @@ class DatasetProvider(abc.ABC):
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def expression_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None):
|
||||
def expression_calculator(
|
||||
inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]
|
||||
):
|
||||
"""
|
||||
Calculate the expressions for one instrument, return a df result.
|
||||
If the expression has been calculated before, load from cache.
|
||||
@@ -518,13 +548,17 @@ class DatasetProvider(abc.ABC):
|
||||
data.index = _calendar[data.index.values.astype(int)]
|
||||
data.index.names = ["datetime"]
|
||||
|
||||
if spans is None:
|
||||
return data
|
||||
else:
|
||||
if spans is not None:
|
||||
mask = np.zeros(len(data), dtype=bool)
|
||||
for begin, end in spans:
|
||||
mask |= (data.index >= begin) & (data.index <= end)
|
||||
return data[mask]
|
||||
data = data[mask]
|
||||
|
||||
for _processor in inst_processors:
|
||||
if _processor:
|
||||
_processor_obj = init_instance_by_config(_processor, accept_types=InstProcessor)
|
||||
data = _processor_obj(data)
|
||||
return data
|
||||
|
||||
|
||||
class LocalCalendarProvider(CalendarProvider):
|
||||
@@ -537,11 +571,6 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
super(LocalCalendarProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@property
|
||||
def _uri_cal(self):
|
||||
"""Calendar file uri."""
|
||||
return os.path.join(C.get_data_path(), "calendars", "{}.txt")
|
||||
|
||||
def load_calendar(self, freq, future):
|
||||
"""Load original calendar timestamp from file.
|
||||
|
||||
@@ -601,11 +630,6 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
Provide instrument data from local data source.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _uri_inst(self):
|
||||
"""Instrument file uri."""
|
||||
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
|
||||
|
||||
def _load_instruments(self, market, freq):
|
||||
return self.backend_obj(market=market, freq=freq).data
|
||||
|
||||
@@ -654,14 +678,9 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
super(LocalFeatureProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@property
|
||||
def _uri_data(self):
|
||||
"""Static feature file uri."""
|
||||
return os.path.join(C.get_data_path(), "features", "{}", "{}.{}.bin")
|
||||
|
||||
def feature(self, instrument, field, start_index, end_index, freq):
|
||||
# validate
|
||||
field = str(field).lower()[1:]
|
||||
field = str(field)[1:]
|
||||
instrument = code_to_fname(instrument)
|
||||
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
|
||||
|
||||
@@ -703,7 +722,15 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"):
|
||||
def dataset(
|
||||
self,
|
||||
instruments,
|
||||
fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
inst_processors=[],
|
||||
):
|
||||
instruments_d = self.get_instruments_d(instruments, freq)
|
||||
column_names = self.get_column_names(fields)
|
||||
cal = Cal.calendar(start_time, end_time, freq)
|
||||
@@ -712,7 +739,9 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
|
||||
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
|
||||
data = self.dataset_processor(
|
||||
instruments_d, column_names, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -855,6 +884,7 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
freq="day",
|
||||
disk_cache=0,
|
||||
return_uri=False,
|
||||
inst_processors=[],
|
||||
):
|
||||
if Inst.get_inst_type(instruments) == Inst.DICT:
|
||||
get_module_logger("data").warning(
|
||||
@@ -894,7 +924,7 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
|
||||
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
|
||||
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors)
|
||||
if return_uri:
|
||||
return data, feature_uri
|
||||
else:
|
||||
@@ -907,6 +937,13 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
- using single-process implementation.
|
||||
|
||||
"""
|
||||
# TODO: support inst_processors, need to change the code of qlib-server at the same time
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
self.conn.send_request(
|
||||
request_type="feature",
|
||||
request_content={
|
||||
@@ -926,7 +963,7 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
get_module_logger("data").debug("get result")
|
||||
try:
|
||||
# pre-mound nfs, used for demo
|
||||
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
|
||||
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri)
|
||||
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
|
||||
get_module_logger("data").debug("finish slicing data")
|
||||
if return_uri:
|
||||
@@ -964,6 +1001,7 @@ class BaseProvider:
|
||||
end_time=None,
|
||||
freq="day",
|
||||
disk_cache=None,
|
||||
inst_processors=[],
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@@ -978,9 +1016,11 @@ class BaseProvider:
|
||||
disk_cache = C.default_disk_cache if disk_cache is None else disk_cache
|
||||
fields = list(fields) # In case of tuple.
|
||||
try:
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
return DatasetD.dataset(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
except TypeError:
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, inst_processors=inst_processors)
|
||||
|
||||
|
||||
class LocalProvider(BaseProvider):
|
||||
@@ -1028,13 +1068,21 @@ class ClientProvider(BaseProvider):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def is_instance_of_provider(instance: object, cls: type):
|
||||
if isinstance(instance, Wrapper):
|
||||
p = getattr(instance, "_provider", None)
|
||||
|
||||
return False if p is None else isinstance(p, cls)
|
||||
|
||||
return isinstance(instance, cls)
|
||||
|
||||
from .client import Client
|
||||
|
||||
self.client = Client(C.flask_server, C.flask_port)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
if isinstance(Cal, ClientCalendarProvider):
|
||||
if is_instance_of_provider(Cal, ClientCalendarProvider):
|
||||
Cal.set_conn(self.client)
|
||||
if isinstance(Inst, ClientInstrumentProvider):
|
||||
if is_instance_of_provider(Inst, ClientInstrumentProvider):
|
||||
Inst.set_conn(self.client)
|
||||
if hasattr(DatasetD, "provider"):
|
||||
DatasetD.provider.set_conn(self.client)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
@@ -243,6 +243,8 @@ class TSDataSampler:
|
||||
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
|
||||
dataset based on tabular data.
|
||||
- On time step dimension, the smaller index indicates the historical data and the larger index indicates the future
|
||||
data.
|
||||
|
||||
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
|
||||
more powerful subclasses.
|
||||
@@ -309,11 +311,19 @@ class TSDataSampler:
|
||||
self.data_index = deepcopy(self.data.index)
|
||||
|
||||
if flt_data is not None:
|
||||
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
|
||||
if isinstance(flt_data, pd.DataFrame):
|
||||
assert len(flt_data.columns) == 1
|
||||
flt_data = flt_data.iloc[:, 0]
|
||||
# NOTE: bool(np.nan) is True !!!!!!!!
|
||||
# make sure reindex comes first. Otherwise extra NaN may appear.
|
||||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(
|
||||
start=time_to_slc_point(start), end=time_to_slc_point(end)
|
||||
)
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
del self.data # save memory
|
||||
@@ -341,7 +351,7 @@ class TSDataSampler:
|
||||
setattr(self, k, v)
|
||||
|
||||
@staticmethod
|
||||
def build_index(data: pd.DataFrame) -> dict:
|
||||
def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
|
||||
"""
|
||||
The relation of the data
|
||||
|
||||
@@ -352,9 +362,15 @@ class TSDataSampler:
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict:
|
||||
{<index>: <prev_index or None>}
|
||||
# get the previous index of a line given index
|
||||
Tuple[pd.DataFrame, dict]:
|
||||
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
|
||||
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
|
||||
datetime
|
||||
2021-01-11 0 1 2 3 4 5 ...
|
||||
2021-01-12 4146 4147 4148 4149 4150 4151 ...
|
||||
2021-01-13 8293 8294 8295 8296 8297 8298 ...
|
||||
2021-01-14 12441 12442 12443 12444 12445 12446 ...
|
||||
2) the second element: {<original index>: <row, col>}
|
||||
"""
|
||||
# object incase of pandas converting int to flaot
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
|
||||
@@ -491,7 +507,9 @@ class TSDatasetH(DatasetH):
|
||||
- The dimension of a batch of data <batch_idx, feature, timestep>
|
||||
"""
|
||||
|
||||
def __init__(self, step_len=30, **kwargs):
|
||||
DEFAULT_STEP_LEN = 30
|
||||
|
||||
def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
|
||||
self.step_len = step_len
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import fetch_df_by_index
|
||||
from ...utils import lazy_sort_index
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
@@ -146,7 +147,8 @@ class DataHandler(Serializable):
|
||||
# Setup data.
|
||||
# _data may be with multiple column index level. The outer level indicates the feature set name
|
||||
with TimeInspector.logt("Loading data"):
|
||||
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
|
||||
# make sure the fetch method is based on a index-sorted pd.DataFrame
|
||||
self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time))
|
||||
# TODO: cache
|
||||
|
||||
CS_ALL = "__all" # return all columns with single-level index column
|
||||
@@ -293,11 +295,14 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
# process type
|
||||
PTYPE_I = "independent"
|
||||
# - self._infer will be processed by infer_processors
|
||||
# - self._learn will be processed by learn_processors
|
||||
# - self._infer will be processed by shared_processors + infer_processors
|
||||
# - self._learn will be processed by shared_processors + learn_processors
|
||||
|
||||
# NOTE:
|
||||
PTYPE_A = "append"
|
||||
# - self._infer will be processed by infer_processors
|
||||
# - self._learn will be processed by infer_processors + learn_processors
|
||||
|
||||
# - self._infer will be processed by shared_processors + infer_processors
|
||||
# - self._learn will be processed by shared_processors + infer_processors + learn_processors
|
||||
# - (e.g. self._infer processed by learn_processors )
|
||||
|
||||
def __init__(
|
||||
@@ -306,8 +311,9 @@ class DataHandlerLP(DataHandler):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Union[dict, str, DataLoader] = None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
infer_processors: List = [],
|
||||
learn_processors: List = [],
|
||||
shared_processors: List = [],
|
||||
process_type=PTYPE_A,
|
||||
drop_raw=False,
|
||||
**kwargs,
|
||||
@@ -358,7 +364,8 @@ class DataHandlerLP(DataHandler):
|
||||
# Setup preprocessor
|
||||
self.infer_processors = [] # for lint
|
||||
self.learn_processors = [] # for lint
|
||||
for pname in "infer_processors", "learn_processors":
|
||||
self.shared_processors = [] # for lint
|
||||
for pname in "infer_processors", "learn_processors", "shared_processors":
|
||||
for proc in locals()[pname]:
|
||||
getattr(self, pname).append(
|
||||
init_instance_by_config(
|
||||
@@ -373,9 +380,12 @@ class DataHandlerLP(DataHandler):
|
||||
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
|
||||
|
||||
def get_all_processors(self):
|
||||
return self.infer_processors + self.learn_processors
|
||||
return self.shared_processors + self.infer_processors + self.learn_processors
|
||||
|
||||
def fit(self):
|
||||
"""
|
||||
fit data without processing the data
|
||||
"""
|
||||
for proc in self.get_all_processors():
|
||||
with TimeInspector.logt(f"{proc.__class__.__name__}"):
|
||||
proc.fit(self._data)
|
||||
@@ -388,30 +398,68 @@ class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
self.process_data(with_fit=True)
|
||||
|
||||
@staticmethod
|
||||
def _run_proc_l(
|
||||
df: pd.DataFrame, proc_l: List[processor_module.Processor], with_fit: bool, check_for_infer: bool
|
||||
) -> pd.DataFrame:
|
||||
for proc in proc_l:
|
||||
if check_for_infer and not proc.is_for_infer():
|
||||
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
|
||||
with TimeInspector.logt(f"{proc.__class__.__name__}"):
|
||||
if with_fit:
|
||||
proc.fit(df)
|
||||
df = proc(df)
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def _is_proc_readonly(proc_l: List[processor_module.Processor]):
|
||||
"""
|
||||
NOTE: it will return True if `len(proc_l) == 0`
|
||||
"""
|
||||
for p in proc_l:
|
||||
if not p.readonly():
|
||||
return False
|
||||
return True
|
||||
|
||||
def process_data(self, with_fit: bool = False):
|
||||
"""
|
||||
process_data data. Fun `processor.fit` if necessary
|
||||
|
||||
Notation: (data) [processor]
|
||||
|
||||
# data processing flow of self.process_type == DataHandlerLP.PTYPE_I
|
||||
(self._data)-[shared_processors]-(_shared_df)-[learn_processors]-(_learn_df)
|
||||
\
|
||||
-[infer_processors]-(_infer_df)
|
||||
|
||||
# data processing flow of self.process_type == DataHandlerLP.PTYPE_A
|
||||
(self._data)-[shared_processors]-(_shared_df)-[infer_processors]-(_infer_df)-[learn_processors]-(_learn_df)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
with_fit : bool
|
||||
The input of the `fit` will be the output of the previous processor
|
||||
"""
|
||||
# data for inference
|
||||
_infer_df = self._data
|
||||
if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data
|
||||
_infer_df = _infer_df.copy()
|
||||
# shared data processors
|
||||
# 1) assign
|
||||
_shared_df = self._data
|
||||
if not self._is_proc_readonly(self.shared_processors): # avoid modifying the original data
|
||||
_shared_df = _shared_df.copy()
|
||||
# 2) process
|
||||
_shared_df = self._run_proc_l(_shared_df, self.shared_processors, with_fit=with_fit, check_for_infer=True)
|
||||
|
||||
# data for inference
|
||||
# 1) assign
|
||||
_infer_df = _shared_df
|
||||
if not self._is_proc_readonly(self.infer_processors): # avoid modifying the original data
|
||||
_infer_df = _infer_df.copy()
|
||||
# 2) process
|
||||
_infer_df = self._run_proc_l(_infer_df, self.infer_processors, with_fit=with_fit, check_for_infer=True)
|
||||
|
||||
for proc in self.infer_processors:
|
||||
if not proc.is_for_infer():
|
||||
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
|
||||
with TimeInspector.logt(f"{proc.__class__.__name__}"):
|
||||
if with_fit:
|
||||
proc.fit(_infer_df)
|
||||
_infer_df = proc(_infer_df)
|
||||
self._infer = _infer_df
|
||||
|
||||
# data for learning
|
||||
# 1) assign
|
||||
if self.process_type == DataHandlerLP.PTYPE_I:
|
||||
_learn_df = self._data
|
||||
elif self.process_type == DataHandlerLP.PTYPE_A:
|
||||
@@ -419,14 +467,11 @@ class DataHandlerLP(DataHandler):
|
||||
_learn_df = _infer_df
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
if len(self.learn_processors) > 0: # avoid modifying the original data
|
||||
if not self._is_proc_readonly(self.learn_processors): # avoid modifying the original data
|
||||
_learn_df = _learn_df.copy()
|
||||
for proc in self.learn_processors:
|
||||
with TimeInspector.logt(f"{proc.__class__.__name__}"):
|
||||
if with_fit:
|
||||
proc.fit(_learn_df)
|
||||
_learn_df = proc(_learn_df)
|
||||
# 2) process
|
||||
_learn_df = self._run_proc_l(_learn_df, self.learn_processors, with_fit=with_fit, check_for_infer=False)
|
||||
|
||||
self._learn = _learn_df
|
||||
|
||||
if self.drop_raw:
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import abc
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple, Union, List
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.data import filter as filter_module
|
||||
from qlib.data.filter import BaseDFilter
|
||||
from qlib.utils import load_dataset, init_instance_by_config
|
||||
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
@@ -62,11 +58,11 @@ class DLWParser(DataLoader):
|
||||
Extracting this class so that QlibDataLoader and other dataloaders(such as QdbDataLoader) can share the fields.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Tuple[list, tuple, dict]):
|
||||
def __init__(self, config: Union[list, tuple, dict]):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
config : Tuple[list, tuple, dict]
|
||||
config : Union[list, tuple, dict]
|
||||
Config will be used to describe the fields and column names
|
||||
|
||||
.. code-block::
|
||||
@@ -88,7 +84,7 @@ class DLWParser(DataLoader):
|
||||
else:
|
||||
self.fields = self._parse_fields_info(config)
|
||||
|
||||
def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]:
|
||||
def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, list]:
|
||||
if len(fields_info) == 0:
|
||||
raise ValueError("The size of fields must be greater than 0")
|
||||
|
||||
@@ -104,7 +100,15 @@ class DLWParser(DataLoader):
|
||||
return exprs, names
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
def load_group_df(
|
||||
self,
|
||||
instruments,
|
||||
exprs: list,
|
||||
names: list,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
gp_name: str = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
load the dataframe for specific group
|
||||
|
||||
@@ -128,7 +132,7 @@ class DLWParser(DataLoader):
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
{
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
|
||||
for grp, (exprs, names) in self.fields.items()
|
||||
},
|
||||
axis=1,
|
||||
@@ -142,7 +146,14 @@ class DLWParser(DataLoader):
|
||||
class QlibDataLoader(DLWParser):
|
||||
"""Same as QlibDataLoader. The fields can be define by config"""
|
||||
|
||||
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True, freq="day"):
|
||||
def __init__(
|
||||
self,
|
||||
config: Tuple[list, tuple, dict],
|
||||
filter_pipe: List = None,
|
||||
swap_level: bool = True,
|
||||
freq: Union[str, dict] = "day",
|
||||
inst_processor: dict = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -152,20 +163,41 @@ class QlibDataLoader(DLWParser):
|
||||
Filter pipe for the instruments
|
||||
swap_level :
|
||||
Whether to swap level of MultiIndex
|
||||
freq: dict or str
|
||||
If type(config) == dict and type(freq) == str, load config data using freq.
|
||||
If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
|
||||
inst_processor: dict
|
||||
If inst_processor is not None and type(config) == dict; load config[<group_name>] data using inst_processor[<group_name>]
|
||||
"""
|
||||
if filter_pipe is not None:
|
||||
assert isinstance(filter_pipe, list), "The type of `filter_pipe` must be list."
|
||||
filter_pipe = [
|
||||
init_instance_by_config(fp, None if "module_path" in fp else filter_module, accept_types=BaseDFilter)
|
||||
for fp in filter_pipe
|
||||
]
|
||||
|
||||
self.filter_pipe = filter_pipe
|
||||
self.swap_level = swap_level
|
||||
self.freq = freq
|
||||
|
||||
# sample
|
||||
self.inst_processor = inst_processor if inst_processor is not None else {}
|
||||
assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict"
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if self.is_group:
|
||||
# check sample config
|
||||
if isinstance(freq, dict):
|
||||
for _gp in config.keys():
|
||||
if _gp not in freq:
|
||||
raise ValueError(f"freq(={freq}) missing group(={_gp})")
|
||||
assert (
|
||||
self.inst_processor
|
||||
), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty"
|
||||
|
||||
def load_group_df(
|
||||
self,
|
||||
instruments,
|
||||
exprs: list,
|
||||
names: list,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
gp_name: str = None,
|
||||
) -> pd.DataFrame:
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
instruments = "all"
|
||||
@@ -174,7 +206,10 @@ class QlibDataLoader(DLWParser):
|
||||
elif self.filter_pipe is not None:
|
||||
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
|
||||
|
||||
df = D.features(instruments, exprs, start_time, end_time, self.freq)
|
||||
freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
|
||||
df = D.features(
|
||||
instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, [])
|
||||
)
|
||||
df.columns = names
|
||||
if self.swap_level:
|
||||
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
|
||||
@@ -199,6 +234,10 @@ class StaticDataLoader(DataLoader):
|
||||
self.join = join
|
||||
self._data = None
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
# avoid pickling `self._data`
|
||||
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
self._maybe_load_raw_data()
|
||||
if instruments is None:
|
||||
@@ -207,7 +246,10 @@ class StaticDataLoader(DataLoader):
|
||||
df = self._data.loc(axis=0)[:, instruments]
|
||||
if start_time is None and end_time is None:
|
||||
return df # NOTE: avoid copy by loc
|
||||
return df.loc[pd.Timestamp(start_time) : pd.Timestamp(end_time)]
|
||||
# pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
|
||||
start_time = time_to_slc_point(start_time)
|
||||
end_time = time_to_slc_point(end_time)
|
||||
return df.loc[start_time:end_time]
|
||||
|
||||
def _maybe_load_raw_data(self):
|
||||
if self._data is not None:
|
||||
|
||||
@@ -73,6 +73,14 @@ class Processor(Serializable):
|
||||
"""
|
||||
return True
|
||||
|
||||
def readonly(self) -> bool:
|
||||
"""
|
||||
Does the processor treat the input data readonly (i.e. does not write the input data) when processsing
|
||||
|
||||
Knowning the readonly information is helpful to the Handler to avoid uncessary copy
|
||||
"""
|
||||
return False
|
||||
|
||||
def config(self, **kwargs):
|
||||
attr_list = {"fit_start_time", "fit_end_time"}
|
||||
for k, v in kwargs.items():
|
||||
@@ -92,6 +100,9 @@ class DropnaProcessor(Processor):
|
||||
def __call__(self, df):
|
||||
return df.dropna(subset=get_group_columns(df, self.fields_group))
|
||||
|
||||
def readonly(self):
|
||||
return True
|
||||
|
||||
|
||||
class DropnaLabel(DropnaProcessor):
|
||||
def __init__(self, fields_group="label"):
|
||||
@@ -113,6 +124,9 @@ class DropCol(Processor):
|
||||
mask = df.columns.isin(self.col_list)
|
||||
return df.loc[:, ~mask]
|
||||
|
||||
def readonly(self):
|
||||
return True
|
||||
|
||||
|
||||
class FilterCol(Processor):
|
||||
def __init__(self, fields_group="feature", col_list=[]):
|
||||
@@ -128,6 +142,9 @@ class FilterCol(Processor):
|
||||
mask = df.columns.get_level_values(-1).isin(self.col_list)
|
||||
return df.loc[:, mask]
|
||||
|
||||
def readonly(self):
|
||||
return True
|
||||
|
||||
|
||||
class TanhProcess(Processor):
|
||||
"""Use tanh to process noise data"""
|
||||
|
||||
23
qlib/data/inst_processor.py
Normal file
23
qlib/data/inst_processor.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import abc
|
||||
import json
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class InstProcessor:
|
||||
@abc.abstractmethod
|
||||
def __call__(self, df: pd.DataFrame, *args, **kwargs):
|
||||
"""
|
||||
process the data
|
||||
|
||||
NOTE: **The processor could change the content of `df` inplace !!!!! **
|
||||
User should keep a copy of data outside
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The raw_df of handler or result from previous processor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.__class__.__name__}:{json.dumps(self.__dict__, sort_keys=True, default=str)}"
|
||||
@@ -10,10 +10,12 @@ import abc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Union, List, Type
|
||||
from scipy.stats import percentileofscore
|
||||
|
||||
from .base import Expression, ExpressionOps
|
||||
from ..log import get_module_logger
|
||||
from ..utils import get_callable_kwargs
|
||||
|
||||
try:
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
@@ -1495,16 +1497,34 @@ class OpsWrapper:
|
||||
def reset(self):
|
||||
self._ops = {}
|
||||
|
||||
def register(self, ops_list):
|
||||
for operator in ops_list:
|
||||
if not issubclass(operator, ExpressionOps):
|
||||
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(operator))
|
||||
def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]):
|
||||
"""register operator
|
||||
|
||||
if operator.__name__ in self._ops:
|
||||
Parameters
|
||||
----------
|
||||
ops_list : List[Union[Type[ExpressionOps], dict]]
|
||||
- if type(ops_list) is List[Type[ExpressionOps]], each element of ops_list represents the operator class, which should be the subclass of `ExpressionOps`.
|
||||
- if type(ops_list) is List[dict], each element of ops_list represents the config of operator, which has the following format:
|
||||
{
|
||||
"class": class_name,
|
||||
"module_path": path,
|
||||
}
|
||||
Note: `class` should be the class name of operator, `module_path` should be a python module or path of file.
|
||||
"""
|
||||
for _operator in ops_list:
|
||||
if isinstance(_operator, dict):
|
||||
_ops_class, _ = get_callable_kwargs(_operator)
|
||||
else:
|
||||
_ops_class = _operator
|
||||
|
||||
if not issubclass(_ops_class, ExpressionOps):
|
||||
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))
|
||||
|
||||
if _ops_class.__name__ in self._ops:
|
||||
get_module_logger(self.__class__.__name__).warning(
|
||||
"The custom operator [{}] will override the qlib default definition".format(operator.__name__)
|
||||
"The custom operator [{}] will override the qlib default definition".format(_ops_class.__name__)
|
||||
)
|
||||
self._ops[operator.__name__] = operator
|
||||
self._ops[_ops_class.__name__] = _ops_class
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key not in self._ops:
|
||||
|
||||
10
qlib/log.py
10
qlib/log.py
@@ -28,16 +28,18 @@ class QlibLogger(metaclass=MetaLogger):
|
||||
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
self.level = 0
|
||||
# this feature name conflicts with the attribute with Logger
|
||||
# rename it to avoid some corner cases that result in comparing `str` and `int`
|
||||
self.__level = 0
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
logger = logging.getLogger(self.module_name)
|
||||
logger.setLevel(self.level)
|
||||
logger.setLevel(self.__level)
|
||||
return logger
|
||||
|
||||
def setLevel(self, level):
|
||||
self.level = level
|
||||
self.__level = level
|
||||
|
||||
def __getattr__(self, name):
|
||||
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
|
||||
@@ -68,7 +70,7 @@ def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logge
|
||||
|
||||
class TimeInspector:
|
||||
|
||||
timer_logger = get_module_logger("timer", level=logging.WARNING)
|
||||
timer_logger = get_module_logger("timer", level=logging.INFO)
|
||||
|
||||
time_marks = []
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ class ModelFT(Model):
|
||||
|
||||
# Finetune model based on previous trained model
|
||||
with R.start(experiment_name="finetune model"):
|
||||
recorder = R.get_recorder(rid, experiment_name="init models")
|
||||
recorder = R.get_recorder(recorder_id=rid, experiment_name="init models")
|
||||
model = recorder.load_object("init_model")
|
||||
model.finetune(dataset, num_boost_round=10)
|
||||
|
||||
|
||||
@@ -105,6 +105,20 @@ class AverageEnsemble(Ensemble):
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
|
||||
"""using sample:
|
||||
from qlib.model.ens.ensemble import AverageEnsemble
|
||||
pred_res['new_key_name'] = AverageEnsemble()(predict_dict)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ensemble_dict : dict
|
||||
Dictionary you want to ensemble
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
The dictionary including ensenbling result
|
||||
"""
|
||||
# need to flatten the nested dict
|
||||
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
|
||||
values = list(ensemble_dict.values())
|
||||
|
||||
@@ -8,15 +8,16 @@ There are two steps in each Trainer including ``train``(make model recorder) and
|
||||
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
|
||||
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
|
||||
|
||||
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
"""
|
||||
|
||||
import socket
|
||||
from typing import Callable, List
|
||||
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
|
||||
from qlib.utils import flatten_dict, get_callable_kwargs, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.recorder import Recorder
|
||||
@@ -70,7 +71,7 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
|
||||
cls, kwargs = get_callable_kwargs(record, default_module="qlib.workflow.record_temp")
|
||||
if cls is SignalRecord:
|
||||
rconf = {"model": model, "dataset": dataset, "recorder": rec}
|
||||
else:
|
||||
@@ -151,6 +152,9 @@ class Trainer:
|
||||
"""
|
||||
return self.delay
|
||||
|
||||
def __call__(self, *args, **kwargs) -> list:
|
||||
return self.end_train(self.train(*args, **kwargs))
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""
|
||||
@@ -190,6 +194,8 @@ class TrainerR(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
@@ -213,6 +219,8 @@ class TrainerR(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
@@ -250,6 +258,8 @@ class DelayTrainerR(TrainerR):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
@@ -275,7 +285,12 @@ class TrainerRM(Trainer):
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
|
||||
# This tag is the _id in TaskManager to distinguish tasks.
|
||||
TM_ID = "_id in TaskManager"
|
||||
|
||||
def __init__(
|
||||
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
|
||||
):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
@@ -283,11 +298,16 @@ class TrainerRM(Trainer):
|
||||
experiment_name (str): the default name of experiment.
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
skip_run_task (bool):
|
||||
If skip_run_task == True:
|
||||
Only run_task in the worker. Otherwise skip run_task.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.train_func = train_func
|
||||
self.skip_run_task = skip_run_task
|
||||
|
||||
def train(
|
||||
self,
|
||||
@@ -315,6 +335,8 @@ class TrainerRM(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
@@ -326,19 +348,26 @@ class TrainerRM(Trainer):
|
||||
task_pool = experiment_name
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
if not self.skip_run_task:
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not self.is_delay():
|
||||
tm.wait(query=query)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
rec = tm.re_query(_id)["res"]
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
rec.set_tags(**{self.TM_ID: _id})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
@@ -352,10 +381,33 @@ class TrainerRM(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
def worker(
|
||||
self,
|
||||
train_func: Callable = None,
|
||||
experiment_name: str = None,
|
||||
):
|
||||
"""
|
||||
The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.
|
||||
|
||||
Args:
|
||||
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
"""
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
@@ -369,6 +421,7 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool: str = None,
|
||||
train_func=begin_task_train,
|
||||
end_train_func=end_task_train,
|
||||
skip_run_task: bool = False,
|
||||
):
|
||||
"""
|
||||
Init DelayTrainerRM.
|
||||
@@ -378,10 +431,15 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
skip_run_task (bool):
|
||||
If skip_run_task == True:
|
||||
Only run_task in the worker. Otherwise skip run_task.
|
||||
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
|
||||
"""
|
||||
super().__init__(experiment_name, task_pool, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
self.skip_run_task = skip_run_task
|
||||
|
||||
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
@@ -395,6 +453,8 @@ class DelayTrainerRM(TrainerRM):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
return super().train(
|
||||
@@ -410,8 +470,6 @@ class DelayTrainerRM(TrainerRM):
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them.
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
@@ -421,7 +479,8 @@ class DelayTrainerRM(TrainerRM):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
@@ -429,18 +488,45 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tasks = []
|
||||
_id_list = []
|
||||
for rec in recs:
|
||||
tasks.append(rec.load_object("task"))
|
||||
_id_list.append(rec.list_tags()[self.TM_ID])
|
||||
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
if not self.skip_run_task:
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
TaskManager(task_pool=task_pool).wait(query=query)
|
||||
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query={"filter": {"$in": tasks}}, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
def worker(self, end_train_func=None, experiment_name: str = None):
|
||||
"""
|
||||
The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.
|
||||
|
||||
Args:
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
"""
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool=task_pool,
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
)
|
||||
|
||||
@@ -43,17 +43,29 @@ RECORD_CONFIG = [
|
||||
]
|
||||
|
||||
|
||||
def get_data_handler_config(market=CSI300_MARKET):
|
||||
def get_data_handler_config(
|
||||
start_time="2008-01-01",
|
||||
end_time="2020-08-01",
|
||||
fit_start_time="2008-01-01",
|
||||
fit_end_time="2014-12-31",
|
||||
instruments=CSI300_MARKET,
|
||||
):
|
||||
return {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
"instruments": instruments,
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
|
||||
def get_dataset_config(
|
||||
dataset_class=DATASET_ALPHA158_CLASS,
|
||||
train=("2008-01-01", "2014-12-31"),
|
||||
valid=("2015-01-01", "2016-12-31"),
|
||||
test=("2017-01-01", "2020-08-01"),
|
||||
handler_kwargs={"instruments": CSI300_MARKET},
|
||||
):
|
||||
return {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
@@ -61,48 +73,88 @@ def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLAS
|
||||
"handler": {
|
||||
"class": dataset_class,
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": get_data_handler_config(market),
|
||||
"kwargs": get_data_handler_config(**handler_kwargs),
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
"train": train,
|
||||
"valid": valid,
|
||||
"test": test,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_gbdt_task(market=CSI300_MARKET):
|
||||
def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": GBDT_MODEL,
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
}
|
||||
|
||||
|
||||
def get_record_lgb_config(market=CSI300_MARKET):
|
||||
def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_record_xgboost_config(market=CSI300_MARKET):
|
||||
def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
|
||||
CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET})
|
||||
CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET})
|
||||
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET})
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET})
|
||||
|
||||
# use for rolling_online_managment.py
|
||||
ROLLING_HANDLER_CONFIG = {
|
||||
"start_time": "2013-01-01",
|
||||
"end_time": "2020-09-25",
|
||||
"fit_start_time": "2013-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": CSI100_MARKET,
|
||||
}
|
||||
ROLLING_DATASET_CONFIG = {
|
||||
"train": ("2013-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2015-12-31"),
|
||||
"test": ("2016-01-01", "2020-07-10"),
|
||||
}
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config(
|
||||
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
|
||||
)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config(
|
||||
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
|
||||
)
|
||||
|
||||
# use for online_management_simulate.py
|
||||
ONLINE_HANDLER_CONFIG = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": CSI100_MARKET,
|
||||
}
|
||||
ONLINE_DATASET_CONFIG = {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
}
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config(
|
||||
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
|
||||
)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config(
|
||||
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
|
||||
)
|
||||
|
||||
@@ -43,8 +43,9 @@ def get_redis_connection():
|
||||
|
||||
|
||||
#################### Data ####################
|
||||
def read_bin(file_path, start_index, end_index):
|
||||
with open(file_path, "rb") as f:
|
||||
def read_bin(file_path: Union[str, Path], start_index, end_index):
|
||||
file_path = Path(file_path.expanduser().resolve())
|
||||
with file_path.open("rb") as f:
|
||||
# read start_index
|
||||
ref_start_index = int(np.frombuffer(f.read(4), dtype="<f")[0])
|
||||
si = max(ref_start_index, start_index)
|
||||
@@ -189,9 +190,9 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]):
|
||||
return module
|
||||
|
||||
|
||||
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
|
||||
def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
|
||||
"""
|
||||
extract class and kwargs from config info
|
||||
extract class/func and kwargs from config info
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -206,22 +207,22 @@ def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleTy
|
||||
Returns
|
||||
-------
|
||||
(type, dict):
|
||||
the class object and it's arguments.
|
||||
the class/func object and it's arguments.
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
module = get_module_by_module_path(config.get("module_path", default_module))
|
||||
|
||||
# raise AttributeError
|
||||
klass = getattr(module, config["class"])
|
||||
_callable = getattr(module, config["class" if "class" in config else "func"])
|
||||
kwargs = config.get("kwargs", {})
|
||||
elif isinstance(config, str):
|
||||
module = get_module_by_module_path(default_module)
|
||||
|
||||
klass = getattr(module, config)
|
||||
_callable = getattr(module, config)
|
||||
kwargs = {}
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return klass, kwargs
|
||||
return _callable, kwargs
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
@@ -272,7 +273,7 @@ def init_instance_by_config(
|
||||
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
|
||||
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
|
||||
return klass(**cls_kwargs, **kwargs)
|
||||
|
||||
|
||||
@@ -570,9 +571,11 @@ def get_pre_trading_date(trading_date, future=False):
|
||||
|
||||
|
||||
def transform_end_date(end_date=None, freq="day"):
|
||||
"""get previous trading date
|
||||
"""handle the end date with various format
|
||||
|
||||
If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned.
|
||||
Otherwise, returns the end_date
|
||||
|
||||
----------
|
||||
end_date: str
|
||||
end trading date
|
||||
@@ -642,6 +645,28 @@ def split_pred(pred, number=None, split_date=None):
|
||||
return pred_left, pred_right
|
||||
|
||||
|
||||
def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:
|
||||
"""
|
||||
Time slicing in Qlib or Pandas is a frequently-used action.
|
||||
However, user often input all kinds of data format to represent time.
|
||||
This function will help user to convert these inputs into a uniform format which is friendly to time slicing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : Union[None, str, pd.Timestamp]
|
||||
original time
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[None, pd.Timestamp]:
|
||||
"""
|
||||
if t is None:
|
||||
# None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]).
|
||||
return t
|
||||
else:
|
||||
return pd.Timestamp(t)
|
||||
|
||||
|
||||
def can_use_cache():
|
||||
res = True
|
||||
r = get_redis_connection()
|
||||
@@ -716,7 +741,8 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
|
||||
sorted dataframe
|
||||
"""
|
||||
idx = df.index if axis == 0 else df.columns
|
||||
if idx.is_monotonic_increasing:
|
||||
# NOTE: MultiIndex.is_lexsorted() is a deprecated method in Pandas 1.3.0 and is suggested to be replaced by MultiIndex.is_monotonic_increasing (see discussion here: https://github.com/pandas-dev/pandas/issues/32259). However, in case older versions of Pandas is implemented, MultiIndex.is_lexsorted() is necessary to prevent certain fatal errors.
|
||||
if idx.is_monotonic_increasing and not (isinstance(idx, pd.MultiIndex) and not idx.is_lexsorted()):
|
||||
return df
|
||||
else:
|
||||
return df.sort_index(axis=axis)
|
||||
@@ -770,7 +796,7 @@ class Wrapper:
|
||||
return "{name}(provider={provider})".format(name=self.__class__.__name__, provider=self._provider)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if self._provider is None:
|
||||
if self.__dict__.get("_provider", None) is None:
|
||||
raise AttributeError("Please run qlib.init() first using qlib")
|
||||
return getattr(self._provider, key)
|
||||
|
||||
|
||||
17
qlib/utils/exceptions.py
Normal file
17
qlib/utils/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Base exception class
|
||||
class QlibException(Exception):
|
||||
def __init__(self, message):
|
||||
super(QlibException, self).__init__(message)
|
||||
|
||||
|
||||
# Error type for reinitialization when starting an experiment
|
||||
class RecorderInitializationError(QlibException):
|
||||
pass
|
||||
|
||||
|
||||
# Error type for Recorder when can not load object
|
||||
class LoadObjectError(QlibException):
|
||||
pass
|
||||
@@ -1,10 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import typing
|
||||
import dill
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
||||
@@ -18,6 +17,7 @@ class Serializable:
|
||||
|
||||
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
|
||||
default_dump_all = False # if dump all things
|
||||
FLAG_KEY = "_qlib_serial_flag"
|
||||
|
||||
def __init__(self):
|
||||
self._dump_all = self.default_dump_all
|
||||
@@ -45,8 +45,6 @@ class Serializable:
|
||||
"""
|
||||
return getattr(self, "_exclude", [])
|
||||
|
||||
FLAG_KEY = "_qlib_serial_flag"
|
||||
|
||||
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
|
||||
"""
|
||||
configure the serializable object
|
||||
@@ -92,16 +90,16 @@ class Serializable:
|
||||
@classmethod
|
||||
def load(cls, filepath):
|
||||
"""
|
||||
Load the collector from a filepath.
|
||||
Load the serializable class from a filepath.
|
||||
|
||||
Args:
|
||||
filepath (str): the path of file
|
||||
|
||||
Raises:
|
||||
TypeError: the pickled file must be `Collector`
|
||||
TypeError: the pickled file must be `type(cls)`
|
||||
|
||||
Returns:
|
||||
Collector: the instance of Collector
|
||||
`type(cls)`: the instance of `type(cls)`
|
||||
"""
|
||||
with open(filepath, "rb") as f:
|
||||
object = cls.get_backend().load(f)
|
||||
@@ -124,3 +122,22 @@ class Serializable:
|
||||
return dill
|
||||
else:
|
||||
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
|
||||
|
||||
@staticmethod
|
||||
def general_dump(obj, path: Union[Path, str]):
|
||||
"""
|
||||
A general dumping method for object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : object
|
||||
the object to be dumped
|
||||
path : Union[Path, str]
|
||||
the target path the data will be dumped
|
||||
"""
|
||||
path = Path(path)
|
||||
if isinstance(obj, Serializable):
|
||||
obj.to_pickle(path)
|
||||
else:
|
||||
with path.open("wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
|
||||
@@ -7,6 +7,7 @@ from .expm import MLflowExpManager
|
||||
from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
from ..utils import Wrapper
|
||||
from ..utils.exceptions import RecorderInitializationError
|
||||
|
||||
|
||||
class QlibRecorder:
|
||||
@@ -37,13 +38,13 @@ class QlibRecorder:
|
||||
.. code-block:: Python
|
||||
|
||||
# start new experiment and recorder
|
||||
with R.start('test', 'recorder_1'):
|
||||
with R.start(experiment_name='test', recorder_name='recorder_1'):
|
||||
model.fit(dataset)
|
||||
R.log...
|
||||
... # further operations
|
||||
|
||||
# resume previous experiment and recorder
|
||||
with R.start('test', 'recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder.
|
||||
with R.start(experiment_name='test', recorder_name='recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder.
|
||||
... # further operations
|
||||
|
||||
Parameters
|
||||
@@ -215,9 +216,9 @@ class QlibRecorder:
|
||||
-------
|
||||
A dictionary (id -> recorder) of recorder information that being stored.
|
||||
"""
|
||||
return self.get_exp(experiment_id, experiment_name).list_recorders()
|
||||
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
|
||||
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
|
||||
"""
|
||||
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
|
||||
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
|
||||
@@ -262,7 +263,7 @@ class QlibRecorder:
|
||||
|
||||
# Case 2
|
||||
with R.start('test'):
|
||||
exp = R.get_exp('test1')
|
||||
exp = R.get_exp(experiment_name='test1')
|
||||
|
||||
# Case 3
|
||||
exp = R.get_exp() -> a default experiment.
|
||||
@@ -287,7 +288,9 @@ class QlibRecorder:
|
||||
-------
|
||||
An experiment instance with given id or name.
|
||||
"""
|
||||
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
|
||||
return self.exp_manager.get_exp(
|
||||
experiment_id=experiment_id, experiment_name=experiment_name, create=create, start=False
|
||||
)
|
||||
|
||||
def delete_exp(self, experiment_id=None, experiment_name=None):
|
||||
"""
|
||||
@@ -331,7 +334,9 @@ class QlibRecorder:
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
|
||||
def get_recorder(
|
||||
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
|
||||
) -> Recorder:
|
||||
"""
|
||||
Method for retrieving a recorder.
|
||||
|
||||
@@ -384,7 +389,7 @@ class QlibRecorder:
|
||||
-------
|
||||
A recorder instance.
|
||||
"""
|
||||
return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
|
||||
return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder(
|
||||
recorder_id, recorder_name, create=False, start=False
|
||||
)
|
||||
|
||||
@@ -525,14 +530,29 @@ class QlibRecorder:
|
||||
self.get_exp().get_recorder().set_tags(**kwargs)
|
||||
|
||||
|
||||
class RecorderWrapper(Wrapper):
|
||||
"""
|
||||
Wrapper class for QlibRecorder, which detects whether users reinitialize qlib when already starting an experiment.
|
||||
"""
|
||||
|
||||
def register(self, provider):
|
||||
if self._provider is not None:
|
||||
expm = getattr(self._provider, "exp_manager")
|
||||
if expm.active_experiment is not None:
|
||||
raise RecorderInitializationError(
|
||||
"Please don't reinitialize Qlib if QlibRecorder is already acivated. Otherwise, the experiment stored location will be modified."
|
||||
)
|
||||
self._provider = provider
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated
|
||||
|
||||
QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper]
|
||||
QlibRecorderWrapper = Annotated[QlibRecorder, RecorderWrapper]
|
||||
else:
|
||||
QlibRecorderWrapper = QlibRecorder
|
||||
|
||||
# global record
|
||||
R: QlibRecorderWrapper = Wrapper()
|
||||
R: QlibRecorderWrapper = RecorderWrapper()
|
||||
|
||||
@@ -53,7 +53,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder)
|
||||
qlib.init(**config.get("qlib_init"), exp_manager=exp_manager)
|
||||
|
||||
task_train(config.get("task"), experiment_name=experiment_name)
|
||||
recorder = task_train(config.get("task"), experiment_name=experiment_name)
|
||||
recorder.save_objects(config=config)
|
||||
|
||||
|
||||
# function to run worklflow by config
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union
|
||||
import mlflow, logging
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
@@ -213,11 +214,15 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_get_recorder` method")
|
||||
|
||||
def list_recorders(self):
|
||||
def list_recorders(self, **flt_kwargs):
|
||||
"""
|
||||
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
|
||||
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
|
||||
|
||||
flt_kwargs : dict
|
||||
filter recorders by conditions
|
||||
e.g. list_recorders(status=Recorder.STATUS_FI)
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary (id -> recorder) of recorder information that being stored.
|
||||
@@ -320,11 +325,25 @@ class MLflowExperiment(Experiment):
|
||||
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED):
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
|
||||
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None, filter_string: str = ""):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
max_results : int
|
||||
the number limitation of the results
|
||||
status : str
|
||||
the criteria based on status to filter results.
|
||||
`None` indicates no filtering.
|
||||
filter_string : str
|
||||
mlflow supported filter string like 'params."my_param"="a" and tags."my_tag"="b"', use this will help to reduce too much run number.
|
||||
"""
|
||||
runs = self._client.search_runs(
|
||||
self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results, filter_string=filter_string
|
||||
)
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
recorders[runs[i].info.run_id] = recorder
|
||||
if status is None or recorder.status == status:
|
||||
recorders[runs[i].info.run_id] = recorder
|
||||
|
||||
return recorders
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from urllib.parse import urlparse
|
||||
import mlflow
|
||||
from filelock import FileLock
|
||||
from mlflow.exceptions import MlflowException
|
||||
from mlflow.entities import ViewType
|
||||
import os, logging
|
||||
@@ -109,7 +111,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
|
||||
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
|
||||
"""
|
||||
Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.
|
||||
|
||||
@@ -190,7 +192,14 @@ class ExpManager:
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
|
||||
# NOTE: mlflow doesn't consider the lock for recording multiple runs
|
||||
# So we supported it in the interface wrapper
|
||||
pr = urlparse(self.uri)
|
||||
if pr.scheme == "file":
|
||||
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")) as f:
|
||||
return self.create_exp(experiment_name), True
|
||||
return self.create_exp(experiment_name), True
|
||||
|
||||
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
|
||||
@@ -352,6 +361,8 @@ class MLflowExpManager(ExpManager):
|
||||
), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder."
|
||||
if experiment_id is not None:
|
||||
try:
|
||||
# NOTE: the mlflow's experiment_id must be str type...
|
||||
# https://www.mlflow.org/docs/latest/python_api/mlflow.tracking.html#mlflow.tracking.MlflowClient.get_experiment
|
||||
exp = self.client.get_experiment(experiment_id)
|
||||
if exp.lifecycle_stage.upper() == "DELETED":
|
||||
raise MlflowException("No valid experiment has been found.")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user