mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 00:51:19 +08:00
Compare commits
253 Commits
v0.6.0
...
high-freq-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56edc16089 | ||
|
|
2b8462d137 | ||
|
|
1979cac50a | ||
|
|
424a48d0fb | ||
|
|
202bbea272 | ||
|
|
6a22136366 | ||
|
|
603c282415 | ||
|
|
22abe852f7 | ||
|
|
e3f463010b | ||
|
|
80aa08215f | ||
|
|
b3893067f7 | ||
|
|
e6dfccce2f | ||
|
|
f9c30f9834 | ||
|
|
f164bf8411 | ||
|
|
1f28044d84 | ||
|
|
3cf0d27a07 | ||
|
|
bcae4bb22e | ||
|
|
f680a564a0 | ||
|
|
828993b397 | ||
|
|
8ef89b4fa8 | ||
|
|
76cf9dad99 | ||
|
|
f3eb02a0bd | ||
|
|
ffa68fd010 | ||
|
|
f6dd006c35 | ||
|
|
9cd41e5a81 | ||
|
|
e23022e9d8 | ||
|
|
ebbbec2a6c | ||
|
|
13d39e6bbc | ||
|
|
b96aab6bef | ||
|
|
700eef4164 | ||
|
|
31c7d72485 | ||
|
|
30ad1967a2 | ||
|
|
0c6cad1d7b | ||
|
|
a0f22571de | ||
|
|
6835b2f67e | ||
|
|
7c4971e566 | ||
|
|
70a9d42c7d | ||
|
|
bcadf47f32 | ||
|
|
4dc14a2489 | ||
|
|
a03b08bb4c | ||
|
|
98086e4fdc | ||
|
|
8c29105bca | ||
|
|
948b829ff4 | ||
|
|
304a0c3d7a | ||
|
|
02dea2aeb6 | ||
|
|
6fc4f2b249 | ||
|
|
2a5f06ee9e | ||
|
|
7f9216dc90 | ||
|
|
263ccdfe6f | ||
|
|
1a8f1bfc57 | ||
|
|
9dc11a9e3c | ||
|
|
3bdd54308b | ||
|
|
1b569d371d | ||
|
|
36e5c601de | ||
|
|
ae45711e2b | ||
|
|
bcc47aa4cb | ||
|
|
ee94634b23 | ||
|
|
2016ebbbb2 | ||
|
|
1eaf09cce1 | ||
|
|
7579f4b4c0 | ||
|
|
1a1c45981c | ||
|
|
e4ecea55e4 | ||
|
|
58616fced9 | ||
|
|
8e9ca22b07 | ||
|
|
6a145df87c | ||
|
|
06dbd02b99 | ||
|
|
ffedb6382f | ||
|
|
3f9f295a87 | ||
|
|
84d77f4585 | ||
|
|
afdf58b4fa | ||
|
|
2b6d16feb1 | ||
|
|
0a86a6f392 | ||
|
|
5da5ad4b9f | ||
|
|
dd07810b66 | ||
|
|
a762248d98 | ||
|
|
80c9a47e51 | ||
|
|
784e73bceb | ||
|
|
5ad1b4cc33 | ||
|
|
e85646762c | ||
|
|
fc81a39317 | ||
|
|
d44c5bb2b2 | ||
|
|
c622d3f6f8 | ||
|
|
6daaa79519 | ||
|
|
3dda2cb379 | ||
|
|
4fcfde7cfb | ||
|
|
3403c00b6b | ||
|
|
ecdfe49fd1 | ||
|
|
cc214a3462 | ||
|
|
65d8af41e7 | ||
|
|
0e0970f06e | ||
|
|
917261dbf6 | ||
|
|
6a9105e065 | ||
|
|
570bb272eb | ||
|
|
0524a47cf4 | ||
|
|
9abc0b0d4f | ||
|
|
fe60e40927 | ||
|
|
740c297618 | ||
|
|
b4a088efe8 | ||
|
|
b34890772f | ||
|
|
054ffa29f6 | ||
|
|
74e08c9e37 | ||
|
|
ea96c9e22d | ||
|
|
86e7c44c6b | ||
|
|
64cf2e2df8 | ||
|
|
4361a4049a | ||
|
|
231f37376b | ||
|
|
328cdeda4a | ||
|
|
4dbc8e52ec | ||
|
|
ba447d3448 | ||
|
|
df556532d0 | ||
|
|
18e040f506 | ||
|
|
aefc98b1d7 | ||
|
|
46c8d791ac | ||
|
|
afcd91a2d0 | ||
|
|
4a30d9d1ec | ||
|
|
2da2e9bd9e | ||
|
|
3e6877ff0f | ||
|
|
a0f32036a6 | ||
|
|
d8f36df7f4 | ||
|
|
cb3b6c5bde | ||
|
|
b11712fa54 | ||
|
|
660edeb94f | ||
|
|
95de4088df | ||
|
|
e8d7a22651 | ||
|
|
4a62b929ad | ||
|
|
5efe82fb56 | ||
|
|
40bbafcaab | ||
|
|
4c4f0f3c5e | ||
|
|
ae0e0eca3d | ||
|
|
7e37fa710a | ||
|
|
e0c460c33c | ||
|
|
53f501ac19 | ||
|
|
132df027a5 | ||
|
|
7d97fd39ce | ||
|
|
995fa98fc6 | ||
|
|
824de921d1 | ||
|
|
66d9bd1a68 | ||
|
|
1c0bb2f827 | ||
|
|
ea018ed4dc | ||
|
|
f3f1867b14 | ||
|
|
8bbfd8810c | ||
|
|
3f84c3768a | ||
|
|
7372a3a598 | ||
|
|
4b4cd38ca6 | ||
|
|
7d40ba753a | ||
|
|
9b60214e0c | ||
|
|
f7e775f941 | ||
|
|
aefbf3b5f1 | ||
|
|
3f85af05e5 | ||
|
|
192c2dc5ef | ||
|
|
911edd7839 | ||
|
|
3d47dd78c8 | ||
|
|
8f6ab0af54 | ||
|
|
cb0b6fcdaa | ||
|
|
6b8824dd29 | ||
|
|
c217e7c479 | ||
|
|
ea4fe1577b | ||
|
|
1bab07e419 | ||
|
|
422d1d8c93 | ||
|
|
c8f9b1162d | ||
|
|
e2bdef7ffe | ||
|
|
e49b590322 | ||
|
|
9d19294f15 | ||
|
|
b0e7a85601 | ||
|
|
8ea45802df | ||
|
|
bba94d72dc | ||
|
|
d6dd423dc2 | ||
|
|
c10955d026 | ||
|
|
d642c7b6ea | ||
|
|
9307bcc8d1 | ||
|
|
99f3820e42 | ||
|
|
b04d2c39c8 | ||
|
|
0cdc5e125a | ||
|
|
2de812f262 | ||
|
|
16450c2876 | ||
|
|
729b57e4a7 | ||
|
|
87cc52cd05 | ||
|
|
0be57d51be | ||
|
|
9c482ebbe2 | ||
|
|
eb67f1037a | ||
|
|
59282c8965 | ||
|
|
03ab67ad5c | ||
|
|
e2d862bfb2 | ||
|
|
936d5abb1f | ||
|
|
7296780149 | ||
|
|
97c053ba73 | ||
|
|
2c5864204e | ||
|
|
6562c9aaa4 | ||
|
|
85a217c121 | ||
|
|
f156280a51 | ||
|
|
e8eb034a97 | ||
|
|
7763cf5a5c | ||
|
|
053736c0ea | ||
|
|
74ac230edb | ||
|
|
303021cd47 | ||
|
|
c0f1696adb | ||
|
|
361d168890 | ||
|
|
73669de392 | ||
|
|
89ec87e45b | ||
|
|
15cdfeb121 | ||
|
|
1bbd026195 | ||
|
|
a5c098de92 | ||
|
|
a63ba3e819 | ||
|
|
56e579e20f | ||
|
|
2873813562 | ||
|
|
a8ac56a82f | ||
|
|
6ef339b1ec | ||
|
|
579caa757c | ||
|
|
a1e579ff39 | ||
|
|
217019a640 | ||
|
|
c14404afe1 | ||
|
|
4596a7e000 | ||
|
|
ec40845513 | ||
|
|
dcfa8110e8 | ||
|
|
666e1ffcbd | ||
|
|
70fb760830 | ||
|
|
4a748525bc | ||
|
|
fb4a2e65cc | ||
|
|
71ad651514 | ||
|
|
65a9a72a88 | ||
|
|
ec0d7838ac | ||
|
|
752f17e51e | ||
|
|
8d42092a7e | ||
|
|
412c9eee2e | ||
|
|
abb90ca2f6 | ||
|
|
a7c6aea386 | ||
|
|
a88697151a | ||
|
|
d2107c9957 | ||
|
|
65902e424c | ||
|
|
bf8de72605 | ||
|
|
60f62482b7 | ||
|
|
d2d865fb7a | ||
|
|
5d5f8c8868 | ||
|
|
d093afd684 | ||
|
|
46396c229a | ||
|
|
eef90c7901 | ||
|
|
895b1e7944 | ||
|
|
2fb7774927 | ||
|
|
86b0b63771 | ||
|
|
99adc514a5 | ||
|
|
07fb9031c6 | ||
|
|
f237a344c3 | ||
|
|
2cb888c8b9 | ||
|
|
ab762b3cd7 | ||
|
|
703ae5d4aa | ||
|
|
91c3dfddf5 | ||
|
|
745b93138d | ||
|
|
7f385345bb | ||
|
|
d109d3d44e | ||
|
|
a2603fe27a | ||
|
|
e5590de2a4 | ||
|
|
77884db3a5 | ||
|
|
bb5f3cb33d |
5
.github/ISSUE_TEMPLATE/bug-report.md
vendored
5
.github/ISSUE_TEMPLATE/bug-report.md
vendored
@@ -28,7 +28,8 @@ Steps to reproduce the behavior:
|
||||
|
||||
## Environment
|
||||
|
||||
**Note**: One could run `python scripts/collect_info.py` under the `qlib` directory to get the following information.
|
||||
**Note**: User could run `cd scripts && python collect_info.py all` under project directory to get system information
|
||||
and paste them here directly.
|
||||
|
||||
- Qlib version:
|
||||
- Python version:
|
||||
@@ -37,4 +38,4 @@ Steps to reproduce the behavior:
|
||||
|
||||
## Additional Notes
|
||||
|
||||
<!-- Add any other information about the problem here. -->
|
||||
<!-- Add any other information about the problem here. -->
|
||||
|
||||
62
.github/stale.yml
vendored
Normal file
62
.github/stale.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
# Configuration for probot-stale - https://github.com/probot/stale
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request becomes stale
|
||||
daysUntilStale: 60
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
|
||||
# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
|
||||
daysUntilClose: 7
|
||||
|
||||
# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
|
||||
onlyLabels: []
|
||||
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
exemptLabels:
|
||||
- bug
|
||||
- pinned
|
||||
- security
|
||||
- "[Status] Maybe Later"
|
||||
|
||||
# Set to true to ignore issues in a project (defaults to false)
|
||||
exemptProjects: false
|
||||
|
||||
# Set to true to ignore issues in a milestone (defaults to false)
|
||||
exemptMilestones: false
|
||||
|
||||
# Set to true to ignore issues with an assignee (defaults to false)
|
||||
exemptAssignees: false
|
||||
|
||||
# Label to use when marking as stale
|
||||
staleLabel: wontfix
|
||||
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Thank you
|
||||
for your contributions.
|
||||
|
||||
# Comment to post when removing the stale label.
|
||||
# unmarkComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Comment to post when closing a stale Issue or Pull Request.
|
||||
# closeComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Limit the number of actions per hour, from 1-30. Default is 30
|
||||
limitPerRun: 30
|
||||
|
||||
# Limit to only `issues` or `pulls`
|
||||
# only: issues
|
||||
|
||||
# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
|
||||
# pulls:
|
||||
# daysUntilStale: 30
|
||||
# markComment: >
|
||||
# This pull request has been automatically marked as stale because it has not had
|
||||
# recent activity. It will be closed if no further activity occurs. Thank you
|
||||
# for your contributions.
|
||||
|
||||
# issues:
|
||||
# exemptLabels:
|
||||
# - confirmed
|
||||
99
.github/workflows/test.yml
vendored
99
.github/workflows/test.yml
vendored
@@ -12,8 +12,8 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -23,37 +23,96 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
pip install --upgrade cython
|
||||
pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
python setup.py install
|
||||
cd ..
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install black
|
||||
$CONDA\\python.exe -m black qlib -l 120 --check --diff
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install black
|
||||
$CONDA/bin/python -m black qlib -l 120 --check --diff
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
# Test Qlib installed with pip
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml --user
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
else
|
||||
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
$CONDA\\python.exe -m pip uninstall -y pyqlib
|
||||
else
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
sudo $CONDA/bin/python -m pip uninstall -y pyqlib
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade cython
|
||||
$CONDA\\python.exe -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
$CONDA\\python.exe -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
$CONDA\\python.exe setup.py install
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade cython
|
||||
sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
sudo $CONDA/bin/python setup.py install
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install black pytest
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
python -m black qlib -l 120 --check --diff
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade pip
|
||||
$CONDA\\python.exe -m pip install black pytest
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade pip
|
||||
sudo $CONDA/bin/python -m pip install black pytest
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
pytest . --durations=0
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pytest . --durations=0
|
||||
else
|
||||
$CONDA/bin/python -m pytest . --durations=0
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Test data downloads
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Test workflow by config
|
||||
run: |
|
||||
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
else
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
fi
|
||||
shell: bash
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
__pycache__/
|
||||
|
||||
*.pyc
|
||||
*.pyd
|
||||
*.so
|
||||
*.ipynb
|
||||
.ipynb_checkpoints
|
||||
|
||||
21
.readthedocs.yml
Normal file
21
.readthedocs.yml
Normal file
@@ -0,0 +1,21 @@
|
||||
# .readthedocs.yml
|
||||
# Read the Docs configuration file
|
||||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||
|
||||
# Required
|
||||
version: 2
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
configuration: docs/conf.py
|
||||
|
||||
# Build all formats
|
||||
formats: all
|
||||
|
||||
# Optionally set the version of Python and requirements required to build your docs
|
||||
python:
|
||||
version: 3.7
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- method: setuptools
|
||||
path: .
|
||||
@@ -114,7 +114,7 @@ Version 0.4.1
|
||||
Version 0.4.2
|
||||
--------------------
|
||||
- Refactor DataHandler
|
||||
- Add ``ALPHA360`` DataHandler
|
||||
- Add ``Alpha360`` DataHandler
|
||||
|
||||
|
||||
Version 0.4.3
|
||||
|
||||
102
README.md
102
README.md
@@ -31,9 +31,11 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
- [Run a single model](#run-a-single-model)
|
||||
- [Run multiple models](#run-multiple-models)
|
||||
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
|
||||
- [High-frequency execution](#high-frequency-execution)
|
||||
- [More About Qlib](#more-about-qlib)
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
|
||||
- [Related Reports](#related-reports)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
|
||||
@@ -61,17 +63,36 @@ At the module level, Qlib is a platform that consists of the above components. T
|
||||
|
||||
This quick start guide tries to demonstrate
|
||||
1. It's very easy to build a complete Quant research workflow and try your ideas with _Qlib_.
|
||||
1. Though with *public data* and *simple models*, machine learning technologies **work very well** in practical Quant investment.
|
||||
2. Though with *public data* and *simple models*, machine learning technologies **work very well** in practical Quant investment.
|
||||
|
||||
Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how to install ``Qlib``, and run LightGBM with ``qrun``. **But**, please make sure you have already prepared the data following the [instruction](#data-preparation).
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
Users can easily install ``Qlib`` by pip according to the following command
|
||||
This table demonstrates the supported Python version of `Qlib`:
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:----:|
|
||||
| Python 3.6 | :heavy_check_mark: | :heavy_check_mark: (only with `Anaconda`) | :heavy_check_mark: |
|
||||
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
|
||||
**Note**:
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
|
||||
### Install with pip
|
||||
Users can easily install ``Qlib`` by pip according to the following command.
|
||||
|
||||
```bash
|
||||
pip install pyqlib
|
||||
```
|
||||
|
||||
Also, users can install ``Qlib`` by the source code according to the following steps:
|
||||
**Note**: pip will install the latest stable qlib. However, the main branch of qlib is in active development. If you want to test the latest scripts or functions in the main branch. Please install qlib with the methods below.
|
||||
|
||||
### Install from source
|
||||
Also, users can install the latest dev version ``Qlib`` by the source code according to the following steps:
|
||||
|
||||
* Before installing ``Qlib`` from source, users need to install some dependencies:
|
||||
|
||||
@@ -80,13 +101,20 @@ Also, users can install ``Qlib`` by the source code according to the following s
|
||||
pip install --upgrade cython
|
||||
```
|
||||
|
||||
* Clone the repository and install ``Qlib``:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
python setup.py install
|
||||
```
|
||||
* Clone the repository and install ``Qlib`` as follows.
|
||||
* If you haven't installed qlib by the command ``pip install pyqlib`` before:
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
python setup.py install
|
||||
```
|
||||
* If you have already installed the stable version by the command ``pip install pyqlib``:
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
pip install .
|
||||
```
|
||||
**Note**: **Only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
|
||||
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test.yml) may help you find the problem.
|
||||
|
||||
## Data Preparation
|
||||
Load and prepare data by running the following code:
|
||||
@@ -130,12 +158,16 @@ Users could create the same dataset with it.
|
||||
## Auto Quant Research Workflow
|
||||
Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||
|
||||
1. Quant Research Workflow: Run `qrun` with lightgbm workflow config ([workflow_config_lightgbm.yaml](examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml)) as following.
|
||||
1. Quant Research Workflow: Run `qrun` with lightgbm workflow config ([workflow_config_lightgbm_Alpha158.yaml](examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml) as following.
|
||||
```bash
|
||||
cd examples # Avoid running program under the directory contains `qlib`
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
The result of `qrun` is as follows, please refer to please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
|
||||
If users want to use `qrun` under debug mode, please use the following command:
|
||||
```bash
|
||||
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
The result of `qrun` is as follows, please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
|
||||
|
||||
```bash
|
||||
|
||||
@@ -153,9 +185,6 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
annualized_return 0.128982
|
||||
information_ratio 1.444287
|
||||
max_drawdown -0.091078
|
||||
|
||||
|
||||
|
||||
```
|
||||
Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
|
||||
|
||||
@@ -193,19 +222,22 @@ The automatic workflow may not suite the research workflow of all Quant research
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on LightGBM](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost](qlib/contrib/model/catboost_model.py)
|
||||
- [GBDT based on XGBoost](qlib/contrib/model/xgboost.py)
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. 2016)](qlib/contrib/model/xgboost.py)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al. 2017)](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. 2017)](qlib/contrib/model/catboost_model.py)
|
||||
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
|
||||
- [GRU based on pytorch](qlib/contrib/model/pytorch_gru.py)
|
||||
- [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [ALSTM based on pytorcn](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch](qlib/contrib/model/pytorch_sfm.py)
|
||||
<!-- - [TFT based on tensorflow](examples/benchmarks/TFT/tft.py) -->
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al. 1997)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](qlib/contrib/model/pytorch_gru.py)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al. 2017)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
The performance of each model on the `Alpha158` and `Alpha360` dataset can be found [here](examples/benchmarks/README.md).
|
||||
|
||||
## Run a single model
|
||||
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
|
||||
|
||||
@@ -216,9 +248,9 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
|
||||
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only supprots *Linux* now. Other OS will be supported in the future.)
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
|
||||
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored. (**Note**: the script will erase your previous experiment records created by running itself.)
|
||||
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.
|
||||
|
||||
Here is an example of running all the models for 10 iterations:
|
||||
```python
|
||||
@@ -229,16 +261,24 @@ It also provides the API to run specific models at once. For more use cases, ple
|
||||
|
||||
|
||||
# Quant Dataset Zoo
|
||||
Dataset plays a very important role in Quant. Here is a list of the datasets built on `Qlib`.
|
||||
Dataset plays a very important role in Quant. Here is a list of the datasets built on `Qlib`:
|
||||
|
||||
| Dataset | US Market | China Market |
|
||||
| -- | -- | -- |
|
||||
| [Alpha360](./qlib/contrib/data/handler.py) | √ | √ |
|
||||
| [Alpha158](./qlib/contrib/data/handler.py) | √ | √ |
|
||||
| [Alpha158](./qlib/contrib/data/handler.py) | √ | √ |
|
||||
|
||||
[Here](https://qlib.readthedocs.io/en/latest/advanced/alpha.html) is a tutorial to build dataset with `Qlib`.
|
||||
Your PR to build new Quant dataset is highly welcomed.
|
||||
|
||||
# High-Frequency Execution
|
||||
High-frequency order execution is a fundamental problem in quantitative finance.
|
||||
It aims at fulfilling a specific trading order, either liquidation or acquirement, for a given instrument.
|
||||
AI has the potential to mine patterns from a huge mass of high-frequency market data and helps traders make better decisions during order execution.
|
||||
Here is a list of solutions built on `Qlib`.
|
||||
- [Universal Trading for Order Execution with Oracle Policy Distillation](examples/trade/)
|
||||
|
||||
|
||||
# More About Qlib
|
||||
The detailed documents are organized in [docs](docs/).
|
||||
[Sphinx](http://www.sphinx-doc.org) and the readthedocs theme is required to build the documentation in html formats.
|
||||
@@ -281,7 +321,11 @@ Such overheads greatly slow down the data loading process.
|
||||
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
|
||||
|
||||
|
||||
|
||||
# Related Reports
|
||||
- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/)
|
||||
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
|
||||
- [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
|
||||
- [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
|
||||
|
||||
|
||||
# Contributing
|
||||
|
||||
12
docs/_static/demo.sh
vendored
Normal file
12
docs/_static/demo.sh
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
#!/bin/sh
|
||||
git clone https://github.com/microsoft/qlib.git
|
||||
cd qlib
|
||||
ls
|
||||
pip install pyqlib
|
||||
# or
|
||||
# pip install numpy
|
||||
# pip install --upgrade cython
|
||||
# python setup.py install
|
||||
cd examples
|
||||
ls
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
@@ -50,57 +50,37 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data.dataset.handler import QLibDataHandler
|
||||
>> from qlib.data.dataset.loader import QlibDataLoader
|
||||
>> MACD_EXP = '(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'
|
||||
>> fields = [MACD_EXP] # MACD
|
||||
>> names = ['MACD']
|
||||
>> labels = ['$close'] # label
|
||||
>> labels = ['Ref($close, -2)/Ref($close, -1) - 1'] # label
|
||||
>> label_names = ['LABEL']
|
||||
>> data_handler = QLibDataHandler(start_date='2010-01-01', end_date='2017-12-31', fields=fields, names=names, labels=labels, label_names=label_names)
|
||||
>> TRAINER_CONFIG = {
|
||||
.. "train_start_date": "2007-01-01",
|
||||
.. "train_end_date": "2014-12-31",
|
||||
.. "validate_start_date": "2015-01-01",
|
||||
.. "validate_end_date": "2016-12-31",
|
||||
.. "test_start_date": "2017-01-01",
|
||||
.. "test_end_date": "2020-08-01",
|
||||
>> data_loader_config = {
|
||||
.. "feature": (fields, names),
|
||||
.. "label": (labels, label_names)
|
||||
.. }
|
||||
>> feature_train, label_train, feature_validate, label_validate, feature_test, label_test = data_handler.get_split_data(**TRAINER_CONFIG)
|
||||
>> print(feature_train, label_train)
|
||||
MACD
|
||||
instrument datetime
|
||||
SH600000 2010-01-04 -0.008625
|
||||
2010-01-05 -0.007234
|
||||
2010-01-06 -0.007693
|
||||
2010-01-07 -0.009633
|
||||
2010-01-08 -0.009891
|
||||
... ...
|
||||
SZ300251 2014-12-25 0.043072
|
||||
2014-12-26 0.041345
|
||||
2014-12-29 0.042733
|
||||
2014-12-30 0.042066
|
||||
2014-12-31 0.036299
|
||||
|
||||
[322025 rows x 1 columns]
|
||||
LABEL
|
||||
instrument datetime
|
||||
SH600000 2010-01-04 4.260015
|
||||
2010-01-05 4.292182
|
||||
2010-01-06 4.207747
|
||||
2010-01-07 4.113258
|
||||
2010-01-08 4.159496
|
||||
... ...
|
||||
SZ300251 2014-12-25 4.343212
|
||||
2014-12-26 4.470587
|
||||
2014-12-29 4.762474
|
||||
2014-12-30 4.369748
|
||||
2014-12-31 4.182222
|
||||
|
||||
[322025 rows x 1 columns]
|
||||
>> data_loader = QlibDataLoader(config=data_loader_config)
|
||||
>> df = data_loader.load(instruments='csi300', start_time='2010-01-01', end_time='2017-12-31')
|
||||
>> print(df)
|
||||
feature label
|
||||
MACD LABEL
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 -0.011547 -0.019672
|
||||
SH600004 0.002745 -0.014721
|
||||
SH600006 0.010133 0.002911
|
||||
SH600008 -0.001113 0.009818
|
||||
SH600009 0.025878 -0.017758
|
||||
... ... ...
|
||||
2017-12-29 SZ300124 0.007306 -0.005074
|
||||
SZ300136 -0.013492 0.056352
|
||||
SZ300144 -0.000966 0.011853
|
||||
SZ300251 0.004383 0.021739
|
||||
SZ300315 -0.030557 0.012455
|
||||
|
||||
Reference
|
||||
===========
|
||||
|
||||
To learn more about ``Data Handler``, please refer to `Data Handler <../component/data.html>`_
|
||||
To learn more about ``Data Loader``, please refer to `Data Loader <../component/data.html#data-loader>`_
|
||||
|
||||
To learn more about ``Data API``, please refer to `Data API <../component/data.html>`_
|
||||
|
||||
@@ -126,17 +126,17 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
The arguments of `--include_fields` should correspond with the column names of CSV files. The columns names of dataset provided by ``Qlib`` should include open, close, high, low, volume and factor at least.
|
||||
|
||||
- `open`
|
||||
The opening price
|
||||
The adjusted opening price
|
||||
- `close`
|
||||
The closing price
|
||||
The adjusted closing price
|
||||
- `high`
|
||||
The highest price
|
||||
The adjusted highest price
|
||||
- `low`
|
||||
The lowest price
|
||||
The adjusted lowest price
|
||||
- `volume`
|
||||
The trading volume
|
||||
The adjusted trading volume
|
||||
- `factor`
|
||||
The Restoration factor
|
||||
The Restoration factor. Normally, ``factor = adjusted_price / original_price``, `adjusted price` reference: `split adjusted <https://www.investopedia.com/terms/s/splitadjusted.asp>`_
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
|
||||
@@ -195,6 +195,7 @@ Feature
|
||||
- `ExpressionOps`
|
||||
`ExpressionOps` will use operator for feature construction.
|
||||
To know more about ``Operator``, please refer to `Operator API <../reference/api.html#module-qlib.data.ops>`_.
|
||||
Also, ``Qlib`` supports users to define their own custom ``Operator``, an example has been given in ``tests/test_register_ops.py``.
|
||||
|
||||
To know more about ``Feature``, please refer to `Feature API <../reference/api.html#module-qlib.data.base>`_.
|
||||
|
||||
@@ -295,6 +296,7 @@ The ``Processor`` module in ``Qlib`` is designed to be learnable and it is respo
|
||||
- ``RobustZScoreNorm``: `processor` that applies robust z-score normalization.
|
||||
- ``CSZScoreNorm``: `processor` that applies cross sectional z-score normalization.
|
||||
- ``CSRankNorm``: `processor` that applies cross sectional rank normalization.
|
||||
- ``CSZFillna``: `processor` that fills N/A values in a cross sectional way by the mean of the column.
|
||||
|
||||
Users can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`_).
|
||||
|
||||
|
||||
@@ -34,8 +34,9 @@ Here is a general view of the structure of the system:
|
||||
- Recorder 2
|
||||
- ...
|
||||
- ...
|
||||
This experiment management system defines a set of interface and provided a concrete implementation based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
|
||||
If users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, pleaes refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.
|
||||
|
||||
Qlib Recorder
|
||||
===================
|
||||
@@ -91,7 +92,7 @@ Record Template
|
||||
|
||||
The ``RecordTemp`` class is a class that enables generate experiment results such as IC and backtest in a certain format. We have provided three different `Record Template` class:
|
||||
|
||||
- ``SignalRecord``: This class generates the `preidction` results of the model.
|
||||
- ``SignalRecord``: This class generates the `prediction` results of the model.
|
||||
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
|
||||
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
|
||||
|
||||
|
||||
@@ -103,6 +103,12 @@ After saving the config into `configuration.yaml`, users could start the workflo
|
||||
|
||||
qrun configuration.yaml
|
||||
|
||||
If users want to use ``qrun`` under debug mode, please use the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
.. note::
|
||||
|
||||
`qrun` will be placed in your $PATH directory when installing ``Qlib``.
|
||||
|
||||
@@ -226,3 +226,8 @@ epub_exclude_files = ["search.html"]
|
||||
|
||||
autodoc_member_order = "bysource"
|
||||
autodoc_default_flags = ["members"]
|
||||
autodoc_default_options = {
|
||||
"members": True,
|
||||
"member-order": "bysource",
|
||||
"special-members": "__init__",
|
||||
}
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
Cython==0.29.21
|
||||
Cython
|
||||
cmake
|
||||
numpy
|
||||
scipy
|
||||
scikit-learn
|
||||
|
||||
@@ -63,6 +63,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <../component/data.html#cache>`_ for details.
|
||||
- `exp_manager`
|
||||
Type: dict, optional parameter, the setting of `experiment manager` to be used in qlib. Users can specify an experiment manager class, as well as the tracking URI for all the experiments. However, please be aware that we only support input of a dictionary in the following style for `exp_manager`. For more information about `exp_manager`, users can refer to `Recorder: Experiment Management <../component/recorder.html>`_.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# For example, if you want to set your tracking_uri to a <specific folder>, you can initialize qlib below
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Requirements
|
||||
|
||||
Here is the minimal hardware requirements to run the example.
|
||||
Here is the minimal hardware requirements to run the `workflow_by_code` example.
|
||||
- Memory: 16G
|
||||
- Free Disk: 5G
|
||||
|
||||
|
||||
93
examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml
Executable file
93
examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml
Executable file
@@ -0,0 +1,93 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: ALSTM
|
||||
module_path: qlib.contrib.model.pytorch_alstm_ts
|
||||
kwargs:
|
||||
d_feat: 20
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 10
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
n_jobs: 20
|
||||
GPU: 0
|
||||
rnn_type: GRU
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -54,7 +54,6 @@ task:
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
seed: 0
|
||||
GPU: 0
|
||||
rnn_type: GRU
|
||||
dataset:
|
||||
@@ -62,7 +61,7 @@ task:
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
@@ -0,0 +1,72 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: CatBoostModel
|
||||
module_path: qlib.contrib.model.catboost_model
|
||||
kwargs:
|
||||
loss: RMSE
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
max_depth: 6
|
||||
num_leaves: 100
|
||||
thread_count: 20
|
||||
grow_policy: Lossguide
|
||||
bootstrap_type: Poisson
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
92
examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml
Normal file
92
examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml
Normal file
@@ -0,0 +1,92 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GATs
|
||||
module_path: qlib.contrib.model.pytorch_gats_ts
|
||||
kwargs:
|
||||
d_feat: 20
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.7
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 10
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
with_pretrain: True
|
||||
model_path: "benchmarks/LSTM/csi300_lstm_ts.pkl"
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -56,14 +56,13 @@ task:
|
||||
base_model: LSTM
|
||||
with_pretrain: True
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
seed: 0
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
@@ -74,6 +73,11 @@ task:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
BIN
examples/benchmarks/GRU/csi300_gru_ts.pkl
Normal file
BIN
examples/benchmarks/GRU/csi300_gru_ts.pkl
Normal file
Binary file not shown.
92
examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml
Executable file
92
examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml
Executable file
@@ -0,0 +1,92 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GRU
|
||||
module_path: qlib.contrib.model.pytorch_gru_ts
|
||||
kwargs:
|
||||
d_feat: 20
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 2e-4
|
||||
early_stop: 10
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
n_jobs: 20
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -54,14 +54,13 @@ task:
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
seed: 0
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
BIN
examples/benchmarks/LSTM/csi300_lstm_ts.pkl
Normal file
BIN
examples/benchmarks/LSTM/csi300_lstm_ts.pkl
Normal file
Binary file not shown.
92
examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml
Executable file
92
examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml
Executable file
@@ -0,0 +1,92 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LSTM
|
||||
module_path: qlib.contrib.model.pytorch_lstm_ts
|
||||
kwargs:
|
||||
d_feat: 20
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 10
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
n_jobs: 20
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -54,14 +54,13 @@ task:
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
seed: 0
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
@@ -32,7 +32,7 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
@@ -0,0 +1,73 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -65,8 +65,9 @@ task:
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
batch_size: 8192
|
||||
GPU: 0
|
||||
weight_decay: 0.0002
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
82
examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml
Normal file
82
examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml
Normal file
@@ -0,0 +1,82 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DNNModelPytorch
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
input_dim: 360
|
||||
output_dim: 1
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
35
examples/benchmarks/README.md
Normal file
35
examples/benchmarks/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Benchmarks Performance
|
||||
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs.
|
||||
|
||||
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
|
||||
|
||||
## Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
| Linear | Alpha360 | 0.0150±0.00 | 0.1049±0.00| 0.0284±0.00 | 0.1970±0.00 | -0.0659±0.00 | -0.7072±0.00| -0.2955±0.00 |
|
||||
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0397±0.00 | 0.2878±0.00| 0.0470±0.00 | 0.3703±0.00 | 0.0342±0.00 | 0.4092±0.00| -0.1057±0.00 |
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha360 | 0.0400±0.00 | 0.3031±0.00| 0.0461±0.00 | 0.3862±0.00 | 0.0528±0.00 | 0.6307±0.00| -0.1113±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha360 | 0.0399±0.00 | 0.3075±0.00| 0.0492±0.00 | 0.4019±0.00 | 0.0323±0.00 | 0.4370±0.00| -0.0917±0.00 |
|
||||
| MLP | Alpha360 | 0.0285±0.00 | 0.1981±0.02| 0.0402±0.00 | 0.2993±0.02 | 0.0073±0.02 | 0.0880±0.22| -0.1446±0.03 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha360 | 0.0490±0.01 | 0.3787±0.05| 0.0581±0.00 | 0.4664±0.04 | 0.0726±0.02 | 0.9817±0.34| -0.0902±0.03 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
| Linear | Alpha158 | 0.0393±0.00 | 0.2980±0.00| 0.0475±0.00 | 0.3546±0.00 | 0.0795±0.00 | 1.0712±0.00| -0.1449±0.00 |
|
||||
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha158 | 0.0503±0.00 | 0.3586±0.00| 0.0483±0.00 | 0.3667±0.00 | 0.1080±0.00 | 1.1561±0.00| -0.0787±0.00 |
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TabNet with pretrain (Sercan O. Arikm et al) | Alpha158 | 0.0344±0.00|0.205±0.11|0.0398±0.00 |0.3479±0.01|0.0827±0.02|1.1141±0.32 |-0.0925±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
@@ -1,3 +1,3 @@
|
||||
# State-Frequency-Memory
|
||||
- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
|
||||
- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.)
|
||||
- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [http://www.eecs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](http://www.eecs.ucf.edu/~gqi/publications/kdd2017_stock.pdf)
|
||||
@@ -57,14 +57,13 @@ task:
|
||||
eval_steps: 5
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
GPU: 1
|
||||
seed: 710
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
@@ -1,219 +1,229 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Custom formatting functions for Alpha158 dataset.
|
||||
|
||||
Defines dataset specific column definitions and data transformations.
|
||||
"""
|
||||
|
||||
import data_formatters.base
|
||||
import libs.utils as utils
|
||||
import sklearn.preprocessing
|
||||
|
||||
GenericDataFormatter = data_formatters.base.GenericDataFormatter
|
||||
DataTypes = data_formatters.base.DataTypes
|
||||
InputTypes = data_formatters.base.InputTypes
|
||||
|
||||
|
||||
class Alpha158Formatter(GenericDataFormatter):
|
||||
"""Defines and formats data for the Alpha158 dataset.
|
||||
|
||||
Attributes:
|
||||
column_definition: Defines input and data type of column used in the
|
||||
experiment.
|
||||
identifiers: Entity identifiers used in experiments.
|
||||
"""
|
||||
|
||||
_column_definition = [
|
||||
("instrument", DataTypes.CATEGORICAL, InputTypes.ID),
|
||||
("LABEL0", DataTypes.REAL_VALUED, InputTypes.TARGET),
|
||||
("date", DataTypes.DATE, InputTypes.TIME),
|
||||
("month", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
|
||||
("day_of_week", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
|
||||
# Selected 10 features
|
||||
("RESI5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("WVMA5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RSQR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("KLEN", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RSQR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("ROC60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RESI10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("const", DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""Initialises formatter."""
|
||||
|
||||
self.identifiers = None
|
||||
self._real_scalers = None
|
||||
self._cat_scalers = None
|
||||
self._target_scaler = None
|
||||
self._num_classes_per_cat_input = None
|
||||
|
||||
def split_data(self, df, valid_boundary=2016, test_boundary=2018):
|
||||
"""Splits data frame into training-validation-test data frames.
|
||||
|
||||
This also calibrates scaling object, and transforms data for each split.
|
||||
|
||||
Args:
|
||||
df: Source data frame to split.
|
||||
valid_boundary: Starting year for validation data
|
||||
test_boundary: Starting year for test data
|
||||
|
||||
Returns:
|
||||
Tuple of transformed (train, valid, test) data.
|
||||
"""
|
||||
|
||||
print("Formatting train-valid-test splits.")
|
||||
|
||||
index = df["year"]
|
||||
train = df.loc[index < valid_boundary]
|
||||
valid = df.loc[(index >= valid_boundary) & (index < test_boundary)]
|
||||
test = df.loc[index >= test_boundary]
|
||||
|
||||
self.set_scalers(train)
|
||||
|
||||
return (self.transform_inputs(data) for data in [train, valid, test])
|
||||
|
||||
def set_scalers(self, df):
|
||||
"""Calibrates scalers using the data supplied.
|
||||
|
||||
Args:
|
||||
df: Data to use to calibrate scalers.
|
||||
"""
|
||||
print("Setting scalers with training data...")
|
||||
|
||||
column_definitions = self.get_column_definition()
|
||||
id_column = utils.get_single_col_by_input_type(InputTypes.ID, column_definitions)
|
||||
target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, column_definitions)
|
||||
|
||||
# Extract identifiers in case required
|
||||
self.identifiers = list(df[id_column].unique())
|
||||
|
||||
# Format real scalers
|
||||
real_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
data = df[real_inputs].values
|
||||
self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)
|
||||
self._target_scaler = sklearn.preprocessing.StandardScaler().fit(
|
||||
df[[target_column]].values
|
||||
) # used for predictions
|
||||
|
||||
# Format categorical scalers
|
||||
categorical_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
categorical_scalers = {}
|
||||
num_classes = []
|
||||
for col in categorical_inputs:
|
||||
# Set all to str so that we don't have mixed integer/string columns
|
||||
srs = df[col].apply(str)
|
||||
categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(srs.values)
|
||||
num_classes.append(srs.nunique())
|
||||
|
||||
# Set categorical scaler outputs
|
||||
self._cat_scalers = categorical_scalers
|
||||
self._num_classes_per_cat_input = num_classes
|
||||
|
||||
def transform_inputs(self, df):
|
||||
"""Performs feature transformations.
|
||||
|
||||
This includes both feature engineering, preprocessing and normalisation.
|
||||
|
||||
Args:
|
||||
df: Data frame to transform.
|
||||
|
||||
Returns:
|
||||
Transformed data frame.
|
||||
|
||||
"""
|
||||
output = df.copy()
|
||||
|
||||
if self._real_scalers is None and self._cat_scalers is None:
|
||||
raise ValueError("Scalers have not been set!")
|
||||
|
||||
column_definitions = self.get_column_definition()
|
||||
|
||||
real_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
categorical_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
# Format real inputs
|
||||
output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)
|
||||
|
||||
# Format categorical inputs
|
||||
for col in categorical_inputs:
|
||||
string_df = df[col].apply(str)
|
||||
output[col] = self._cat_scalers[col].transform(string_df)
|
||||
|
||||
return output
|
||||
|
||||
def format_predictions(self, predictions):
|
||||
"""Reverts any normalisation to give predictions in original scale.
|
||||
|
||||
Args:
|
||||
predictions: Dataframe of model predictions.
|
||||
|
||||
Returns:
|
||||
Data frame of unnormalised predictions.
|
||||
"""
|
||||
output = predictions.copy()
|
||||
|
||||
column_names = predictions.columns
|
||||
|
||||
for col in column_names:
|
||||
if col not in {"forecast_time", "identifier"}:
|
||||
output[col] = self._target_scaler.inverse_transform(predictions[col])
|
||||
|
||||
return output
|
||||
|
||||
# Default params
|
||||
def get_fixed_params(self):
|
||||
"""Returns fixed model parameters for experiments."""
|
||||
|
||||
fixed_params = {
|
||||
"total_time_steps": 6 + 6,
|
||||
"num_encoder_steps": 6,
|
||||
"num_epochs": 100,
|
||||
"early_stopping_patience": 10,
|
||||
"multiprocessing_workers": 5,
|
||||
}
|
||||
|
||||
return fixed_params
|
||||
|
||||
def get_default_model_params(self):
|
||||
"""Returns default optimised model parameters."""
|
||||
|
||||
model_params = {
|
||||
"dropout_rate": 0.4,
|
||||
"hidden_layer_size": 16,
|
||||
"learning_rate": 0.0001,
|
||||
"minibatch_size": 128,
|
||||
"max_gradient_norm": 0.0135,
|
||||
"num_heads": 1,
|
||||
"stack_size": 1,
|
||||
}
|
||||
|
||||
return model_params
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Custom formatting functions for Alpha158 dataset.
|
||||
|
||||
Defines dataset specific column definitions and data transformations.
|
||||
"""
|
||||
|
||||
import data_formatters.base
|
||||
import libs.utils as utils
|
||||
import sklearn.preprocessing
|
||||
|
||||
GenericDataFormatter = data_formatters.base.GenericDataFormatter
|
||||
DataTypes = data_formatters.base.DataTypes
|
||||
InputTypes = data_formatters.base.InputTypes
|
||||
|
||||
|
||||
class Alpha158Formatter(GenericDataFormatter):
|
||||
"""Defines and formats data for the Alpha158 dataset.
|
||||
|
||||
Attributes:
|
||||
column_definition: Defines input and data type of column used in the
|
||||
experiment.
|
||||
identifiers: Entity identifiers used in experiments.
|
||||
"""
|
||||
|
||||
_column_definition = [
|
||||
("instrument", DataTypes.CATEGORICAL, InputTypes.ID),
|
||||
("LABEL0", DataTypes.REAL_VALUED, InputTypes.TARGET),
|
||||
("date", DataTypes.DATE, InputTypes.TIME),
|
||||
("month", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
|
||||
("day_of_week", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
|
||||
# Selected features
|
||||
("RESI5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("WVMA5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RSQR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("KLEN", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RSQR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("ROC60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RESI10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("VSTD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RSQR60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORR60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("WVMA60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("STD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RSQR20", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORD60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORD10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORR20", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("KLOW", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("const", DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""Initialises formatter."""
|
||||
|
||||
self.identifiers = None
|
||||
self._real_scalers = None
|
||||
self._cat_scalers = None
|
||||
self._target_scaler = None
|
||||
self._num_classes_per_cat_input = None
|
||||
|
||||
def split_data(self, df, valid_boundary=2016, test_boundary=2018):
|
||||
"""Splits data frame into training-validation-test data frames.
|
||||
|
||||
This also calibrates scaling object, and transforms data for each split.
|
||||
|
||||
Args:
|
||||
df: Source data frame to split.
|
||||
valid_boundary: Starting year for validation data
|
||||
test_boundary: Starting year for test data
|
||||
|
||||
Returns:
|
||||
Tuple of transformed (train, valid, test) data.
|
||||
"""
|
||||
|
||||
print("Formatting train-valid-test splits.")
|
||||
|
||||
index = df["year"]
|
||||
train = df.loc[index < valid_boundary]
|
||||
valid = df.loc[(index >= valid_boundary) & (index < test_boundary)]
|
||||
test = df.loc[index >= test_boundary]
|
||||
|
||||
self.set_scalers(train)
|
||||
|
||||
return (self.transform_inputs(data) for data in [train, valid, test])
|
||||
|
||||
def set_scalers(self, df):
|
||||
"""Calibrates scalers using the data supplied.
|
||||
|
||||
Args:
|
||||
df: Data to use to calibrate scalers.
|
||||
"""
|
||||
print("Setting scalers with training data...")
|
||||
|
||||
column_definitions = self.get_column_definition()
|
||||
id_column = utils.get_single_col_by_input_type(InputTypes.ID, column_definitions)
|
||||
target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, column_definitions)
|
||||
|
||||
# Extract identifiers in case required
|
||||
self.identifiers = list(df[id_column].unique())
|
||||
|
||||
# Format real scalers
|
||||
real_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
data = df[real_inputs].values
|
||||
self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)
|
||||
self._target_scaler = sklearn.preprocessing.StandardScaler().fit(
|
||||
df[[target_column]].values
|
||||
) # used for predictions
|
||||
|
||||
# Format categorical scalers
|
||||
categorical_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
categorical_scalers = {}
|
||||
num_classes = []
|
||||
for col in categorical_inputs:
|
||||
# Set all to str so that we don't have mixed integer/string columns
|
||||
srs = df[col].apply(str)
|
||||
categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(srs.values)
|
||||
num_classes.append(srs.nunique())
|
||||
|
||||
# Set categorical scaler outputs
|
||||
self._cat_scalers = categorical_scalers
|
||||
self._num_classes_per_cat_input = num_classes
|
||||
|
||||
def transform_inputs(self, df):
|
||||
"""Performs feature transformations.
|
||||
|
||||
This includes both feature engineering, preprocessing and normalisation.
|
||||
|
||||
Args:
|
||||
df: Data frame to transform.
|
||||
|
||||
Returns:
|
||||
Transformed data frame.
|
||||
|
||||
"""
|
||||
output = df.copy()
|
||||
|
||||
if self._real_scalers is None and self._cat_scalers is None:
|
||||
raise ValueError("Scalers have not been set!")
|
||||
|
||||
column_definitions = self.get_column_definition()
|
||||
|
||||
real_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
categorical_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
# Format real inputs
|
||||
output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)
|
||||
|
||||
# Format categorical inputs
|
||||
for col in categorical_inputs:
|
||||
string_df = df[col].apply(str)
|
||||
output[col] = self._cat_scalers[col].transform(string_df)
|
||||
|
||||
return output
|
||||
|
||||
def format_predictions(self, predictions):
|
||||
"""Reverts any normalisation to give predictions in original scale.
|
||||
|
||||
Args:
|
||||
predictions: Dataframe of model predictions.
|
||||
|
||||
Returns:
|
||||
Data frame of unnormalised predictions.
|
||||
"""
|
||||
output = predictions.copy()
|
||||
|
||||
column_names = predictions.columns
|
||||
|
||||
for col in column_names:
|
||||
if col not in {"forecast_time", "identifier"}:
|
||||
output[col] = self._target_scaler.inverse_transform(predictions[col])
|
||||
|
||||
return output
|
||||
|
||||
# Default params
|
||||
def get_fixed_params(self):
|
||||
"""Returns fixed model parameters for experiments."""
|
||||
|
||||
fixed_params = {
|
||||
"total_time_steps": 6 + 6,
|
||||
"num_encoder_steps": 6,
|
||||
"num_epochs": 100,
|
||||
"early_stopping_patience": 10,
|
||||
"multiprocessing_workers": 5,
|
||||
}
|
||||
|
||||
return fixed_params
|
||||
|
||||
def get_default_model_params(self):
|
||||
"""Returns default optimised model parameters."""
|
||||
|
||||
model_params = {
|
||||
"dropout_rate": 0.4,
|
||||
"hidden_layer_size": 160,
|
||||
"learning_rate": 0.0001,
|
||||
"minibatch_size": 128,
|
||||
"max_gradient_norm": 0.0135,
|
||||
"num_heads": 1,
|
||||
"stack_size": 1,
|
||||
}
|
||||
|
||||
return model_params
|
||||
|
||||
@@ -25,7 +25,7 @@ import os
|
||||
import data_formatters.qlib_Alpha158
|
||||
|
||||
|
||||
class ExperimentConfig(object):
|
||||
class ExperimentConfig:
|
||||
"""Defines experiment configs and paths to outputs.
|
||||
|
||||
Attributes:
|
||||
|
||||
@@ -320,7 +320,7 @@ class InterpretableMultiHeadAttention:
|
||||
return outputs, attn
|
||||
|
||||
|
||||
class TFTDataCache(object):
|
||||
class TFTDataCache:
|
||||
"""Caches data for the TFT."""
|
||||
|
||||
_data_cache = {}
|
||||
@@ -348,7 +348,7 @@ class TFTDataCache(object):
|
||||
|
||||
|
||||
# TFT model definitions.
|
||||
class TemporalFusionTransformer(object):
|
||||
class TemporalFusionTransformer:
|
||||
"""Defines Temporal Fusion Transformer.
|
||||
|
||||
Attributes:
|
||||
@@ -972,7 +972,7 @@ class TemporalFusionTransformer(object):
|
||||
valid_quantiles = self.quantiles
|
||||
output_size = self.output_size
|
||||
|
||||
class QuantileLossCalculator(object):
|
||||
class QuantileLossCalculator:
|
||||
"""Computes the combined quantile loss for prespecified quantiles.
|
||||
|
||||
Attributes:
|
||||
|
||||
@@ -1,249 +1,291 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow.compat.v1 as tf
|
||||
import data_formatters.base
|
||||
import expt_settings.configs
|
||||
import libs.hyperparam_opt
|
||||
import libs.tft_model
|
||||
import libs.utils as utils
|
||||
import os
|
||||
import datetime as dte
|
||||
|
||||
|
||||
from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
# To register new datasets, please add them here.
|
||||
ALLOW_DATASET = ["Alpha158"]
|
||||
DATASET_SETTING = {
|
||||
"Alpha158": {
|
||||
"feature_col": ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "ROC60", "RESI10"],
|
||||
"label_col": ["LABEL0"],
|
||||
},
|
||||
}
|
||||
# To register new datasets, please add their configurations here.
|
||||
|
||||
|
||||
def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
|
||||
return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
|
||||
|
||||
|
||||
def fill_test_na(test_df):
|
||||
test_df_res = test_df.copy()
|
||||
feature_cols = ~test_df_res.columns.str.contains("label", case=False)
|
||||
test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
|
||||
test_df_res.loc[:, feature_cols] = test_feature_fna
|
||||
return test_df_res
|
||||
|
||||
|
||||
def process_qlib_data(df, dataset, fillna=False):
|
||||
"""Prepare data to fit the TFT model.
|
||||
|
||||
Args:
|
||||
df: Original DataFrame.
|
||||
fillna: Whether to fill the data with the mean values.
|
||||
|
||||
Returns:
|
||||
Transformed DataFrame.
|
||||
|
||||
"""
|
||||
# Several features selected manually
|
||||
feature_col = DATASET_SETTING[dataset]["feature_col"]
|
||||
label_col = DATASET_SETTING[dataset]["label_col"]
|
||||
temp_df = df.loc[:, feature_col + label_col]
|
||||
if fillna:
|
||||
temp_df = fill_test_na(temp_df)
|
||||
temp_df = temp_df.swaplevel()
|
||||
temp_df = temp_df.sort_index()
|
||||
temp_df = temp_df.reset_index(level=0)
|
||||
dates = pd.to_datetime(temp_df.index)
|
||||
temp_df["date"] = dates
|
||||
temp_df["day_of_week"] = dates.dayofweek
|
||||
temp_df["month"] = dates.month
|
||||
temp_df["year"] = dates.year
|
||||
temp_df["const"] = 1.0
|
||||
return temp_df
|
||||
|
||||
|
||||
def process_predicted(df, col_name):
|
||||
"""Transform the TFT predicted data into Qlib format.
|
||||
|
||||
Args:
|
||||
df: Original DataFrame.
|
||||
fillna: New column name.
|
||||
|
||||
Returns:
|
||||
Transformed DataFrame.
|
||||
|
||||
"""
|
||||
df_res = df.copy()
|
||||
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name})
|
||||
df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
|
||||
df_res = df_res[[col_name]]
|
||||
return df_res
|
||||
|
||||
|
||||
def format_score(forecast_df, col_name="pred", label_shift=5):
|
||||
pred = process_predicted(forecast_df, col_name=col_name)
|
||||
pred = get_shifted_label(pred, shifts=-label_shift, col_shift=col_name)
|
||||
pred = pred.dropna()[col_name]
|
||||
return pred
|
||||
|
||||
|
||||
def transform_df(df, col_name="LABEL0"):
|
||||
df_res = df["feature"]
|
||||
df_res[col_name] = df["label"]
|
||||
return df_res
|
||||
|
||||
|
||||
class TFTModel(ModelFT):
|
||||
"""TFT Model"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.model = None
|
||||
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
return transform_df(df_train), transform_df(df_valid)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
DATASET="Alpha158",
|
||||
MODEL_FOLDER="qlib_alpha158_model",
|
||||
LABEL_COL="LABEL0",
|
||||
LABEL_SHIFT=5,
|
||||
USE_GPU_ID=0,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if DATASET not in ALLOW_DATASET:
|
||||
raise AssertionError("The dataset is not supported, please make a new formatter to fit this dataset")
|
||||
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
dtrain.loc[:, LABEL_COL] = get_shifted_label(dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
|
||||
dvalid.loc[:, LABEL_COL] = get_shifted_label(dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
|
||||
|
||||
train = process_qlib_data(dtrain, DATASET, fillna=True).dropna()
|
||||
valid = process_qlib_data(dvalid, DATASET, fillna=True).dropna()
|
||||
|
||||
ExperimentConfig = expt_settings.configs.ExperimentConfig
|
||||
config = ExperimentConfig(DATASET)
|
||||
self.data_formatter = config.make_data_formatter()
|
||||
self.model_folder = MODEL_FOLDER
|
||||
self.gpu_id = USE_GPU_ID
|
||||
self.label_shift = LABEL_SHIFT
|
||||
self.expt_name = DATASET
|
||||
self.label_col = LABEL_COL
|
||||
|
||||
use_gpu = (True, self.gpu_id)
|
||||
# ===========================Training Process===========================
|
||||
ModelClass = libs.tft_model.TemporalFusionTransformer
|
||||
if not isinstance(self.data_formatter, data_formatters.base.GenericDataFormatter):
|
||||
raise ValueError(
|
||||
"Data formatters should inherit from"
|
||||
+ "AbstractDataFormatter! Type={}".format(type(self.data_formatter))
|
||||
)
|
||||
|
||||
default_keras_session = tf.keras.backend.get_session()
|
||||
|
||||
if use_gpu[0]:
|
||||
self.tf_config = utils.get_default_tensorflow_config(tf_device="gpu", gpu_id=use_gpu[1])
|
||||
else:
|
||||
self.tf_config = utils.get_default_tensorflow_config(tf_device="cpu")
|
||||
|
||||
self.data_formatter.set_scalers(train)
|
||||
|
||||
# Sets up default params
|
||||
fixed_params = self.data_formatter.get_experiment_params()
|
||||
params = self.data_formatter.get_default_model_params()
|
||||
|
||||
# Wendi: 合并调优的参数和非调优的参数
|
||||
params = {**params, **fixed_params}
|
||||
|
||||
if not os.path.exists(self.model_folder):
|
||||
os.makedirs(self.model_folder)
|
||||
params["model_folder"] = self.model_folder
|
||||
|
||||
print("*** Begin training ***")
|
||||
best_loss = np.Inf
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
self.tf_graph = tf.Graph()
|
||||
with self.tf_graph.as_default():
|
||||
self.sess = tf.Session(config=self.tf_config)
|
||||
tf.keras.backend.set_session(self.sess)
|
||||
self.model = ModelClass(params, use_cudnn=use_gpu[0])
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
self.model.fit(train_df=train, valid_df=valid)
|
||||
print("*** Finished training ***")
|
||||
saved_model_dir = self.model_folder + "/" + "saved_model"
|
||||
if not os.path.exists(saved_model_dir):
|
||||
os.makedirs(saved_model_dir)
|
||||
self.model.save(saved_model_dir)
|
||||
|
||||
def extract_numerical_data(data):
|
||||
"""Strips out forecast time and identifier columns."""
|
||||
return data[[col for col in data.columns if col not in {"forecast_time", "identifier"}]]
|
||||
|
||||
# p50_loss = utils.numpy_normalised_quantile_loss(
|
||||
# extract_numerical_data(targets), extract_numerical_data(p50_forecast),
|
||||
# 0.5)
|
||||
# p90_loss = utils.numpy_normalised_quantile_loss(
|
||||
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
|
||||
# 0.9)
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
print("Training completed.".format(dte.datetime.now()))
|
||||
# ===========================Training Process===========================
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
d_test = dataset.prepare("test", col_set=["feature", "label"])
|
||||
d_test = transform_df(d_test)
|
||||
d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col)
|
||||
test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna()
|
||||
|
||||
use_gpu = (True, self.gpu_id)
|
||||
# ===========================Predicting Process===========================
|
||||
default_keras_session = tf.keras.backend.get_session()
|
||||
|
||||
# Sets up default params
|
||||
fixed_params = self.data_formatter.get_experiment_params()
|
||||
params = self.data_formatter.get_default_model_params()
|
||||
params = {**params, **fixed_params}
|
||||
|
||||
print("*** Begin predicting ***")
|
||||
tf.reset_default_graph()
|
||||
|
||||
with self.tf_graph.as_default():
|
||||
tf.keras.backend.set_session(self.sess)
|
||||
output_map = self.model.predict(test, return_targets=True)
|
||||
targets = self.data_formatter.format_predictions(output_map["targets"])
|
||||
p50_forecast = self.data_formatter.format_predictions(output_map["p50"])
|
||||
p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
|
||||
predict50 = format_score(p50_forecast, "pred", 1)
|
||||
predict90 = format_score(p90_forecast, "pred", 1)
|
||||
predict = (predict50 + predict90) / 2 # self.label_shift
|
||||
# ===========================Predicting Process===========================
|
||||
return predict
|
||||
|
||||
def finetune(self, dataset: DatasetH):
|
||||
"""
|
||||
finetune model
|
||||
Parameters
|
||||
----------
|
||||
dataset : DatasetH
|
||||
dataset for finetuning
|
||||
"""
|
||||
pass
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow.compat.v1 as tf
|
||||
import data_formatters.base
|
||||
import expt_settings.configs
|
||||
import libs.hyperparam_opt
|
||||
import libs.tft_model
|
||||
import libs.utils as utils
|
||||
import os
|
||||
import datetime as dte
|
||||
|
||||
|
||||
from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
# To register new datasets, please add them here.
|
||||
ALLOW_DATASET = ["Alpha158", "Alpha360"]
|
||||
# To register new datasets, please add their configurations here.
|
||||
DATASET_SETTING = {
|
||||
"Alpha158": {
|
||||
"feature_col": [
|
||||
"RESI5",
|
||||
"WVMA5",
|
||||
"RSQR5",
|
||||
"KLEN",
|
||||
"RSQR10",
|
||||
"CORR5",
|
||||
"CORD5",
|
||||
"CORR10",
|
||||
"ROC60",
|
||||
"RESI10",
|
||||
"VSTD5",
|
||||
"RSQR60",
|
||||
"CORR60",
|
||||
"WVMA60",
|
||||
"STD5",
|
||||
"RSQR20",
|
||||
"CORD60",
|
||||
"CORD10",
|
||||
"CORR20",
|
||||
"KLOW",
|
||||
],
|
||||
"label_col": "LABEL0",
|
||||
},
|
||||
"Alpha360": {
|
||||
"feature_col": [
|
||||
"HIGH0",
|
||||
"LOW0",
|
||||
"OPEN0",
|
||||
"CLOSE1",
|
||||
"HIGH1",
|
||||
"VOLUME1",
|
||||
"LOW1",
|
||||
"VOLUME3",
|
||||
"OPEN1",
|
||||
"VOLUME4",
|
||||
"CLOSE2",
|
||||
"CLOSE4",
|
||||
"VOLUME5",
|
||||
"LOW2",
|
||||
"CLOSE3",
|
||||
"VOLUME2",
|
||||
"HIGH2",
|
||||
"LOW4",
|
||||
"VOLUME8",
|
||||
"VOLUME11",
|
||||
],
|
||||
"label_col": "LABEL0",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
|
||||
return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
|
||||
|
||||
|
||||
def fill_test_na(test_df):
|
||||
test_df_res = test_df.copy()
|
||||
feature_cols = ~test_df_res.columns.str.contains("label", case=False)
|
||||
test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
|
||||
test_df_res.loc[:, feature_cols] = test_feature_fna
|
||||
return test_df_res
|
||||
|
||||
|
||||
def process_qlib_data(df, dataset, fillna=False):
|
||||
"""Prepare data to fit the TFT model.
|
||||
|
||||
Args:
|
||||
df: Original DataFrame.
|
||||
fillna: Whether to fill the data with the mean values.
|
||||
|
||||
Returns:
|
||||
Transformed DataFrame.
|
||||
|
||||
"""
|
||||
# Several features selected manually
|
||||
feature_col = DATASET_SETTING[dataset]["feature_col"]
|
||||
label_col = [DATASET_SETTING[dataset]["label_col"]]
|
||||
temp_df = df.loc[:, feature_col + label_col]
|
||||
if fillna:
|
||||
temp_df = fill_test_na(temp_df)
|
||||
temp_df = temp_df.swaplevel()
|
||||
temp_df = temp_df.sort_index()
|
||||
temp_df = temp_df.reset_index(level=0)
|
||||
dates = pd.to_datetime(temp_df.index)
|
||||
temp_df["date"] = dates
|
||||
temp_df["day_of_week"] = dates.dayofweek
|
||||
temp_df["month"] = dates.month
|
||||
temp_df["year"] = dates.year
|
||||
temp_df["const"] = 1.0
|
||||
return temp_df
|
||||
|
||||
|
||||
def process_predicted(df, col_name):
|
||||
"""Transform the TFT predicted data into Qlib format.
|
||||
|
||||
Args:
|
||||
df: Original DataFrame.
|
||||
fillna: New column name.
|
||||
|
||||
Returns:
|
||||
Transformed DataFrame.
|
||||
|
||||
"""
|
||||
df_res = df.copy()
|
||||
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name})
|
||||
df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
|
||||
df_res = df_res[[col_name]]
|
||||
return df_res
|
||||
|
||||
|
||||
def format_score(forecast_df, col_name="pred", label_shift=5):
|
||||
pred = process_predicted(forecast_df, col_name=col_name)
|
||||
pred = get_shifted_label(pred, shifts=-label_shift, col_shift=col_name)
|
||||
pred = pred.dropna()[col_name]
|
||||
return pred
|
||||
|
||||
|
||||
def transform_df(df, col_name="LABEL0"):
|
||||
df_res = df["feature"]
|
||||
df_res[col_name] = df["label"]
|
||||
return df_res
|
||||
|
||||
|
||||
class TFTModel(ModelFT):
|
||||
"""TFT Model"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.model = None
|
||||
self.params = {"DATASET": "Alpha158", "label_shift": 5}
|
||||
self.params.update(kwargs)
|
||||
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
return transform_df(df_train), transform_df(df_valid)
|
||||
|
||||
def fit(self, dataset: DatasetH, MODEL_FOLDER="qlib_tft_model", USE_GPU_ID=0, **kwargs):
|
||||
DATASET = self.params["DATASET"]
|
||||
LABEL_SHIFT = self.params["label_shift"]
|
||||
LABEL_COL = DATASET_SETTING[DATASET]["label_col"]
|
||||
|
||||
if DATASET not in ALLOW_DATASET:
|
||||
raise AssertionError("The dataset is not supported, please make a new formatter to fit this dataset")
|
||||
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
dtrain.loc[:, LABEL_COL] = get_shifted_label(dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
|
||||
dvalid.loc[:, LABEL_COL] = get_shifted_label(dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
|
||||
|
||||
train = process_qlib_data(dtrain, DATASET, fillna=True).dropna()
|
||||
valid = process_qlib_data(dvalid, DATASET, fillna=True).dropna()
|
||||
|
||||
ExperimentConfig = expt_settings.configs.ExperimentConfig
|
||||
config = ExperimentConfig(DATASET)
|
||||
self.data_formatter = config.make_data_formatter()
|
||||
self.model_folder = MODEL_FOLDER
|
||||
self.gpu_id = USE_GPU_ID
|
||||
self.label_shift = LABEL_SHIFT
|
||||
self.expt_name = DATASET
|
||||
self.label_col = LABEL_COL
|
||||
|
||||
use_gpu = (True, self.gpu_id)
|
||||
# ===========================Training Process===========================
|
||||
ModelClass = libs.tft_model.TemporalFusionTransformer
|
||||
if not isinstance(self.data_formatter, data_formatters.base.GenericDataFormatter):
|
||||
raise ValueError(
|
||||
"Data formatters should inherit from"
|
||||
+ "AbstractDataFormatter! Type={}".format(type(self.data_formatter))
|
||||
)
|
||||
|
||||
default_keras_session = tf.keras.backend.get_session()
|
||||
|
||||
if use_gpu[0]:
|
||||
self.tf_config = utils.get_default_tensorflow_config(tf_device="gpu", gpu_id=use_gpu[1])
|
||||
else:
|
||||
self.tf_config = utils.get_default_tensorflow_config(tf_device="cpu")
|
||||
|
||||
self.data_formatter.set_scalers(train)
|
||||
|
||||
# Sets up default params
|
||||
fixed_params = self.data_formatter.get_experiment_params()
|
||||
params = self.data_formatter.get_default_model_params()
|
||||
|
||||
# Wendi: 合并调优的参数和非调优的参数
|
||||
params = {**params, **fixed_params}
|
||||
|
||||
if not os.path.exists(self.model_folder):
|
||||
os.makedirs(self.model_folder)
|
||||
params["model_folder"] = self.model_folder
|
||||
|
||||
print("*** Begin training ***")
|
||||
best_loss = np.Inf
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
self.tf_graph = tf.Graph()
|
||||
with self.tf_graph.as_default():
|
||||
self.sess = tf.Session(config=self.tf_config)
|
||||
tf.keras.backend.set_session(self.sess)
|
||||
self.model = ModelClass(params, use_cudnn=use_gpu[0])
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
self.model.fit(train_df=train, valid_df=valid)
|
||||
print("*** Finished training ***")
|
||||
saved_model_dir = self.model_folder + "/" + "saved_model"
|
||||
if not os.path.exists(saved_model_dir):
|
||||
os.makedirs(saved_model_dir)
|
||||
self.model.save(saved_model_dir)
|
||||
|
||||
def extract_numerical_data(data):
|
||||
"""Strips out forecast time and identifier columns."""
|
||||
return data[[col for col in data.columns if col not in {"forecast_time", "identifier"}]]
|
||||
|
||||
# p50_loss = utils.numpy_normalised_quantile_loss(
|
||||
# extract_numerical_data(targets), extract_numerical_data(p50_forecast),
|
||||
# 0.5)
|
||||
# p90_loss = utils.numpy_normalised_quantile_loss(
|
||||
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
|
||||
# 0.9)
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
print("Training completed.".format(dte.datetime.now()))
|
||||
# ===========================Training Process===========================
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
d_test = dataset.prepare("test", col_set=["feature", "label"])
|
||||
d_test = transform_df(d_test)
|
||||
d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col)
|
||||
test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna()
|
||||
|
||||
use_gpu = (True, self.gpu_id)
|
||||
# ===========================Predicting Process===========================
|
||||
default_keras_session = tf.keras.backend.get_session()
|
||||
|
||||
# Sets up default params
|
||||
fixed_params = self.data_formatter.get_experiment_params()
|
||||
params = self.data_formatter.get_default_model_params()
|
||||
params = {**params, **fixed_params}
|
||||
|
||||
print("*** Begin predicting ***")
|
||||
tf.reset_default_graph()
|
||||
|
||||
with self.tf_graph.as_default():
|
||||
tf.keras.backend.set_session(self.sess)
|
||||
output_map = self.model.predict(test, return_targets=True)
|
||||
targets = self.data_formatter.format_predictions(output_map["targets"])
|
||||
p50_forecast = self.data_formatter.format_predictions(output_map["p50"])
|
||||
p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
|
||||
predict50 = format_score(p50_forecast, "pred", 1)
|
||||
predict90 = format_score(p90_forecast, "pred", 1)
|
||||
predict = (predict50 + predict90) / 2 # self.label_shift
|
||||
# ===========================Predicting Process===========================
|
||||
return predict
|
||||
|
||||
def finetune(self, dataset: DatasetH):
|
||||
"""
|
||||
finetune model
|
||||
Parameters
|
||||
----------
|
||||
dataset : DatasetH
|
||||
dataset for finetuning
|
||||
"""
|
||||
pass
|
||||
|
||||
BIN
examples/benchmarks/TabNet/pretrain/best.model
Normal file
BIN
examples/benchmarks/TabNet/pretrain/best.model
Normal file
Binary file not shown.
4
examples/benchmarks/TabNet/requirements.txt
Normal file
4
examples/benchmarks/TabNet/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -0,0 +1,74 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2020-08-01]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,71 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: XGBModel
|
||||
module_path: qlib.contrib.model.xgboost
|
||||
kwargs:
|
||||
eval_metric: rmse
|
||||
colsample_bytree: 0.8879
|
||||
eta: 0.0421
|
||||
max_depth: 8
|
||||
n_estimators: 647
|
||||
subsample: 0.8789
|
||||
nthread: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
28
examples/highfreq/README.md
Normal file
28
examples/highfreq/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# High-Frequency Dataset
|
||||
|
||||
This dataset is an example for RL high frequency trading.
|
||||
|
||||
## Get High-Frequency Data
|
||||
|
||||
Get high-frequency data by running the following command:
|
||||
```bash
|
||||
python workflow.py get_data
|
||||
```
|
||||
|
||||
## Dump & Reload & Reinitialize the Dataset
|
||||
|
||||
|
||||
The High-Frequency Dataset is implemented as `qlib.data.dataset.DatasetH` in the `workflow.py`. `DatatsetH` is the subclass of [`qlib.utils.serial.Serializable`](https://qlib.readthedocs.io/en/latest/advanced/serial.html), whose state can be dumped in or loaded from disk in `pickle` format.
|
||||
|
||||
### About Reinitialization
|
||||
|
||||
After reloading `Dataset` from disk, `Qlib` also support reinitializing the dataset. It means that users can reset some states of `Dataset` or `DataHandler` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states.
|
||||
|
||||
The example is given in `workflow.py`, users can run the code as follows.
|
||||
|
||||
### Run the Code
|
||||
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py dump_and_load_dataset
|
||||
```
|
||||
174
examples/highfreq/highfreq_handler.py
Normal file
174
examples/highfreq/highfreq_handler.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.utils import get_cls_kwargs
|
||||
from qlib.log import TimeInspector
|
||||
|
||||
|
||||
class HighFreqHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
drop_raw=drop_raw,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})"
|
||||
template_fillnan = "BFillNan(FFillNan({0}))"
|
||||
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
|
||||
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
|
||||
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
"""Get normalized price feature ops"""
|
||||
if shift == 0:
|
||||
template_norm = "Cut({0}/Ref(DayLast({1}), 240), 240, None)"
|
||||
else:
|
||||
template_norm = "Cut(Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240), 240, None)"
|
||||
|
||||
feature_ops = template_norm.format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(price_field),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 0)]
|
||||
fields += [get_normalized_price_feature("$high", 0)]
|
||||
fields += [get_normalized_price_feature("$low", 0)]
|
||||
fields += [get_normalized_price_feature("$close", 0)]
|
||||
fields += [get_normalized_price_feature(simpson_vwap, 0)]
|
||||
names += ["$open", "$high", "$low", "$close", "$vwap"]
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 240)]
|
||||
fields += [get_normalized_price_feature("$high", 240)]
|
||||
fields += [get_normalized_price_feature("$low", 240)]
|
||||
fields += [get_normalized_price_feature("$close", 240)]
|
||||
fields += [get_normalized_price_feature(simpson_vwap, 240)]
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
|
||||
fields += [
|
||||
"Cut({0}/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
template_paused.format("$low"),
|
||||
template_paused.format("$high"),
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume"]
|
||||
fields += [
|
||||
"Cut(Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
template_paused.format("$low"),
|
||||
template_paused.format("$high"),
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
|
||||
names += ["date"]
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqBacktestHandler(DataHandler):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})"
|
||||
template_fillnan = "BFillNan(FFillNan({0}))"
|
||||
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
|
||||
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
|
||||
fields += [
|
||||
"Cut({0}, 240, None)".format(template_fillnan.format(template_paused.format("$close"))),
|
||||
]
|
||||
names += ["$close0"]
|
||||
fields += [
|
||||
"Cut({0}, 240, None)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(simpson_vwap),
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
fields += [
|
||||
"Cut(If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0})), 240, None)".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
template_paused.format("$low"),
|
||||
template_paused.format("$high"),
|
||||
)
|
||||
]
|
||||
names += ["$volume0"]
|
||||
|
||||
return fields, names
|
||||
190
examples/highfreq/highfreq_ops.py
Normal file
190
examples/highfreq/highfreq_ops.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import importlib
|
||||
from qlib.data.ops import ElemOperator, PairOperator
|
||||
from qlib.config import C
|
||||
from qlib.data.cache import H
|
||||
from qlib.data.data import Cal
|
||||
|
||||
|
||||
def get_calendar_day(freq="day", future=False):
|
||||
"""Load High-Freq Calendar Date Using Memcache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of read calendar file.
|
||||
future : bool
|
||||
whether including future trading day.
|
||||
|
||||
Returns
|
||||
-------
|
||||
_calendar:
|
||||
array of date.
|
||||
"""
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
|
||||
H["c"][flag] = _calendar
|
||||
return _calendar
|
||||
|
||||
|
||||
class DayLast(ElemOperator):
|
||||
"""DayLast Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value equals the last value of its day
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("last")
|
||||
|
||||
|
||||
class FFillNan(ElemOperator):
|
||||
"""FFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a forward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="ffill")
|
||||
|
||||
|
||||
class BFillNan(ElemOperator):
|
||||
"""BFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a backfoward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="bfill")
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
"""Date Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value is the date corresponding to feature.index
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return pd.Series(_calendar[series.index], index=series.index)
|
||||
|
||||
|
||||
class Select(PairOperator):
|
||||
"""Select Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature_left : Expression
|
||||
feature instance, select condition
|
||||
feature_right : Expression
|
||||
feature instance, select value
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
value(feature_right) that meets the condition(feature_left)
|
||||
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_condition = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
series_feature = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
return series_feature.loc[series_condition]
|
||||
|
||||
|
||||
class IsNull(ElemOperator):
|
||||
"""IsNull Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series indicating whether the feature is nan
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.isnull()
|
||||
|
||||
|
||||
class Cut(ElemOperator):
|
||||
"""Cut Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
l : int
|
||||
l > 0, delete the first l elements of feature (default is None, which means 0)
|
||||
r : int
|
||||
r < 0, delete the last -r elements of feature (default is None, which means 0)
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series with the first l and last -r elements deleted from the feature.
|
||||
Note: It is deleted from the raw data, not the sliced data
|
||||
"""
|
||||
|
||||
def __init__(self, feature, l=None, r=None):
|
||||
self.l = l
|
||||
self.r = r
|
||||
if (self.l is not None and self.l <= 0) or (self.r is not None and self.r >= 0):
|
||||
raise ValueError("Cut operator l shoud > 0 and r should < 0")
|
||||
|
||||
super(Cut, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.iloc[self.l : self.r]
|
||||
|
||||
def get_extended_window_size(self):
|
||||
ll = 0 if self.l is None else self.l
|
||||
rr = 0 if self.r is None else abs(self.r)
|
||||
lft_etd, rght_etd = self.feature.get_extended_window_size()
|
||||
lft_etd = lft_etd + ll
|
||||
rght_etd = rght_etd + rr
|
||||
return lft_etd, rght_etd
|
||||
72
examples/highfreq/highfreq_processor.py
Normal file
72
examples/highfreq/highfreq_processor.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.data.dataset.utils import fetch_df_by_index
|
||||
|
||||
|
||||
class HighFreqNorm(Processor):
|
||||
def __init__(self, fit_start_time, fit_end_time):
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
|
||||
def fit(self, df_features):
|
||||
fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
del df_features
|
||||
df_values = fetch_df.values
|
||||
names = {
|
||||
"price": slice(0, 10),
|
||||
"volume": slice(10, 12),
|
||||
}
|
||||
self.feature_med = {}
|
||||
self.feature_std = {}
|
||||
self.feature_vmax = {}
|
||||
self.feature_vmin = {}
|
||||
for name, name_val in names.items():
|
||||
part_values = df_values[:, name_val].astype(np.float32)
|
||||
if name == "volume":
|
||||
part_values = np.log1p(part_values)
|
||||
self.feature_med[name] = np.nanmedian(part_values)
|
||||
part_values = part_values - self.feature_med[name]
|
||||
self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + 1e-12
|
||||
part_values = part_values / self.feature_std[name]
|
||||
self.feature_vmax[name] = np.nanmax(part_values)
|
||||
self.feature_vmin[name] = np.nanmin(part_values)
|
||||
|
||||
def __call__(self, df_features):
|
||||
df_features.set_index("date", append=True, drop=True, inplace=True)
|
||||
df_values = df_features.values
|
||||
names = {
|
||||
"price": slice(0, 10),
|
||||
"volume": slice(10, 12),
|
||||
}
|
||||
|
||||
for name, name_val in names.items():
|
||||
if name == "volume":
|
||||
df_values[:, name_val] = np.log1p(df_values[:, name_val])
|
||||
df_values[:, name_val] -= self.feature_med[name]
|
||||
df_values[:, name_val] /= self.feature_std[name]
|
||||
slice0 = df_values[:, name_val] > 3.0
|
||||
slice1 = df_values[:, name_val] > 3.5
|
||||
slice2 = df_values[:, name_val] < -3.0
|
||||
slice3 = df_values[:, name_val] < -3.5
|
||||
|
||||
df_values[:, name_val][slice0] = (
|
||||
3.0 + (df_values[:, name_val][slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5
|
||||
)
|
||||
df_values[:, name_val][slice1] = 3.5
|
||||
df_values[:, name_val][slice2] = (
|
||||
-3.0 - (df_values[:, name_val][slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5
|
||||
)
|
||||
df_values[:, name_val][slice3] = -3.5
|
||||
idx = df_features.index.droplevel("datetime").drop_duplicates()
|
||||
idx.set_names(["instrument", "datetime"], inplace=True)
|
||||
|
||||
# Reshape is specifically for adapting to RL high-freq executor
|
||||
feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240)
|
||||
feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240)
|
||||
df_new_features = pd.DataFrame(
|
||||
data=np.concatenate((feat, feat_1), axis=1),
|
||||
index=idx,
|
||||
columns=["FEATURE_%d" % i for i in range(12 * 240)],
|
||||
).sort_index()
|
||||
return df_new_features
|
||||
217
examples/highfreq/workflow.py
Normal file
217
examples/highfreq/workflow.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import fire
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
|
||||
from qlib.utils import init_instance_by_config, exists_qlib_data
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
|
||||
|
||||
|
||||
class HighfreqWorkflow(object):
|
||||
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
|
||||
|
||||
MARKET = "all"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
start_time = pd.Timestamp("2020-09-15 00:00:00")
|
||||
end_time = pd.Timestamp("2021-01-18 16:00:00")
|
||||
train_end_time = pd.Timestamp("2020-11-30 16:00:00")
|
||||
test_start_time = pd.Timestamp("2020-12-01 00:00:00")
|
||||
|
||||
DATA_HANDLER_CONFIG0 = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"fit_start_time": start_time,
|
||||
"fit_end_time": train_end_time,
|
||||
"instruments": MARKET,
|
||||
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}],
|
||||
}
|
||||
DATA_HANDLER_CONFIG1 = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
task = {
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "HighFreqHandler",
|
||||
"module_path": "highfreq_handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG0,
|
||||
},
|
||||
"segments": {
|
||||
"train": (start_time, train_end_time),
|
||||
"test": (
|
||||
test_start_time,
|
||||
end_time,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
"dataset_backtest": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "HighFreqBacktestHandler",
|
||||
"module_path": "highfreq_handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG1,
|
||||
},
|
||||
"segments": {
|
||||
"train": (start_time, train_end_time),
|
||||
"test": (
|
||||
test_start_time,
|
||||
end_time,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
|
||||
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
|
||||
qlib.init(**QLIB_INIT_CONFIG)
|
||||
|
||||
def _prepare_calender_cache(self):
|
||||
"""preload the calendar for cache"""
|
||||
|
||||
# This code used the copy-on-write feature of Linux to avoid calculating the calendar multiple times in the subprocess
|
||||
# This code may accelerate, but may be not useful on Windows and Mac Os
|
||||
Cal.calendar(freq="1min")
|
||||
get_calendar_day(freq="1min")
|
||||
|
||||
def get_data(self):
|
||||
"""use dataset to get highreq data"""
|
||||
self._init_qlib()
|
||||
self._prepare_calender_cache()
|
||||
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
xtrain, xtest = dataset.prepare(["train", "test"])
|
||||
print(xtrain, xtest)
|
||||
|
||||
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
print(backtest_train, backtest_test)
|
||||
|
||||
return
|
||||
|
||||
def dump_and_load_dataset(self):
|
||||
"""dump and load dataset state on disk"""
|
||||
self._init_qlib()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
|
||||
|
||||
##=============dump dataset=============
|
||||
dataset.to_pickle(path="dataset.pkl")
|
||||
dataset_backtest.to_pickle(path="dataset_backtest.pkl")
|
||||
|
||||
del dataset, dataset_backtest
|
||||
##=============reload dataset=============
|
||||
with open("dataset.pkl", "rb") as file_dataset:
|
||||
dataset = pickle.load(file_dataset)
|
||||
|
||||
with open("dataset_backtest.pkl", "rb") as file_dataset_backtest:
|
||||
dataset_backtest = pickle.load(file_dataset_backtest)
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reinit dataset=============
|
||||
dataset.init(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_LS,
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.init(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
##=============get data=============
|
||||
xtest = dataset.prepare(["test"])
|
||||
backtest_test = dataset_backtest.prepare(["test"])
|
||||
|
||||
print(xtest, backtest_test)
|
||||
return
|
||||
|
||||
|
||||
def get_high_freq_data(self, data_path):
|
||||
self._init_qlib()
|
||||
self._prepare_calender_cache()
|
||||
|
||||
import os
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
xtrain, xtest = dataset.prepare(["train", "test"])
|
||||
normed_feature = pd.concat([xtrain, xtest]).sort_index()
|
||||
dic = dict(tuple(normed_feature.groupby("instrument")))
|
||||
feature_path = os.path.join(data_path, "normed_feature/")
|
||||
if not os.path.exists(feature_path):
|
||||
os.makedirs(feature_path)
|
||||
for k, v in dic.items():
|
||||
v.to_pickle(feature_path + f"{k}.pkl")
|
||||
|
||||
|
||||
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
backtest = pd.concat([backtest_train, backtest_test]).sort_index()
|
||||
backtest['date'] = backtest.index.map(lambda x: x[1].date())
|
||||
backtest.set_index('date', append=True, drop=True, inplace=True)
|
||||
dic = dict(tuple(backtest.groupby("instrument")))
|
||||
backtest_path = os.path.join(data_path, "backtest/")
|
||||
if not os.path.exists(backtest_path):
|
||||
os.makedirs(backtest_path)
|
||||
for k, v in dic.items():
|
||||
v.to_pickle(backtest_path + f"{k}.pkl.backtest")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#fire.Fire(HighfreqWorkflow)
|
||||
data_path = '../data/'
|
||||
workflow = HighfreqWorkflow()
|
||||
workflow.get_high_freq_data(data_path)
|
||||
|
||||
@@ -15,6 +15,7 @@ import traceback
|
||||
import functools
|
||||
import statistics
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from operator import xor
|
||||
from pprint import pprint
|
||||
@@ -45,8 +46,6 @@ if not exists_qlib_data(provider_uri):
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
if os.path.isdir(exp_path):
|
||||
shutil.rmtree(exp_path)
|
||||
|
||||
# decorator to check the arguments
|
||||
def only_allow_defined_args(function_to_decorate):
|
||||
@@ -70,9 +69,9 @@ def handler(signum, frame):
|
||||
os.system("kill -9 %d" % os.getpid())
|
||||
|
||||
|
||||
signal.signal(signal.SIGTSTP, handler)
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
|
||||
|
||||
# function to calculate the mean and std of a list in the results dictionary
|
||||
def cal_mean_std(results) -> dict:
|
||||
mean_std = dict()
|
||||
@@ -136,9 +135,9 @@ def get_all_folders(models, exclude) -> dict:
|
||||
|
||||
|
||||
# function to get all the files under the model folder
|
||||
def get_all_files(folder_path) -> (str, str):
|
||||
yaml_path = str(Path(f"{folder_path}") / "*.yaml")
|
||||
req_path = str(Path(f"{folder_path}") / "*.txt")
|
||||
def get_all_files(folder_path, dataset) -> (str, str):
|
||||
yaml_path = str(Path(f"{folder_path}") / f"*{dataset}*.yaml")
|
||||
req_path = str(Path(f"{folder_path}") / f"*.txt")
|
||||
return glob.glob(yaml_path)[0], glob.glob(req_path)[0]
|
||||
|
||||
|
||||
@@ -152,6 +151,10 @@ def get_all_results(folders) -> dict:
|
||||
result["annualized_return_with_cost"] = list()
|
||||
result["information_ratio_with_cost"] = list()
|
||||
result["max_drawdown_with_cost"] = list()
|
||||
result["ic"] = list()
|
||||
result["icir"] = list()
|
||||
result["rank_ic"] = list()
|
||||
result["rank_icir"] = list()
|
||||
for recorder_id in recorders:
|
||||
if recorders[recorder_id].status == "FINISHED":
|
||||
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
|
||||
@@ -159,19 +162,27 @@ def get_all_results(folders) -> dict:
|
||||
result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
|
||||
result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
|
||||
result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
|
||||
result["ic"].append(metrics["IC"])
|
||||
result["icir"].append(metrics["ICIR"])
|
||||
result["rank_ic"].append(metrics["Rank IC"])
|
||||
result["rank_icir"].append(metrics["Rank ICIR"])
|
||||
results[fn] = result
|
||||
return results
|
||||
|
||||
|
||||
# function to generate and save markdown table
|
||||
def gen_and_save_md_table(metrics):
|
||||
table = "| Model Name | Annualized Return | Information Ratio | Max Drawdown |\n"
|
||||
table += "|---|---|---|---|\n"
|
||||
def gen_and_save_md_table(metrics, dataset):
|
||||
table = "| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |\n"
|
||||
table += "|---|---|---|---|---|---|---|---|---|\n"
|
||||
for fn in metrics:
|
||||
ic = metrics[fn]["ic"]
|
||||
icir = metrics[fn]["icir"]
|
||||
ric = metrics[fn]["rank_ic"]
|
||||
ricir = metrics[fn]["rank_icir"]
|
||||
ar = metrics[fn]["annualized_return_with_cost"]
|
||||
ir = metrics[fn]["information_ratio_with_cost"]
|
||||
md = metrics[fn]["max_drawdown_with_cost"]
|
||||
table += f"| {fn} | {ar[0]:9.4f}±{ar[1]:9.2f} | {ir[0]:9.4f}±{ir[1]:9.2f}| {md[0]:9.4f}±{md[1]:9.2f} |\n"
|
||||
table += f"| {fn} | {dataset} | {ic[0]:5.4f}±{ic[1]:2.2f} | {icir[0]:5.4f}±{icir[1]:2.2f}| {ric[0]:5.4f}±{ric[1]:2.2f} | {ricir[0]:5.4f}±{ricir[1]:2.2f} | {ar[0]:5.4f}±{ar[1]:2.2f} | {ir[0]:5.4f}±{ir[1]:2.2f}| {md[0]:5.4f}±{md[1]:2.2f} |\n"
|
||||
pprint(table)
|
||||
with open("table.md", "w") as f:
|
||||
f.write(table)
|
||||
@@ -180,10 +191,11 @@ def gen_and_save_md_table(metrics):
|
||||
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(times=1, models=None, exclude=False):
|
||||
def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed.
|
||||
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parrallel running the same model
|
||||
for multiple times, and this will be fixed in the future development.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
@@ -193,6 +205,8 @@ def run(times=1, models=None, exclude=False):
|
||||
determines the specific model or list of models to run or exclude.
|
||||
exclude : boolean
|
||||
determines whether the model being used is excluded or included.
|
||||
dataset : str
|
||||
determines the dataset to be used for each model.
|
||||
|
||||
Usage:
|
||||
-------
|
||||
@@ -206,13 +220,16 @@ def run(times=1, models=None, exclude=False):
|
||||
# Case 2 - run specific models multiple times
|
||||
python run_all_model.py 3 mlp
|
||||
|
||||
# Case 3 - run other models except those are given as arguments for multiple times
|
||||
python run_all_model.py 3 [mlp,tft,lstm] True
|
||||
# Case 3 - run specific models multiple times with specific dataset
|
||||
python run_all_model.py 3 mlp Alpha158
|
||||
|
||||
# Case 4 - run specific models for one time
|
||||
# Case 4 - run other models except those are given as arguments for multiple times
|
||||
python run_all_model.py 3 [mlp,tft,lstm] --exclude=True
|
||||
|
||||
# Case 5 - run specific models for one time
|
||||
python run_all_model.py --models=[mlp,lightgbm]
|
||||
|
||||
# Case 5 - run other models except those are given as aruments for one time
|
||||
# Case 6 - run other models except those are given as aruments for one time
|
||||
python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
|
||||
|
||||
"""
|
||||
@@ -226,7 +243,7 @@ def run(times=1, models=None, exclude=False):
|
||||
env_path, python_path, conda_activate = create_env()
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn])
|
||||
yaml_path, req_path = get_all_files(folders[fn], dataset)
|
||||
sys.stderr.write("\n")
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
@@ -240,6 +257,7 @@ def run(times=1, models=None, exclude=False):
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
execute(f"{python_path} -m pip install --upgrade pip") # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
@@ -272,12 +290,15 @@ def run(times=1, models=None, exclude=False):
|
||||
results = cal_mean_std(results)
|
||||
# generating md table
|
||||
sys.stderr.write(f"Generating markdown table...\n")
|
||||
gen_and_save_md_table(results)
|
||||
gen_and_save_md_table(results, dataset)
|
||||
sys.stderr.write("\n")
|
||||
# print erros
|
||||
sys.stderr.write(f"Here are some of the errors of the models...\n")
|
||||
pprint(errors)
|
||||
sys.stderr.write("\n")
|
||||
# move results folder
|
||||
shutil.move(exp_path, exp_path + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
|
||||
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
104
examples/trade/README.md
Normal file
104
examples/trade/README.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# Universal Trading for Order Execution with Oracle Policy Distillation
|
||||
This is the experiment code for our AAAI 2021 paper "[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)", including the implementations of all the compared methods in the paper and a general reinforcement learning framework for order execution in quantitative finance.
|
||||
|
||||
## Abstract
|
||||
As a fundamental problem in algorithmic trading, order execution aims at fulfilling a specific trading order, either liquidation or acquirement, for a given instrument. Towards effective execution strategy, recent years have witnessed the shift from the analytical view with model-based market assumptions to model-free perspective, i.e., reinforcement learning, due to its nature of sequential decision optimization. However, the noisy and yet imperfect market information that can be leveraged by the policy has made it quite challenging to build up sample efficient reinforcement learning methods to achieve effective order execution. In this paper, we propose a novel universal trading policy optimization framework to bridge the gap between the noisy yet imperfect market states and the optimal action sequences for order execution. Particularly, this framework leverages a policy distillation method that can better guide the learning of the common policy towards practically optimal execution by an oracle teacher with perfect information to approximate the optimal trading strategy. The extensive experiments have shown significant improvements of our method over various strong baselines, with reasonable trading actions.
|
||||
|
||||
## Environment Dependencies
|
||||
|
||||
### Dependencies
|
||||
|
||||
```
|
||||
gym==0.17.3
|
||||
torch==1.6.0
|
||||
numba==0.51.2
|
||||
numpy==1.19.1
|
||||
pandas==1.1.3
|
||||
tqdm==4.50.2
|
||||
tianshou==0.3.0.post1
|
||||
env==0.1.0
|
||||
PyYAML==5.4.1
|
||||
redis==3.5.3
|
||||
```
|
||||
|
||||
### Environment Variable
|
||||
|
||||
`EXP_PATH` Absolute path to your config folder, we give folder `exp` as an example.
|
||||
|
||||
`OUTPUT_DIR` Absolute path to your log folder.
|
||||
|
||||
## Data Processing
|
||||
|
||||
For Feature processing, we take Yahoo dataset as an example, which can be precessed in `qlib/examples/highfreq/workflow.py` file. If you have a need to change your data storage path, you can change the `data_path` in `workflow.py`, and then do the following.
|
||||
|
||||
```
|
||||
python workflow.py
|
||||
```
|
||||
|
||||
For order generation, if you have changed change the the `data_path` in `workflow.py`, change `data_path` in `order_gen.py` again, then do the following.
|
||||
|
||||
```
|
||||
python order_gen.py
|
||||
```
|
||||
|
||||
## Training and backtest
|
||||
|
||||
### Config file
|
||||
|
||||
Config file is need to start our project, we take `PPO`, `OPDS` and `OPD` as an example in folder `exp/example`. If you want to use our given config, make sure the `data_path` you set before matches the config file.
|
||||
|
||||
### Baseline method
|
||||
|
||||
To run a method, you can do the following.
|
||||
|
||||
```
|
||||
python main.py --config={config_path}
|
||||
```
|
||||
|
||||
Where `{config_path}` means the relative path from your config.yml to `EXP_PATH`.
|
||||
|
||||
If you need to run our given method such as PPO method, you can do the following.
|
||||
|
||||
```
|
||||
python main.py --config=example/PPO/config.yml
|
||||
```
|
||||
|
||||
### OPD method
|
||||
|
||||
OPD method is a multi step method, at first you should run OPDT as the teacher in OPD method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPDT/config.yml
|
||||
```
|
||||
|
||||
After training, find the `policy_best` file in your OPDT log file and copy it to `trade` file for backtest. Also you can change `policy_path` in the `example/OPDT_b/config.yml` to your `policy_best` file. Then run the backtest method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPDT_b/config.yml
|
||||
```
|
||||
|
||||
then processed feature from teacher. Remember to change `log_path` if you have changed `log_dir` in `OPDT_b/config.yml`.
|
||||
|
||||
```
|
||||
python teacher_feature.py
|
||||
```
|
||||
|
||||
and finally start our OPD method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPD/config.yml
|
||||
```
|
||||
|
||||
## Citation
|
||||
You are more than welcome to citetmu our paper:
|
||||
```
|
||||
@inproceedings{fang2021universal,
|
||||
title={Universal Trading for Order Execution with Oracle Policy Distillation},
|
||||
author={Fang, Yuchen and Ren, Kan and Liu, Weiqing and Zhou, Dong and Zhang, Weinan and Bian, Jiang and Yu, Yong and Liu, Tie-Yan},
|
||||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||
volume={35},
|
||||
number={1},
|
||||
pages={107--115},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
10
examples/trade/__init__.py
Normal file
10
examples/trade/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# from rl4execution import env, trainer, exploration
|
||||
|
||||
# __all__ = [
|
||||
# "env",
|
||||
# "data",
|
||||
# "utils",
|
||||
# "policy",
|
||||
# "trainer",
|
||||
# "exploration",
|
||||
# ]
|
||||
4
examples/trade/action/__init__.py
Normal file
4
examples/trade/action/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import *
|
||||
from .action_rl import *
|
||||
from .action_rule import *
|
||||
from .action_rl import *
|
||||
27
examples/trade/action/action_rl.py
Normal file
27
examples/trade/action/action_rl.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Static_Action(Base_Action):
|
||||
""" """
|
||||
|
||||
def __init__(self, config):
|
||||
self.action_num = config["action_num"]
|
||||
self.action_map = config["action_map"]
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Discrete(self.action_num)
|
||||
|
||||
def get_action(self, action, target, position, **kargs):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
:param position:
|
||||
:param target:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return min(target * self.action_map[action], position)
|
||||
46
examples/trade/action/action_rule.py
Normal file
46
examples/trade/action/action_rule.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Rule_Dynamic(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, max_step_num, t, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param max_step_num:
|
||||
:param t: param **kargs:
|
||||
:param target:
|
||||
:param max_step_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return position / (max_step_num - (t + 1)) * action
|
||||
|
||||
|
||||
class Rule_Static(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, max_step_num, t, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param max_step_num:
|
||||
:param t: param **kargs:
|
||||
:param target:
|
||||
:param max_step_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return target / max_step_num * action
|
||||
20
examples/trade/action/base.py
Normal file
20
examples/trade/action/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
|
||||
class Base_Action(object):
|
||||
""" """
|
||||
|
||||
def __init__(self, config):
|
||||
return
|
||||
|
||||
def __call__(self, *args, **kargs):
|
||||
return self.get_action(*args, **kargs)
|
||||
|
||||
def get_action(self, action):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
|
||||
"""
|
||||
return action
|
||||
46
examples/trade/action/interval_rule.py
Normal file
46
examples/trade/action/interval_rule.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Rule_Static_Interval(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, interval_num, interval, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param interval_num:
|
||||
:param interval: param **kargs:
|
||||
:param target:
|
||||
:param interval_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return target / (interval_num) * action
|
||||
|
||||
|
||||
class Rule_Dynamic_Interval(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, interval_num, interval, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param interval_num:
|
||||
:param interval: param **kargs:
|
||||
:param target:
|
||||
:param interval_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return position / (interval_num - interval) * action
|
||||
1
examples/trade/agent/__init__.py
Normal file
1
examples/trade/agent/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .basic import *
|
||||
69
examples/trade/agent/basic.py
Normal file
69
examples/trade/agent/basic.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch
|
||||
import numpy as np
|
||||
import torch
|
||||
from env import nan_weighted_avg
|
||||
|
||||
|
||||
class TWAP(BasePolicy):
|
||||
""" The TWAP strategy. """
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.max_step_num = config["max_step_num"]
|
||||
self.num_cpus = config["num_cpus"]
|
||||
|
||||
# @njit(parallel=True)
|
||||
def forward(self, batch: Batch, state=None, **kwargs) -> Batch:
|
||||
act = [1] * len(batch.obs.private)
|
||||
return Batch(act=act, state=state)
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
pass
|
||||
|
||||
|
||||
class VWAP(BasePolicy):
|
||||
""" The VWAP strategy."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, batch, state, **kwargs):
|
||||
obs = batch.obs
|
||||
r = np.stack(obs.prediction).reshape(-1)
|
||||
return Batch(act=r, state=state)
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
pass
|
||||
|
||||
|
||||
class AC(VWAP):
|
||||
"""Almgren-Chriss strategy."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.T = config["max_step_num"]
|
||||
self.gamma = 0
|
||||
self.tau = 1
|
||||
self.lamb = config["lambda"]
|
||||
self.eps = 0.0625
|
||||
self.alpha = 0.02
|
||||
self.eta = 2.5e-6
|
||||
|
||||
def forward(self, batch, state, **kwargs):
|
||||
obs = batch.obs
|
||||
sig = np.stack(obs.prediction).reshape(-1)
|
||||
sell = ~np.stack(obs.is_buy).astype(np.bool)
|
||||
data = np.stack(obs.private)
|
||||
t = data[:, 2]
|
||||
t = t + 1
|
||||
k_tild = self.lamb / self.eta * sig * sig
|
||||
k = np.arccosh(k_tild / 2 + 1)
|
||||
act = (np.sinh(k * (self.T - t)) - np.sinh(k * (self.T - t - 1))) / np.sinh(k * self.T)
|
||||
return Batch(act=act, state=state)
|
||||
342
examples/trade/collector.py
Normal file
342
examples/trade/collector.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import gym
|
||||
import time
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Union, Optional, Callable
|
||||
|
||||
from vecenv import BaseVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.data.collector import _batch_set_item
|
||||
|
||||
|
||||
class Collector(object):
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
env: Union[gym.Env, BaseVectorEnv],
|
||||
testing=False,
|
||||
buffer: Optional[ReplayBuffer] = None,
|
||||
preprocess_fn: Optional[Callable[..., Batch]] = None,
|
||||
action_noise: Optional[BaseNoise] = None,
|
||||
reward_metric: Optional[Callable[[np.ndarray], float]] = np.sum,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(env, BaseVectorEnv):
|
||||
env = DummyVectorEnv([lambda: env])
|
||||
self.env = env
|
||||
self.env_num = len(env)
|
||||
# environments that are available in step()
|
||||
# this means all environments in synchronous simulation
|
||||
# but only a subset of environments in asynchronous simulation
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
# self.async is a flag to indicate whether this collector works
|
||||
# with asynchronous simulation
|
||||
self.is_async = env.is_async
|
||||
self.testing = testing
|
||||
# need cache buffers before storing in the main buffer
|
||||
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
self.preprocess_fn = preprocess_fn
|
||||
self.process_fn = policy.process_fn
|
||||
# self._action_space = env.action_space
|
||||
self._action_noise = action_noise
|
||||
self._rew_metric = reward_metric or Collector._default_rew_metric
|
||||
# avoid creating attribute outside __init__
|
||||
# self.reset()
|
||||
|
||||
@staticmethod
|
||||
def _default_rew_metric(x: Union[Number, np.number]) -> Union[Number, np.number]:
|
||||
# this internal function is designed for single-agent RL
|
||||
# for multi-agent RL, a reward_metric must be provided
|
||||
assert np.asanyarray(x).size == 1, "Please specify the reward_metric " "since the reward is not a scalar."
|
||||
return x
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all related variables in the collector."""
|
||||
# use empty Batch for ``state`` so that ``self.data`` supports slicing
|
||||
# convert empty Batch to None when passing data to policy
|
||||
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={})
|
||||
self.reset_env()
|
||||
self.reset_buffer()
|
||||
self.reset_stat()
|
||||
if self._action_noise is not None:
|
||||
self._action_noise.reset()
|
||||
|
||||
def reset_stat(self) -> None:
|
||||
"""Reset the statistic variables."""
|
||||
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
|
||||
|
||||
def reset_buffer(self) -> None:
|
||||
"""Reset the main data buffer."""
|
||||
if self.buffer is not None:
|
||||
self.buffer.reset()
|
||||
|
||||
def get_env_num(self) -> int:
|
||||
""" """
|
||||
return self.env_num
|
||||
|
||||
def reset_env(self) -> None:
|
||||
"""Reset all of the environment(s)' states and the cache buffers."""
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
self.env.reset_sampler()
|
||||
obs, stop_id = self.env.reset()
|
||||
if self.preprocess_fn:
|
||||
obs = self.preprocess_fn(obs=obs).get("obs", obs)
|
||||
self.data.obs = obs
|
||||
for b in self._cached_buf:
|
||||
b.reset()
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in stop_id])
|
||||
|
||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||
"""Reset the hidden state: self.data.state[id]."""
|
||||
state = self.data.state # it is a reference
|
||||
if isinstance(state, torch.Tensor):
|
||||
state[id].zero_()
|
||||
elif isinstance(state, np.ndarray):
|
||||
state[id] = None if state.dtype == np.object else 0
|
||||
elif isinstance(state, Batch):
|
||||
state.empty_(id)
|
||||
|
||||
def collect(
|
||||
self,
|
||||
n_step: Optional[int] = None,
|
||||
n_episode: Optional[Union[int, List[int]]] = None,
|
||||
random: bool = False,
|
||||
render: Optional[float] = None,
|
||||
log_fn=None,
|
||||
no_grad: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""Collect a specified number of step or episode.
|
||||
|
||||
:param int: n_step: how many steps you want to collect.
|
||||
:param n_episode: how many episodes you want to collect. If it is an
|
||||
int, it means to collect at lease ``n_episode`` episodes; if it is
|
||||
a list, it means to collect exactly ``n_episode[i]`` episodes in
|
||||
the i-th environment
|
||||
:param bool: random: whether to use random policy for collecting data,
|
||||
defaults to False.
|
||||
:param float: render: the sleep time between rendering consecutive
|
||||
frames, defaults to None (no rendering).
|
||||
:param bool: no_grad: whether to retain gradient in policy.forward,
|
||||
defaults to True (no gradient retaining).
|
||||
|
||||
.. note::
|
||||
|
||||
One and only one collection number specification is permitted,
|
||||
either ``n_step`` or ``n_episode``.
|
||||
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:List[int]]]: (Default value = None)
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param log_fn: Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:
|
||||
:param List[int]]]: (Default value = None)
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:returns: A dict including the following keys
|
||||
|
||||
* ``n/ep`` the collected number of episodes.
|
||||
* ``n/st`` the collected number of steps.
|
||||
* ``v/st`` the speed of steps per second.
|
||||
* ``v/ep`` the speed of episode per second.
|
||||
* ``rew`` the mean reward over collected episodes.
|
||||
* ``len`` the mean length over collected episodes.
|
||||
|
||||
"""
|
||||
assert (
|
||||
(n_step is not None and n_episode is None and n_step > 0)
|
||||
or (n_step is None and n_episode is not None and np.sum(n_episode) > 0)
|
||||
or self.testing
|
||||
), "Only one of n_step or n_episode is allowed in Collector.collect, "
|
||||
f"got n_step = {n_step}, n_episode = {n_episode}."
|
||||
start_time = time.time()
|
||||
step_count = 0
|
||||
step_time = 0.0
|
||||
reset_time = 0.0
|
||||
model_time = 0.0
|
||||
# episode of each environment
|
||||
episode_count = np.zeros(self.env_num)
|
||||
# If n_episode is a list, and some envs have collected the required
|
||||
# number of episodes, these envs will be recorded in this list, and
|
||||
# they will not be stepped.
|
||||
finished_env_ids = []
|
||||
rewards = []
|
||||
whole_data = Batch()
|
||||
if isinstance(n_episode, list):
|
||||
assert len(n_episode) == self.get_env_num()
|
||||
finished_env_ids = [i for i in self._ready_env_ids if n_episode[i] <= 0]
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in finished_env_ids])
|
||||
while True:
|
||||
if step_count >= 100000 and episode_count.sum() == 0:
|
||||
warnings.warn(
|
||||
"There are already many steps in an episode. "
|
||||
"You should add a time limitation to your environment!",
|
||||
Warning,
|
||||
)
|
||||
|
||||
is_async = self.is_async or len(finished_env_ids) > 0
|
||||
if is_async:
|
||||
# self.data are the data for all environments in async
|
||||
# simulation or some envs have finished,
|
||||
# **only a subset of data are disposed**,
|
||||
# so we store the whole data in ``whole_data``, let self.data
|
||||
# to be the data available in ready environments, and finally
|
||||
# set these back into all the data
|
||||
whole_data = self.data
|
||||
self.data = self.data[self._ready_env_ids]
|
||||
|
||||
# restore the state and the input data
|
||||
last_state = self.data.state
|
||||
if isinstance(last_state, Batch) and last_state.is_empty():
|
||||
last_state = None
|
||||
self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())
|
||||
|
||||
# calculate the next action
|
||||
start = time.time()
|
||||
if random:
|
||||
spaces = self._action_space
|
||||
result = Batch(act=[spaces[i].sample() for i in self._ready_env_ids])
|
||||
else:
|
||||
if no_grad:
|
||||
with torch.no_grad(): # faster than retain_grad version
|
||||
result = self.policy(self.data, last_state)
|
||||
else:
|
||||
result = self.policy(self.data, last_state)
|
||||
model_time += time.time() - start
|
||||
state = result.get("state", Batch())
|
||||
# convert None to Batch(), since None is reserved for 0-init
|
||||
if state is None:
|
||||
state = Batch()
|
||||
self.data.update(state=state, policy=result.get("policy", Batch()))
|
||||
# save hidden state to policy._state, in order to save into buffer
|
||||
if not (isinstance(state, Batch) and state.is_empty()):
|
||||
self.data.policy._state = self.data.state
|
||||
|
||||
self.data.act = to_numpy(result.act)
|
||||
if self._action_noise is not None:
|
||||
assert isinstance(self.data.act, np.ndarray)
|
||||
self.data.act += self._action_noise(self.data.act.shape)
|
||||
|
||||
# step in env
|
||||
start = time.time()
|
||||
if not is_async:
|
||||
obs_next, rew, done, info = self.env.step(self.data.act)
|
||||
if log_fn:
|
||||
log_fn(info)
|
||||
else:
|
||||
# store computed actions, states, etc
|
||||
_batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# fetch finished data
|
||||
obs_next, rew, done, info = self.env.step(self.data.act, id=self._ready_env_ids)
|
||||
self._ready_env_ids = np.array([i["env_id"] for i in info])
|
||||
# get the stepped data
|
||||
self.data = whole_data[self._ready_env_ids]
|
||||
if log_fn:
|
||||
log_fn(info)
|
||||
|
||||
step_time += time.time() - start
|
||||
# move data to self.data
|
||||
self.data.update(obs_next=obs_next, rew=rew, done=done, info=[{} for i in info])
|
||||
|
||||
if render:
|
||||
self.env.render()
|
||||
time.sleep(render)
|
||||
|
||||
# add data into the buffer
|
||||
if self.preprocess_fn:
|
||||
result = self.preprocess_fn(**self.data) # type: ignore
|
||||
self.data.update(result)
|
||||
|
||||
for j, i in enumerate(self._ready_env_ids):
|
||||
# j is the index in current ready_env_ids
|
||||
# i is the index in all environments
|
||||
if self.buffer is None:
|
||||
# users do not want to store data, so we store
|
||||
# small fake data here to make the code clean
|
||||
self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0)
|
||||
else:
|
||||
self._cached_buf[i].add(**self.data[j])
|
||||
|
||||
if done[j]:
|
||||
if not (isinstance(n_episode, list) and episode_count[i] >= n_episode[i]):
|
||||
episode_count[i] += 1
|
||||
rewards.append(self._rew_metric(np.sum(self._cached_buf[i].rew, axis=0)))
|
||||
step_count += len(self._cached_buf[i])
|
||||
if self.buffer is not None:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
if isinstance(n_episode, list) and episode_count[i] >= n_episode[i]:
|
||||
# env i has collected enough data, it has finished
|
||||
finished_env_ids.append(i)
|
||||
self._cached_buf[i].reset()
|
||||
self._reset_state(j)
|
||||
obs_next = self.data.obs_next
|
||||
start = time.time()
|
||||
if sum(done):
|
||||
env_ind_local = np.where(done)[0].tolist()
|
||||
env_ind_global = self._ready_env_ids[env_ind_local]
|
||||
obs_reset, stop_id = self.env.reset(env_ind_global)
|
||||
_ready_env_ids = self._ready_env_ids.tolist()
|
||||
for i in stop_id:
|
||||
finished_env_ids.append(i)
|
||||
# env_ind_local.remove(_ready_env_ids.index(i))
|
||||
if len(env_ind_local) > 0:
|
||||
if self.preprocess_fn:
|
||||
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset)
|
||||
obs_next[env_ind_local] = obs_reset
|
||||
reset_time += time.time() - start
|
||||
self.data.obs = obs_next
|
||||
if is_async:
|
||||
# set data back
|
||||
whole_data = deepcopy(whole_data) # avoid reference in ListBuf
|
||||
_batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# let self.data be the data in all environments again
|
||||
self.data = whole_data
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in finished_env_ids])
|
||||
if n_step:
|
||||
if step_count >= n_step:
|
||||
break
|
||||
else:
|
||||
if isinstance(n_episode, int) and episode_count.sum() >= n_episode:
|
||||
break
|
||||
if isinstance(n_episode, list) and (episode_count >= n_episode).all():
|
||||
break
|
||||
if len(self._ready_env_ids) == 0 and self.testing:
|
||||
break
|
||||
|
||||
# finished envs are ready, and can be used for the next collection
|
||||
self._ready_env_ids = np.array(self._ready_env_ids.tolist() + finished_env_ids)
|
||||
|
||||
# generate the statistics
|
||||
episode_count = sum(episode_count)
|
||||
duration = max(time.time() - start_time, 1e-9)
|
||||
self.collect_step += step_count
|
||||
self.collect_episode += episode_count
|
||||
self.collect_time += duration
|
||||
return {
|
||||
"n/ep": episode_count,
|
||||
"n/st": step_count,
|
||||
"v/st": step_count / duration,
|
||||
"v/ep": episode_count / duration,
|
||||
"t/st": step_time / step_count,
|
||||
"t/re": reset_time / episode_count,
|
||||
"t/mo": model_time / step_count,
|
||||
"rew": np.mean(rewards),
|
||||
"rew_std": np.std(rewards),
|
||||
"len": step_count / episode_count,
|
||||
}
|
||||
1
examples/trade/env/__init__.py
vendored
Normal file
1
examples/trade/env/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
from .env_rl import *
|
||||
481
examples/trade/env/env_rl.py
vendored
Normal file
481
examples/trade/env/env_rl.py
vendored
Normal file
@@ -0,0 +1,481 @@
|
||||
import gym
|
||||
|
||||
gym.logger.set_level(40)
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pickle as pkl
|
||||
import datetime
|
||||
import random
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import tianshou as ts
|
||||
import copy
|
||||
from multiprocessing import Process, Pipe, Queue
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
from scipy.stats import pearsonr
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from util import merge_dicts, nan_weighted_avg, robust_auc
|
||||
import reward
|
||||
import observation
|
||||
import action
|
||||
|
||||
ZERO = 1e-7
|
||||
|
||||
|
||||
class StockEnv(gym.Env):
|
||||
"""Single-assert environment"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.max_step_num = config["max_step_num"]
|
||||
self.limit = config["limit"]
|
||||
self.time_interval = config["time_interval"]
|
||||
self.interval_num = config["interval_num"]
|
||||
self.offset = config["offset"] if "offset" in config else 0
|
||||
if "last_reward" in config:
|
||||
self.last_reward = config["last_reward"]
|
||||
else:
|
||||
self.last_reward = None
|
||||
if "log" in config:
|
||||
self.log = config["log"]
|
||||
else:
|
||||
self.log = True
|
||||
# loader_conf = config['loader']['config']
|
||||
obs_conf = config["obs"]["config"]
|
||||
obs_conf["features"] = config["features"]
|
||||
obs_conf["time_interval"] = self.time_interval
|
||||
obs_conf["max_step_num"] = self.max_step_num
|
||||
self.obs = getattr(observation, config["obs"]["name"])(obs_conf)
|
||||
self.action_func = getattr(action, config["action"]["name"])(config["action"]["config"])
|
||||
self.reward_func_list = []
|
||||
self.reward_log_dict = {}
|
||||
self.reward_coef = []
|
||||
for name, conf in config["reward"].items():
|
||||
self.reward_coef.append(conf.pop("coefficient"))
|
||||
self.reward_func_list.append(getattr(reward, name)(conf))
|
||||
self.reward_log_dict[name] = 0.0
|
||||
self.observation_space = self.obs.get_space()
|
||||
self.action_space = self.action_func.get_space()
|
||||
|
||||
def toggle_log(self, log):
|
||||
self.log = log
|
||||
|
||||
def reset(self, sample):
|
||||
"""
|
||||
|
||||
:param sample:
|
||||
|
||||
"""
|
||||
|
||||
for key in self.reward_log_dict.keys():
|
||||
self.reward_log_dict[key] = 0.0
|
||||
if not sample is None:
|
||||
(
|
||||
self.ins,
|
||||
self.date,
|
||||
self.raw_df_values,
|
||||
self.raw_df_columns,
|
||||
self.raw_df_index,
|
||||
self.feature_dfs,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
) = sample
|
||||
self.raw_df = pd.DataFrame(index=self.raw_df_index, data=self.raw_df_values, columns=self.raw_df_columns,)
|
||||
del self.raw_df_values, self.raw_df_columns, self.raw_df_index
|
||||
start_time = time.time()
|
||||
self.load_time = time.time() - start_time
|
||||
self.day_vwap = nan_weighted_avg(
|
||||
self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num],
|
||||
self.raw_df["$volume0"].values[self.offset : self.offset + self.max_step_num],
|
||||
)
|
||||
try:
|
||||
assert not (np.isnan(self.day_vwap) or np.isinf(self.day_vwap))
|
||||
except:
|
||||
print(self.raw_df)
|
||||
print(self.ins)
|
||||
print(self.day_vwap)
|
||||
self.raw_df.to_pickle("/nfs_data1/kanren/error_df.pkl")
|
||||
self.day_twap = np.nanmean(self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num])
|
||||
self.t = -1 + self.offset
|
||||
self.interval = 0
|
||||
self.position = self.target
|
||||
self.eps_start = time.time()
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
)
|
||||
if self.log:
|
||||
index_array = [
|
||||
np.array([self.ins] * self.max_step_num),
|
||||
self.raw_df.index.to_numpy()[self.offset : self.offset + self.max_step_num],
|
||||
np.array([self.date] * self.max_step_num),
|
||||
]
|
||||
self.traded_log = pd.DataFrame(
|
||||
data={
|
||||
"$v_t": np.nan,
|
||||
"$max_vol_t": (self.raw_df["$volume0"] * self.limit).values[
|
||||
self.offset : self.offset + self.max_step_num
|
||||
],
|
||||
"$traded_t": np.nan,
|
||||
"$vwap_t": self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num],
|
||||
"action": np.nan,
|
||||
},
|
||||
index=index_array,
|
||||
)
|
||||
# v_t: The amount of shares the agent hope to trade
|
||||
# max_vol_t: The max amount of shares can be traded
|
||||
# traded_t: The amount of shares that is acually traded
|
||||
# action: the action of agent, may have various meanings in different settings.
|
||||
self.done = False
|
||||
if self.limit > 1:
|
||||
self.this_valid = np.inf
|
||||
else:
|
||||
self.this_valid = np.nansum(self.raw_df["$volume0"].values) * self.limit
|
||||
self.this_cash = 0
|
||||
|
||||
self.step_time = []
|
||||
self.action_log = [np.nan] * self.interval_num
|
||||
self.reset_time = time.time() - start_time
|
||||
self.real_eps_time = self.reset_time
|
||||
self.total_reward = 0
|
||||
self.total_instant_rew = 0
|
||||
self.last_rew = 0
|
||||
return self.state
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
|
||||
"""
|
||||
start_time = time.time()
|
||||
self.action_log[self.interval] = action
|
||||
volume_t = self.action_func(
|
||||
action,
|
||||
self.target,
|
||||
self.position,
|
||||
max_step_num=self.max_step_num,
|
||||
t=self.t - self.offset,
|
||||
interval=self.interval,
|
||||
interval_num=self.interval_num,
|
||||
)
|
||||
self.interval += 1
|
||||
reward = 0.0
|
||||
time_left = self.max_step_num - self.t - 1 + self.offset
|
||||
|
||||
for i in range(self.time_interval):
|
||||
v_t = volume_t / min(self.time_interval, time_left)
|
||||
self.t += 1
|
||||
if self.t == self.max_step_num - 1 + self.offset:
|
||||
v_t = self.position
|
||||
if self.log:
|
||||
log_index = self.t - self.offset
|
||||
self.traded_log.iat[log_index, 0] = v_t
|
||||
self.traded_log.iat[log_index, 4] = action
|
||||
vwap_t, vol_t = self.raw_df.iloc[self.t][["$vwap0", "$volume0"]]
|
||||
max_vol_t = self.limit * vol_t
|
||||
if self.limit >= 1:
|
||||
max_vol_t = np.inf
|
||||
if v_t > min(self.position, max_vol_t):
|
||||
if self.position <= max_vol_t:
|
||||
v_t = self.position
|
||||
else:
|
||||
v_t = max_vol_t
|
||||
self.position -= v_t
|
||||
self.this_cash += vwap_t * v_t
|
||||
if self.log:
|
||||
self.traded_log.iat[log_index, 2] = v_t
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - vwap_t / self.day_vwap) * 10000
|
||||
PA_t = (1 - vwap_t / self.day_twap) * 10000
|
||||
else:
|
||||
performance_raise = (vwap_t / self.day_vwap - 1) * 10000
|
||||
PA_t = (vwap_t / self.day_twap - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, v_t, self.target, PA_t)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
if self.t == self.max_step_num - 1 + self.offset:
|
||||
break
|
||||
|
||||
if self.position < ZERO:
|
||||
self.done = True
|
||||
|
||||
if self.interval == self.interval_num:
|
||||
self.done = True
|
||||
|
||||
self.step_time.append(time.time() - start_time)
|
||||
self.real_eps_time += time.time() - start_time
|
||||
if self.done:
|
||||
this_traded = self.target - self.position
|
||||
this_vwap = (self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
valid = min(self.target, self.this_valid)
|
||||
this_ffr = (this_traded / valid) if valid > ZERO else 1.0
|
||||
if abs(this_ffr - 1.0) < ZERO:
|
||||
this_ffr = 1.0
|
||||
this_ffr *= 100
|
||||
this_vv_ratio = this_vwap / self.day_vwap
|
||||
vwap = self.raw_df["$vwap0"].values[self.offset : self.max_step_num + self.offset]
|
||||
this_tt_ratio = this_vwap / np.nanmean(vwap)
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vv_ratio) * 10000
|
||||
PA = (1 - this_tt_ratio) * 10000
|
||||
else:
|
||||
performance_raise = (this_vv_ratio - 1) * 10000
|
||||
PA = (this_tt_ratio - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if not reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, this_ffr, this_tt_ratio, self.is_buy)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
if self.log:
|
||||
res = pd.DataFrame(
|
||||
{
|
||||
"target": self.target,
|
||||
"sell": not self.is_buy,
|
||||
"vwap": this_vwap,
|
||||
"this_vv_ratio": this_vv_ratio,
|
||||
"this_ffr": this_ffr,
|
||||
},
|
||||
index=[[self.ins], [self.date]],
|
||||
)
|
||||
money = self.target * self.day_vwap
|
||||
if self.is_buy:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_buy": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_buy": this_ffr,
|
||||
"PR_buy": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_buy": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
else:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_sell": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_sell": this_ffr,
|
||||
"PR_sell": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_sell": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
info = merge_dicts(info, self.reward_log_dict)
|
||||
if self.log:
|
||||
info["df"] = self.traded_log
|
||||
info["res"] = res
|
||||
del self.feature_dfs
|
||||
return self.state, reward, self.done, info
|
||||
|
||||
else:
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
return self.state, reward, self.done, {}
|
||||
|
||||
|
||||
class StockEnv_Acc(StockEnv):
|
||||
def step(self, action):
|
||||
start_time = time.time()
|
||||
self.action_log[self.interval] = action
|
||||
volume_t = self.action_func(
|
||||
action,
|
||||
self.target,
|
||||
self.position,
|
||||
max_step_num=self.max_step_num,
|
||||
t=self.t - self.offset,
|
||||
interval=self.interval,
|
||||
interval_num=self.interval_num,
|
||||
)
|
||||
self.interval += 1
|
||||
reward = 0.0
|
||||
time_left = self.max_step_num - self.t - 1 + self.offset
|
||||
time_left = min(self.time_interval, time_left)
|
||||
|
||||
v_t = np.repeat(volume_t / time_left, time_left)
|
||||
minutes = np.arange(self.t + 1, self.t + time_left + 1)
|
||||
if self.log:
|
||||
log_index = minutes - self.offset
|
||||
self.traded_log.iloc[log_index, 0] = v_t
|
||||
self.traded_log.iloc[log_index, 4] = action
|
||||
vwap_t = self.raw_df.iloc[minutes]["$vwap0"].values
|
||||
vol_t = self.raw_df.iloc[minutes]["$volume0"].values
|
||||
max_vol_t = self.limit * vol_t if self.limit < 1 else np.inf
|
||||
v_t = np.minimum(v_t, max_vol_t)
|
||||
if self.t + time_left == self.max_step_num - 1 + self.offset:
|
||||
left = self.position - v_t.sum()
|
||||
v_t[-1] += left
|
||||
v_t = np.minimum(v_t, max_vol_t)
|
||||
this_money = (v_t * vwap_t).sum()
|
||||
this_vol = v_t.sum()
|
||||
this_vwap = np.nan_to_num(this_money / this_vol)
|
||||
self.t += time_left
|
||||
self.position -= this_vol
|
||||
self.this_cash += this_money
|
||||
if self.log:
|
||||
self.traded_log.iloc[log_index, 2] = v_t
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vwap / self.day_vwap) * 10000
|
||||
PA_t = (1 - this_vwap / self.day_twap) * 10000
|
||||
else:
|
||||
performance_raise = (this_vwap / self.day_vwap - 1) * 10000
|
||||
PA_t = (this_vwap / self.day_twap - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, v_t, self.target, PA_t)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
if self.position < ZERO:
|
||||
self.done = True
|
||||
|
||||
if self.interval == self.interval_num:
|
||||
self.done = True
|
||||
|
||||
self.step_time.append(time.time() - start_time)
|
||||
self.real_eps_time += time.time() - start_time
|
||||
if self.done:
|
||||
this_traded = self.target - self.position
|
||||
this_vwap = (self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
valid = min(self.target, self.this_valid)
|
||||
this_ffr = (this_traded / valid) if valid > ZERO else 1.0
|
||||
if abs(this_ffr - 1.0) < ZERO:
|
||||
this_ffr = 1.0
|
||||
this_ffr *= 100
|
||||
this_vv_ratio = this_vwap / self.day_vwap
|
||||
vwap = self.raw_df["$vwap0"].values[self.offset : self.max_step_num + self.offset]
|
||||
this_tt_ratio = this_vwap / np.nanmean(vwap)
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vv_ratio) * 10000
|
||||
PA = (1 - this_tt_ratio) * 10000
|
||||
else:
|
||||
performance_raise = (this_vv_ratio - 1) * 10000
|
||||
PA = (this_tt_ratio - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if not reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, this_ffr, this_tt_ratio, self.is_buy)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
if self.log:
|
||||
res = pd.DataFrame(
|
||||
{
|
||||
"target": self.target,
|
||||
"sell": not self.is_buy,
|
||||
"vwap": this_vwap,
|
||||
"this_vv_ratio": this_vv_ratio,
|
||||
"this_ffr": this_ffr,
|
||||
},
|
||||
index=[[self.ins], [self.date]],
|
||||
)
|
||||
money = self.target * self.day_vwap
|
||||
if self.is_buy:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_buy": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_buy": this_ffr,
|
||||
"PR_buy": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_buy": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
else:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_sell": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_sell": this_ffr,
|
||||
"PR_sell": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_sell": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
info = merge_dicts(info, self.reward_log_dict)
|
||||
if self.log:
|
||||
info["df"] = self.traded_log
|
||||
info["res"] = res
|
||||
del self.feature_dfs
|
||||
return self.state, reward, self.done, info
|
||||
|
||||
else:
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
return self.state, reward, self.done, {}
|
||||
351
examples/trade/executor.py
Normal file
351
examples/trade/executor.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import env
|
||||
from vecenv import *
|
||||
import sampler
|
||||
import logger
|
||||
import json
|
||||
import os
|
||||
import agent
|
||||
import network
|
||||
import policy
|
||||
import random
|
||||
import tianshou as ts
|
||||
import tqdm
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from collector import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
from util import merge_dicts
|
||||
|
||||
|
||||
def get_best_gpu(force=None):
|
||||
if force is not None:
|
||||
return force
|
||||
s = os.popen("nvidia-smi --query-gpu=memory.free --format=csv")
|
||||
a = []
|
||||
ss = s.read().replace("MiB", "").replace("memory.free", "").split("\n")
|
||||
s.close()
|
||||
for i in range(1, len(ss) - 1):
|
||||
a.append(int(ss[i]))
|
||||
best = int(np.argmax(a))
|
||||
print("the best GPU is ", best, " with free memories of ", ss[best + 1])
|
||||
return best
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
"""
|
||||
|
||||
:param seed:
|
||||
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
class BaseExecutor(object):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir,
|
||||
resources,
|
||||
env_conf,
|
||||
optim=None,
|
||||
policy_conf=None,
|
||||
network_conf=None,
|
||||
policy_path=None,
|
||||
seed=None,
|
||||
):
|
||||
"""A base class for executor
|
||||
|
||||
:param log_dir: The directory to write all the logs.
|
||||
:type log_dir: string
|
||||
:param resources: A dict which describes available computational resources.
|
||||
:type resources: dict
|
||||
:param env_conf: Configurations for the envionments.
|
||||
:type env_conf: dict
|
||||
:param optim: Optimization configuration, defaults to None
|
||||
:type optim: dict, optional
|
||||
:param policy_conf: Configurations for the RL algorithm, defaults to None
|
||||
:type policy_conf: dict, optional
|
||||
:param network_conf: Configurations for policy network_conf, defaults to None
|
||||
:type network_conf: dict, optional
|
||||
:param policy_path: If is not None, would load the policy from this path, defaults to None
|
||||
:type policy_path: string, optional
|
||||
:param seed: Random seed, defaults to None
|
||||
:type seed: int, optional
|
||||
"""
|
||||
# self.config = config
|
||||
self.log_dir = log_dir
|
||||
print(self.log_dir)
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
if resources["device"] == "cuda":
|
||||
resources["device"] = "cuda:" + str(get_best_gpu())
|
||||
self.device = torch.device(resources["device"])
|
||||
if seed:
|
||||
setup_seed(seed)
|
||||
|
||||
assert not policy_path is None or not policy_conf is None, "Policy must be defined"
|
||||
if policy_path:
|
||||
self.policy = torch.load(policy_path, map_location=self.device)
|
||||
self.policy.actor.extractor.device = self.device
|
||||
# policy.eval()
|
||||
elif hasattr(agent, policy_conf["name"]):
|
||||
policy_conf["config"] = merge_dicts(policy_conf["config"], resources)
|
||||
self.policy = getattr(agent, policy_conf["name"])(policy_conf["config"])
|
||||
# print(self.policy)
|
||||
else:
|
||||
assert not network_conf is None
|
||||
if "extractor" in network_conf.keys():
|
||||
net = getattr(network, network_conf["extractor"]["name"] + "_Extractor")(
|
||||
device=self.device, **network_conf["config"]
|
||||
)
|
||||
else:
|
||||
net = getattr(network, network_conf["name"] + "_Extractor")(
|
||||
device=self.device, **network_conf["config"]
|
||||
)
|
||||
net.to(self.device)
|
||||
actor = getattr(network, network_conf["name"] + "_Actor")(
|
||||
extractor=net, device=self.device, **network_conf["config"]
|
||||
)
|
||||
actor.to(self.device)
|
||||
critic = getattr(network, network_conf["name"] + "_Critic")(
|
||||
extractor=net, device=self.device, **network_conf["config"]
|
||||
)
|
||||
critic.to(self.device)
|
||||
self.optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()),
|
||||
lr=optim["lr"],
|
||||
weight_decay=optim["weight_decay"] if "weight_decay" in optim else 0.0,
|
||||
)
|
||||
self.dist = torch.distributions.Categorical
|
||||
try:
|
||||
self.policy = getattr(ts.policy, policy_conf["name"])(
|
||||
actor, critic, self.optim, self.dist, **policy_conf["config"]
|
||||
)
|
||||
except:
|
||||
self.policy = getattr(policy, policy_conf["name"])(
|
||||
actor, critic, self.optim, self.dist, **policy_conf["config"]
|
||||
)
|
||||
self.writer = SummaryWriter(self.log_dir)
|
||||
|
||||
def train(
|
||||
self,
|
||||
max_epoch,
|
||||
step_per_epoch,
|
||||
repeat_per_collect,
|
||||
collect_per_step,
|
||||
batch_size,
|
||||
iteration=0,
|
||||
global_step=0,
|
||||
early_stopping=5,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
"""Run the whole training process.
|
||||
|
||||
:param max_epoch: The total number of epoch.
|
||||
:param step_per_epoch: The times of bp in one epoch.
|
||||
:param collect_per_step: Number of episodes to collect before one bp.
|
||||
:param repeat_per_collect: Times of bps after every rould of experience collecting.
|
||||
:param batch_size: Batch size when bp.
|
||||
:param iteration: The iteration when starting the training, used when fine tuning. (Default value = 0)
|
||||
:param global_step: The number of steps when starting the training, used when fine tuning. (Default value = 0)
|
||||
:param early_stopping: If the test reward does not reach a new high in `early_stopping` iterations, the training would stop. (Default value = 5)
|
||||
:returns: The result on test set.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def train_round(self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs):
|
||||
"""Do an round of training
|
||||
|
||||
:param collect_per_step: Number of episodes to collect before one bp.
|
||||
:param repeat_per_collect: Times of bps after every rould of experience collecting.
|
||||
:param batch_size: Batch size when bp.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def eval(self, order_dir, save_res=False, logdir=None, *args, **kargs):
|
||||
"""Evaluate the policy on orders in order_dir
|
||||
|
||||
:param order_dir: the orders to be evaluated on.
|
||||
:param save_res: whether the result of evaluation be saved to self.logdir/res.json (Default value = False)
|
||||
:param logdir: the place to save the .log and .pkl log files to. If None, don't save logfiles. (Default value = None)
|
||||
:returns: The result of evaluation.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Executor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir,
|
||||
resources,
|
||||
env_conf,
|
||||
train_paths,
|
||||
valid_paths,
|
||||
test_paths,
|
||||
io_conf,
|
||||
optim=None,
|
||||
policy_conf=None,
|
||||
network_conf=None,
|
||||
policy_path=None,
|
||||
seed=None,
|
||||
share_memory=False,
|
||||
buffer_size=200000,
|
||||
q_learning=False,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
"""[summary]
|
||||
|
||||
:param log_dir: The directory to write all the logs.
|
||||
:type log_dir: string
|
||||
:param resources: A dict which describes available computational resources.
|
||||
:type resources: dict
|
||||
:param env_conf: Configurations for the envionments.
|
||||
:type env_conf: dict
|
||||
:param train_paths: The paths of training datasets including orders, backtest files and features.
|
||||
:type train_paths: string
|
||||
:param valid_paths: The paths of validation datasets including orders, backtest files and features.
|
||||
:type valid_paths: string
|
||||
:param test_paths: The paths of test datasets including orders, backtest files and features.
|
||||
:type test_paths: string
|
||||
:param io_conf: Configuration for sampler and loggers.
|
||||
:type io_conf: dict
|
||||
:param share_memory: Whether to use shared memory vecnev, defaults to False
|
||||
:type share_memory: bool, optional
|
||||
:param buffer_size: The size of replay buffer, defaults to 200000
|
||||
:type buffer_size: int, optional
|
||||
"""
|
||||
super().__init__(log_dir, resources, env_conf, optim, policy_conf, network_conf, policy_path, seed)
|
||||
single_env = getattr(env, env_conf["name"])
|
||||
env_conf = merge_dicts(env_conf, train_paths)
|
||||
env_conf["log"] = True
|
||||
print("CPU_COUNT:", resources["num_cpus"])
|
||||
if share_memory:
|
||||
self.env = ShmemVectorEnv([lambda: single_env(env_conf) for _ in range(resources["num_cpus"])])
|
||||
else:
|
||||
self.env = SubprocVectorEnv([lambda: single_env(env_conf) for _ in range(resources["num_cpus"])])
|
||||
self.test_collector = Collector(policy=self.policy, env=self.env, testing=True, reward_metric=np.sum)
|
||||
self.train_collector = Collector(
|
||||
self.policy, self.env, buffer=ts.data.ReplayBuffer(buffer_size), reward_metric=np.sum,
|
||||
)
|
||||
self.train_paths = train_paths
|
||||
self.test_paths = test_paths
|
||||
self.valid_paths = valid_paths
|
||||
train_sampler_conf = train_paths
|
||||
train_sampler_conf["features"] = env_conf["features"]
|
||||
test_sampler_conf = test_paths
|
||||
test_sampler_conf["features"] = env_conf["features"]
|
||||
self.train_sampler = getattr(sampler, io_conf["train_sampler"])(train_sampler_conf)
|
||||
self.test_sampler = getattr(sampler, io_conf["test_sampler"])(test_sampler_conf)
|
||||
self.train_logger = logger.InfoLogger()
|
||||
self.test_logger = getattr(logger, io_conf["test_logger"])
|
||||
|
||||
self.q_learning = q_learning
|
||||
|
||||
def train(
|
||||
self,
|
||||
max_epoch,
|
||||
step_per_epoch,
|
||||
repeat_per_collect,
|
||||
collect_per_step,
|
||||
batch_size,
|
||||
iteration=0,
|
||||
global_step=0,
|
||||
early_stopping=5,
|
||||
train_step_min=0,
|
||||
log_valid=True,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
best_epoch, best_reward = -1, -1
|
||||
stat = {}
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result, losses = self.train_round(repeat_per_collect, collect_per_step, batch_size, iteration)
|
||||
global_step += result["n/st"]
|
||||
iteration += 1
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Train/" + k, result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
self.writer.add_scalar("Train/" + k, stat[k].get(), global_step=global_step)
|
||||
t.update(1)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
result = self.eval(
|
||||
self.valid_paths["order_dir"], logdir=f"{self.log_dir}/valid/{iteration}/" if log_valid else None,
|
||||
)
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Valid/" + k, result[k], global_step=global_step)
|
||||
if best_epoch == -1 or best_reward < result["rew"]:
|
||||
best_reward = result["rew"]
|
||||
best_epoch = epoch
|
||||
best_state = self.policy.state_dict()
|
||||
early_stop_round = 0
|
||||
torch.save(self.policy, f"{self.log_dir}/policy_best")
|
||||
elif global_step >= train_step_min:
|
||||
early_stop_round += 1
|
||||
torch.save(self.policy, f"{self.log_dir}/policy_{epoch}")
|
||||
print(
|
||||
f'Epoch #{epoch}: test_reward: {result["rew"]:.4f}, ' # train_reward: {result_train["rew"]:.4f}, '
|
||||
f"best_reward: {best_reward:.4f} in #{best_epoch}"
|
||||
)
|
||||
if early_stop_round >= early_stopping:
|
||||
print("Early stopped")
|
||||
break
|
||||
print("Testing...")
|
||||
self.policy.load_state_dict(best_state)
|
||||
result = self.eval(self.test_paths["order_dir"], logdir=f"{self.log_dir}/test/", save_res=True)
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Test/" + k, result[k], global_step=global_step)
|
||||
return result
|
||||
|
||||
def train_round(self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs):
|
||||
self.policy.train()
|
||||
self.env.toggle_log(False)
|
||||
self.env.sampler = self.train_sampler
|
||||
if not self.q_learning:
|
||||
self.train_collector.reset()
|
||||
result = self.train_collector.collect(n_episode=collect_per_step, log_fn=self.train_logger)
|
||||
result = merge_dicts(result, self.train_logger.summary())
|
||||
if not self.q_learning:
|
||||
losses = self.policy.update(
|
||||
0, self.train_collector.buffer, batch_size=batch_size, repeat=repeat_per_collect,
|
||||
)
|
||||
else:
|
||||
losses = self.policy.update(batch_size, self.train_collector.buffer,)
|
||||
return result, losses
|
||||
|
||||
def eval(self, order_dir, save_res=False, logdir=None, *args, **kargs):
|
||||
print(f"start evaluating on {order_dir}")
|
||||
self.policy.eval()
|
||||
self.env.toggle_log(True)
|
||||
self.test_sampler.reset(order_dir)
|
||||
self.env.sampler = self.test_sampler
|
||||
self.test_collector.reset()
|
||||
if not logdir is None:
|
||||
if not os.path.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
eval_logger = self.test_logger(logdir, order_dir)
|
||||
eval_logger.reset()
|
||||
else:
|
||||
eval_logger = self.train_logger
|
||||
result = self.test_collector.collect(log_fn=eval_logger)
|
||||
result = merge_dicts(result, eval_logger.summary())
|
||||
if save_res:
|
||||
with open(self.log_dir + "/res.json", "w") as f:
|
||||
json.dump(result, f, sort_keys=True, indent=4)
|
||||
print(f"finish evaluating on {order_dir}")
|
||||
return result
|
||||
76
examples/trade/exp/example/OPD/config.yml
Normal file
76
examples/trade/exp/example/OPD/config.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/OPD
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
- name: teacher_action
|
||||
type: interval
|
||||
size: 1
|
||||
loc: ../data/feature/teacher/
|
||||
obs:
|
||||
name: RuleTeacher
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO_sup
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
sup_coef: 0.01
|
||||
network_conf:
|
||||
name: OPD
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
71
examples/trade/exp/example/OPDS/config.yml
Normal file
71
examples/trade/exp/example/OPDS/config.yml
Normal file
@@ -0,0 +1,71 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/OPDS
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: PPO
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
71
examples/trade/exp/example/OPDT/config.yml
Normal file
71
examples/trade/exp/example/OPDT/config.yml
Normal file
@@ -0,0 +1,71 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/OPDT
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: Teacher
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
76
examples/trade/exp/example/OPDT_b/config.yml
Normal file
76
examples/trade/exp/example/OPDT_b/config.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
seed: 42
|
||||
task: eval
|
||||
log_dir: example/OPDT_b
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/all/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_path: policy_best
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: Teacher
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
search:
|
||||
optim.weight_decay:
|
||||
type: choice
|
||||
value: [0.]
|
||||
70
examples/trade/exp/example/PPO/config.yml
Normal file
70
examples/trade/exp/example/PPO/config.yml
Normal file
@@ -0,0 +1,70 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/PPO
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
PPO_Reward:
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: PPO
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
1
examples/trade/logger/__init__.py
Normal file
1
examples/trade/logger/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .single_logger import *
|
||||
231
examples/trade/logger/single_logger.py
Normal file
231
examples/trade/logger/single_logger.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
from multiprocessing import Queue, Process
|
||||
import time
|
||||
|
||||
|
||||
def GLR(values):
|
||||
"""
|
||||
|
||||
Calculate -P(value | value > 0) / P(value | value < 0)
|
||||
|
||||
"""
|
||||
pos = []
|
||||
neg = []
|
||||
for i in values:
|
||||
if i > 0:
|
||||
pos.append(i)
|
||||
elif i < 0:
|
||||
neg.append(i)
|
||||
return -np.mean(pos) / np.mean(neg)
|
||||
|
||||
|
||||
class DFLogger(object):
|
||||
"""The logger for single-assert backtest.
|
||||
Would save .pkl and .log in log_dir
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir, order_dir, writer=None):
|
||||
self.order_dir = order_dir + "/"
|
||||
self.log_dir = log_dir + "/"
|
||||
if not os.path.exists(log_dir):
|
||||
os.mkdir(log_dir)
|
||||
self.queue = Queue(100000)
|
||||
self.raw_log_dir = self.log_dir
|
||||
|
||||
@staticmethod
|
||||
def _worker(log_dir, order_dir, queue):
|
||||
df_cache = {}
|
||||
stat_cache = {}
|
||||
if not os.path.exists(log_dir):
|
||||
os.mkdir(log_dir)
|
||||
while True:
|
||||
info = queue.get(block=True)
|
||||
if info == "stop":
|
||||
summary = {}
|
||||
for k, v in stat_cache.items():
|
||||
if not k.startswith("money"):
|
||||
summary[k + "_std"] = np.nanstd(v)
|
||||
summary[k + "_mean"] = np.nanmean(v)
|
||||
try:
|
||||
for k in ["PR_sell", "ffr_sell", "PA_sell"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_sell"])
|
||||
except:
|
||||
# summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache['money_sell'])
|
||||
pass
|
||||
try:
|
||||
for k in ["PR_buy", "ffr_buy", "PA_buy"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_buy"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["obs0_PR", "ffr", "PA"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money"])
|
||||
except:
|
||||
pass
|
||||
summary["GLR"] = GLR(stat_cache["PA"])
|
||||
try:
|
||||
summary["GLR_sell"] = GLR(stat_cache["PA_sell"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
summary["GLR_buy"] = GLR(stat_cache["PA_buy"])
|
||||
except:
|
||||
pass
|
||||
queue.put(summary)
|
||||
break
|
||||
elif len(info) == 0:
|
||||
continue
|
||||
else:
|
||||
df = info.pop("df")
|
||||
res = info.pop("res")
|
||||
ins = df.index[0][0]
|
||||
if ins not in df_cache:
|
||||
df_cache[ins] = (
|
||||
[],
|
||||
[],
|
||||
(pd.read_pickle(order_dir + ins + ".pkl.target")['amount'] != 0).sum(),
|
||||
)
|
||||
df_cache[ins][0].append(df)
|
||||
df_cache[ins][1].append(res)
|
||||
if len(df_cache[ins][0]) == df_cache[ins][2]:
|
||||
pd.concat(df_cache[ins][0]).to_pickle(log_dir + ins + ".log")
|
||||
pd.concat(df_cache[ins][1]).to_pickle(log_dir + ins + ".pkl")
|
||||
del df_cache[ins]
|
||||
for k, v in info.items():
|
||||
if k not in stat_cache:
|
||||
stat_cache[k] = []
|
||||
if hasattr(v, "__len__"):
|
||||
stat_cache[k] += list(v)
|
||||
else:
|
||||
stat_cache[k].append(v)
|
||||
|
||||
def reset(self):
|
||||
""" """
|
||||
while not self.queue.empty():
|
||||
self.queue.get()
|
||||
assert self.queue.empty()
|
||||
self.child = Process(target=self._worker, args=(self.log_dir, self.order_dir, self.queue), daemon=True,)
|
||||
self.child.start()
|
||||
|
||||
def set_step(self, step):
|
||||
|
||||
self.log_dir = f"{self.raw_log_dir}{step}/"
|
||||
self.reset()
|
||||
|
||||
def __call__(self, infos):
|
||||
for info in infos:
|
||||
if "env_id" in info:
|
||||
info.pop("env_id")
|
||||
self.update(infos)
|
||||
|
||||
def update(self, infos):
|
||||
"""store values in info into the logger"""
|
||||
for info in infos:
|
||||
self.queue.put(info, block=True)
|
||||
|
||||
def summary(self):
|
||||
""":return: The mean and std of values in infos stored in logger"""
|
||||
summary = {}
|
||||
self.queue.put("stop", block=True)
|
||||
self.child.join()
|
||||
self.child.close()
|
||||
assert self.queue.qsize() == 1
|
||||
summary = self.queue.get()
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
class InfoLogger(DFLogger):
|
||||
""" """
|
||||
|
||||
def __init__(self, *args):
|
||||
self.stat_cache = {}
|
||||
self.queue = Queue(10000)
|
||||
self.child = Process(target=self._worker, args=(self.queue,), daemon=True)
|
||||
self.child.start()
|
||||
|
||||
def _worker(logdir, queue):
|
||||
stat_cache = {}
|
||||
while True:
|
||||
info = queue.get(block=True)
|
||||
if info == "stop":
|
||||
summary = {}
|
||||
for k, v in stat_cache.items():
|
||||
if not k.startswith("money"):
|
||||
summary[k + "_std"] = np.nanstd(v)
|
||||
summary[k + "_mean"] = np.nanmean(v)
|
||||
try:
|
||||
for k in ["PR_sell", "ffr_sell", "PA_sell"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_sell"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["PR_buy", "ffr_buy", "PA_buy"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_buy"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["obs0_PR", "ffr", "PA"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money"])
|
||||
except:
|
||||
pass
|
||||
summary["GLR"] = GLR(stat_cache["PA"])
|
||||
try:
|
||||
summary["GLR_sell"] = GLR(stat_cache["PA_sell"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
summary["GLR_buy"] = GLR(stat_cache["PA_buy"])
|
||||
except:
|
||||
pass
|
||||
queue.put(summary)
|
||||
stat_cache = {}
|
||||
time.sleep(5)
|
||||
continue
|
||||
if len(info) == 0:
|
||||
continue
|
||||
for k, v in info.items():
|
||||
if k == "res" or k == "df":
|
||||
continue
|
||||
if k not in stat_cache:
|
||||
stat_cache[k] = []
|
||||
if hasattr(v, "__len__"):
|
||||
stat_cache[k] += list(v)
|
||||
else:
|
||||
stat_cache[k].append(v)
|
||||
|
||||
def _update(self, info):
|
||||
if len(info) == 0:
|
||||
return
|
||||
ins = df.index[0][0]
|
||||
for k, v in info.items():
|
||||
if k not in self.stat_cache:
|
||||
self.stat_cache[k] = []
|
||||
if hasattr(v, "__len__"):
|
||||
self.stat_cache[k] += list(v)
|
||||
else:
|
||||
self.stat_cache[k].append(v)
|
||||
|
||||
def summary(self):
|
||||
""" """
|
||||
while not self.queue.empty():
|
||||
# print('not empty')
|
||||
# print(self.queue.qsize())
|
||||
time.sleep(1)
|
||||
self.queue.put("stop")
|
||||
# self.child.join()
|
||||
time.sleep(1)
|
||||
while not self.queue.qsize() == 1:
|
||||
# print(self.queue.qsize())
|
||||
time.sleep(1)
|
||||
assert self.queue.qsize() == 1
|
||||
summary = self.queue.get()
|
||||
|
||||
return summary
|
||||
|
||||
def set_step(self, step):
|
||||
return
|
||||
135
examples/trade/main.py
Normal file
135
examples/trade/main.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import re
|
||||
import os
|
||||
import argparse
|
||||
import yaml
|
||||
from executor import Executor
|
||||
import warnings
|
||||
import redis
|
||||
import subprocess
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from util import merge_dicts
|
||||
|
||||
loader = yaml.FullLoader
|
||||
loader.add_implicit_resolver(
|
||||
"tag:yaml.org,2002:float",
|
||||
re.compile(
|
||||
"""^(?:
|
||||
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
||||
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
||||
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
||||
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
||||
|[-+]?\\.(?:inf|Inf|INF)
|
||||
|\\.(?:nan|NaN|NAN))$""",
|
||||
re.X,
|
||||
),
|
||||
list("-+0123456789."),
|
||||
)
|
||||
|
||||
|
||||
def get_full_config(config, dir_name):
|
||||
while "base" in config:
|
||||
base_config = os.path.normpath(os.path.join(dir_name, config.pop("base")))
|
||||
dir_name = os.path.dirname(base_config)
|
||||
with open(base_config, "r") as f:
|
||||
base_config = yaml.load(base_config, Loader=yaml.FullLoader)
|
||||
config = merge_dicts(base_config, config)
|
||||
return config
|
||||
|
||||
|
||||
def run(config):
|
||||
log_dir = config["log_dir"]
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
with open(log_dir + "/config.yml", "w") as f:
|
||||
yaml.dump(config, f)
|
||||
executor = Executor(**config)
|
||||
if config["task"] == "train":
|
||||
return executor.train(**config["optim"])
|
||||
elif config["task"] == "eval":
|
||||
return executor.eval(config["test_paths"]["order_dir"], save_res=True, logdir=config["log_dir"] + "/test/",)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config", type=str)
|
||||
parser.add_argument("-n", "--index", type=int, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(os.cpu_count())
|
||||
|
||||
EXP_PATH = os.environ["EXP_PATH"]
|
||||
config_path = os.path.normpath(os.path.join(EXP_PATH, args.config))
|
||||
EXP_NAME = os.path.relpath(config_path, EXP_PATH)
|
||||
if os.path.isdir(config_path):
|
||||
if not args.index is None:
|
||||
with open(config_path + "/configs.yml") as f:
|
||||
config_list = list(yaml.load_all(f, Loader=loader))
|
||||
config = config_list[args.index]
|
||||
if "PT_OUTPUT_DIR" in os.environ:
|
||||
config["log_dir"] = os.environ["PT_OUTPUT_DIR"]
|
||||
else:
|
||||
log_prefix = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
||||
config["log_dir"] = os.path.join(log_prefix, config["log_dir"])
|
||||
config = get_full_config(config, config_path)
|
||||
run(config)
|
||||
else:
|
||||
redis_server = redis.Redis(
|
||||
host=os.environ["REDIS_SERVER"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
db=0,
|
||||
charset="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
with open(config_path + "/configs.yml") as f:
|
||||
config_list = list(yaml.load_all(f, Loader=loader))
|
||||
config_num = len(config_list)
|
||||
if not redis_server.exists(EXP_NAME):
|
||||
for i in range(config_num):
|
||||
redis_server.rpush(EXP_NAME, i)
|
||||
redis_server.set(f"{EXP_NAME}_{i}", "Pending")
|
||||
else:
|
||||
if redis_server.llen(EXP_NAME) == 0:
|
||||
for i in range(config_num):
|
||||
if (
|
||||
not redis_server.exists(f"{EXP_NAME}_{i}")
|
||||
or redis_server.get(f"{EXP_NAME}_{i}") == "Failed"
|
||||
):
|
||||
redis_server.rpush(EXP_NAME, i)
|
||||
redis_server.set(f"{EXP_NAME}_{i}", "Pending")
|
||||
print(f"Starting..., {redis_server.llen(EXP_NAME)} trails to run")
|
||||
while True:
|
||||
index = redis_server.lpop(EXP_NAME)
|
||||
if index is None:
|
||||
print("All done")
|
||||
break
|
||||
index = int(index)
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Running")
|
||||
print(f"Trail_{index} is running")
|
||||
try:
|
||||
res = subprocess.run(["python", "main.py", "--config", args.config, "--index", str(index),],)
|
||||
except KeyboardInterrupt:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Failed")
|
||||
print(f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
break
|
||||
if res.returncode == 0:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Finished")
|
||||
print(f"Finish running one trail, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
else:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Failed")
|
||||
print(f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
|
||||
elif os.path.isfile(config_path):
|
||||
assert config_path.endswith(".yml"), "Config file should be an yaml file"
|
||||
EXP_NAME = EXP_NAME[:-4]
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.load(f, Loader=loader)
|
||||
config = get_full_config(config, os.path.dirname(config_path))
|
||||
log_prefix = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
||||
config["log_dir"] = os.path.join(log_prefix, config["log_dir"])
|
||||
run(config)
|
||||
else:
|
||||
print("The config path should be a relative path from EXP_PATH")
|
||||
5
examples/trade/network/__init__.py
Normal file
5
examples/trade/network/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .ppo import *
|
||||
from .qmodel import *
|
||||
from .teacher import *
|
||||
from .util import *
|
||||
from .opd import *
|
||||
74
examples/trade/network/opd.py
Normal file
74
examples/trade/network/opd.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class OPD_Extractor(nn.Module):
|
||||
def __init__(self, device="cpu", **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
inp = to_torch(inp, dtype=torch.float32, device=self.device)
|
||||
teacher_action = inp[:, 0]
|
||||
inp = inp[:, 1:]
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240]
|
||||
raw_in = torch.cat((torch.zeros_like(inp[:, : 6 * 30]), raw_in), dim=-1)
|
||||
raw_in = raw_in.reshape(-1, 30, 6).transpose(1, 2)
|
||||
dnn_in = inp[:, 6 * 240 : -1].reshape(batch_size, -1, 2)
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 9, -1)
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
rnn_out = self.rnn(rnn_in)[0]
|
||||
rnn_out = rnn_out[torch.arange(rnn_out.size(0)), seq_len]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
feature = self.fc(fc_in)
|
||||
return feature, teacher_action / 2
|
||||
|
||||
|
||||
class OPD_Actor(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(32, out_shape), nn.Softmax(dim=-1))
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
feature, self.teacher_action = self.extractor(obs)
|
||||
out = self.layer_out(feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class OPD_Critic(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(32, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
feature, self.teacher_action = self.extractor(obs)
|
||||
return self.value_out(feature).squeeze(dim=-1)
|
||||
79
examples/trade/network/ppo.py
Normal file
79
examples/trade/network/ppo.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class PPO_Extractor(nn.Module):
|
||||
def __init__(self, device="cpu", **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
inp = to_torch(inp, dtype=torch.float32, device=self.device)
|
||||
# inp = torch.from_numpy(inp).to(torch.device('cpu'))
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240]
|
||||
raw_in = torch.cat((torch.zeros_like(inp[:, : 6 * 30]), raw_in), dim=-1)
|
||||
raw_in = raw_in.reshape(-1, 30, 6).transpose(1, 2)
|
||||
dnn_in = inp[:, -19:-1].reshape(batch_size, -1, 2)
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 9, -1)
|
||||
assert not torch.isnan(cnn_out).any()
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
assert not torch.isnan(rnn_in).any()
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
assert not torch.isnan(rnn2_in).any()
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
assert not torch.isnan(rnn2_out).any()
|
||||
rnn_out = self.rnn(rnn_in)[0]
|
||||
assert not torch.isnan(rnn_out).any()
|
||||
rnn_out = rnn_out[torch.arange(rnn_out.size(0)), seq_len]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
self.feature = self.fc(fc_in)
|
||||
return self.feature
|
||||
|
||||
|
||||
class PPO_Actor(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(32, out_shape), nn.Softmax(dim=-1))
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
assert not (torch.isnan(self.feature).any() | torch.isinf(self.feature).any()), f"{self.feature}"
|
||||
out = self.layer_out(self.feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class PPO_Critic(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(32, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
return self.value_out(self.feature).squeeze(dim=-1)
|
||||
52
examples/trade/network/qmodel.py
Normal file
52
examples/trade/network/qmodel.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class RNNQModel(nn.Module):
|
||||
def __init__(self, device="cpu", out_shape=10, **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, out_shape),
|
||||
)
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
inp = to_torch(obs, dtype=torch.float32, device=self.device)
|
||||
inp = inp[:, 182:]
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240]
|
||||
raw_in = torch.cat((torch.zeros_like(inp[:, : 6 * 30]), raw_in), dim=-1)
|
||||
raw_in = raw_in.reshape(-1, 30, 6).transpose(1, 2)
|
||||
dnn_in = inp[:, 6 * 240 : -1].reshape(batch_size, -1, 2)
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 9, -1)
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
rnn_out = self.rnn(rnn_in)[0]
|
||||
rnn_out = rnn_out[torch.arange(rnn_out.size(0)), seq_len]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
out = self.fc(fc_in)
|
||||
return out, state
|
||||
69
examples/trade/network/teacher.py
Normal file
69
examples/trade/network/teacher.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class Teacher_Extractor(nn.Module):
|
||||
def __init__(self, device="cpu", feature_size=180, **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
inp = to_torch(inp, dtype=torch.float32, device=self.device)
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240].reshape(-1, 30, 6).transpose(1, 2) ## public part of state
|
||||
dnn_in = inp[:, 6 * 240 : -1].reshape(batch_size, -1, 2) ## private part of state
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 8, -1)
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
rnn_out = self.rnn(rnn_in)[0][:, -1, :]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
self.feature = self.fc(fc_in)
|
||||
return self.feature
|
||||
|
||||
|
||||
class Teacher_Actor(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(32, out_shape), nn.Softmax(dim=-1))
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
out = self.layer_out(self.feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class Teacher_Critic(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(32, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
return self.value_out(self.feature).squeeze(-1)
|
||||
191
examples/trade/network/util.py
Normal file
191
examples/trade/network/util.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key):
|
||||
key = key.unsqueeze(dim=1)
|
||||
length = value.shape[1]
|
||||
key = key.repeat([1, length, 1])
|
||||
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze() # B * l
|
||||
weight = weight.softmax(dim=-1).unsqueeze(dim=-1) # B * l * 1
|
||||
out = (value * weight).sum(dim=1)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
|
||||
class MaskAttention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key, seq_len, maxlen=9):
|
||||
# seq_len: (batch,)
|
||||
device = value.device
|
||||
key = key.unsqueeze(dim=1)
|
||||
length = value.shape[1]
|
||||
key = key.repeat([1, length, 1]) # (batch, 9, 64)
|
||||
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze(-1) # (batch, 9)
|
||||
mask = sequence_mask(seq_len + 1, maxlen=maxlen, device=device)
|
||||
weight[~mask] = float("-inf")
|
||||
weight = weight.softmax(dim=-1).unsqueeze(dim=-1)
|
||||
out = (value * weight).sum(dim=1)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
|
||||
class TFMaskAttention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key, seq_len, maxlen=9):
|
||||
device = value.device
|
||||
key = key.unsqueeze(dim=1)
|
||||
length = value.shape[1]
|
||||
key = key.repeat([1, length, 1])
|
||||
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze(-1)
|
||||
mask = sequence_mask(seq_len + 1, maxlen=maxlen, device=device)
|
||||
mask = mask.repeat(1, 3) # (batch, 9*3)
|
||||
weight[~mask] = float("-inf")
|
||||
weight = weight.softmax(dim=-1).unsqueeze(dim=-1)
|
||||
out = (value * weight).sum(dim=1)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
|
||||
class NNAttention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.q_net = nn.Linear(in_dim, out_dim)
|
||||
self.k_net = nn.Linear(in_dim, out_dim)
|
||||
self.v_net = nn.Linear(in_dim, out_dim)
|
||||
|
||||
def forward(self, Q, K, V):
|
||||
q = self.q_net(Q)
|
||||
k = self.k_net(K)
|
||||
v = self.v_net(V)
|
||||
|
||||
attn = torch.einsum("ijk,ilk->ijl", q, k)
|
||||
attn = attn.to(Q.device)
|
||||
attn_prob = torch.softmax(attn, dim=-1)
|
||||
|
||||
attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v)
|
||||
|
||||
return attn_vec
|
||||
|
||||
|
||||
class Reshape(nn.Module):
|
||||
def __init__(self, *args):
|
||||
super(Reshape, self).__init__()
|
||||
self.shape = args
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(self.shape)
|
||||
|
||||
|
||||
class DARNN(nn.Module):
|
||||
def __init__(self, device="cpu", **kargs):
|
||||
super().__init__()
|
||||
self.emb_dim = kargs["emb_dim"]
|
||||
self.hidden_size = kargs["hidden_size"]
|
||||
self.num_layers = kargs["num_layers"]
|
||||
self.is_bidir = kargs["is_bidir"]
|
||||
self.dropout = kargs["dropout"]
|
||||
self.seq_len = kargs["seq_len"]
|
||||
self.interval = kargs["interval"]
|
||||
self.today_length = 238
|
||||
self.prev_length = 240
|
||||
self.input_length = 480
|
||||
self.input_size = 6
|
||||
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=self.input_size + self.emb_dim,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=self.is_bidir,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.prev_rnn = nn.LSTM(
|
||||
input_size=self.input_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=self.is_bidir,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.fc_out = nn.Linear(in_features=self.hidden_size * 2, out_features=1)
|
||||
self.attention = NNAttention(self.hidden_size, self.hidden_size)
|
||||
self.act_out = nn.Sigmoid()
|
||||
if self.emb_dim != 0:
|
||||
self.pos_emb = nn.Embedding(self.input_length, self.emb_dim)
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs = inputs.view(-1, self.input_length, self.input_size) # [B, T, F]
|
||||
today_input = inputs[:, : self.today_length, :]
|
||||
today_input = torch.cat((torch.zeros_like(today_input[:, :1, :]), today_input), dim=1)
|
||||
prev_input = inputs[:, 240 : 240 + self.prev_length, :]
|
||||
if self.emb_dim != 0:
|
||||
embedding = self.pos_emb(torch.arange(end=self.today_length + 1, device=inputs.device))
|
||||
embedding = embedding.repeat([today_input.size()[0], 1, 1])
|
||||
today_input = torch.cat((today_input, embedding), dim=-1)
|
||||
prev_outs, _ = self.prev_rnn(prev_input)
|
||||
today_outs, _ = self.rnn(today_input)
|
||||
|
||||
outs = self.attention(today_outs, prev_outs, prev_outs)
|
||||
outs = torch.cat((today_outs, outs), dim=-1)
|
||||
outs = outs[:, range(0, self.seq_len * self.interval, self.interval), :]
|
||||
# outs = self.fc_out(outs).squeeze()
|
||||
return self.act_out(self.fc_out(outs).squeeze(-1)), outs
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, dim1=0, dim2=1):
|
||||
super().__init__()
|
||||
self.dim1 = dim1
|
||||
self.dim2 = dim2
|
||||
|
||||
def forward(self, x):
|
||||
return x.transpose(self.dim1, self.dim2)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, *args, **kargs):
|
||||
super().__init__()
|
||||
self.attention = nn.MultiheadAttention(*args, **kargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.attention(x, x, x)[0]
|
||||
|
||||
|
||||
def onehot_enc(y, len):
|
||||
y = y.unsqueeze(-1)
|
||||
y_onehot = torch.zeros(y.shape[0], len)
|
||||
# y_onehot.zero_()
|
||||
y_onehot.scatter(1, y, 1)
|
||||
return y_onehot
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.bool, device=None):
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
mask = ~(torch.ones((len(lengths), maxlen), device=device).cumsum(dim=1).t() > lengths).t()
|
||||
mask.type(dtype)
|
||||
return mask
|
||||
3
examples/trade/observation/__init__.py
Normal file
3
examples/trade/observation/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ppo_obs import *
|
||||
from .teacher_obs import *
|
||||
from .obs_rule import *
|
||||
136
examples/trade/observation/obs_rule.py
Normal file
136
examples/trade/observation/obs_rule.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
import math
|
||||
import json
|
||||
|
||||
|
||||
class BaseObs(object):
|
||||
""" """
|
||||
|
||||
def __init__(self, config):
|
||||
self._observation_space = None
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return self._observation_space
|
||||
|
||||
def get_obs(self, t):
|
||||
pass
|
||||
|
||||
|
||||
class RuleObs(BaseObs):
|
||||
"""The observation for minute-level rule-based agents, which consists of prediction, private state and direction information."""
|
||||
|
||||
def __init__(self, config):
|
||||
feature_size = 0
|
||||
self.features = config["features"]
|
||||
self.time_interval = config["time_interval"]
|
||||
self.max_step_num = config["max_step_num"]
|
||||
for feature in self.features:
|
||||
feature_size += feature["size"]
|
||||
|
||||
self._observation_space = Tuple(
|
||||
(
|
||||
Box(-np.inf, np.inf, shape=(feature_size,), dtype=np.float32),
|
||||
Box(-np.inf, np.inf, shape=(4,), dtype=np.float32),
|
||||
Discrete(2),
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kargs):
|
||||
return self.get_obs(*args, **kargs)
|
||||
|
||||
def get_feature_res(self, df_list, time, interval, whole_day=False, interval_num=8):
|
||||
"""
|
||||
This method would extract the needed feature from the feature dataframe based on the feature name
|
||||
and the description in feature config.
|
||||
|
||||
:param df_list: The dataframes of features, the order is consistent with the feature list.
|
||||
:param time: The index of current minute of the day (starting from -1).
|
||||
:param interval: The index of interval or decition making.
|
||||
:param whole_day: if True, this method would return the concatenate of all dataframe.(Default value = False)
|
||||
|
||||
"""
|
||||
predictions = []
|
||||
if whole_day:
|
||||
try:
|
||||
prediction = [df_list[i].reshape(-1) for i in range(len(df_list))]
|
||||
except:
|
||||
prediction = [df_list[i].reshape(-1) for i in range(len(df_list))]
|
||||
for i, p in enumerate(prediction):
|
||||
if len(p) < interval_num:
|
||||
prediction[i] = np.concatenate((p, np.zeros(interval_num - len(p))), axis=-1)
|
||||
# res = np.stack(prediction).transpose().reshape(-1)
|
||||
return np.concatenate(prediction)
|
||||
for i in range(len(self.features)):
|
||||
feature = self.features[i]
|
||||
df = df_list[i]
|
||||
size = feature["size"]
|
||||
if feature["type"] == "inday":
|
||||
if time == -1:
|
||||
predictions += [0.0] * size
|
||||
else:
|
||||
predictions += df[size * time : size * (time + 1)].reshape(-1).tolist()
|
||||
elif feature["type"] == "daily":
|
||||
predictions += df.reshape(-1)[:size].tolist()
|
||||
elif feature["type"] == "range":
|
||||
if time == -1:
|
||||
predictions += [0.0] * size
|
||||
else:
|
||||
predictions += df[time : size + time].reshape(-1).tolist()
|
||||
elif feature["type"] == "interval":
|
||||
if len(df[interval * size : (interval + 1) * size].reshape(-1)) == size:
|
||||
predictions += df[interval * size : (interval + 1) * size].reshape(-1).tolist()
|
||||
else:
|
||||
predictions += [0.0] * size
|
||||
elif feature["type"] == "step":
|
||||
if len(df[size * (time + 1) : size * (time + 2)].reshape(-1)) == size:
|
||||
predictions += df[size * (time + 1) : size * (time + 2)].reshape(-1).tolist()
|
||||
else:
|
||||
predictions += [0.0] * size
|
||||
|
||||
return np.array(predictions)
|
||||
|
||||
def get_obs(self, raw_df, feature_dfs, t, interval, position, target, is_buy, *args, **kargs):
|
||||
private_state = np.array([position, target, t, self.max_step_num])
|
||||
prediction_state = self.get_feature_res(feature_dfs, t, interval)
|
||||
return {
|
||||
"prediction": prediction_state,
|
||||
"private": private_state,
|
||||
"is_buy": int(is_buy),
|
||||
}
|
||||
|
||||
|
||||
class RuleInterval(RuleObs):
|
||||
"""
|
||||
The observation for interval_level rule based strategy.
|
||||
|
||||
Consist of interval prediction, private state, direction
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def get_obs(
|
||||
self,
|
||||
raw_df,
|
||||
feature_dfs,
|
||||
t,
|
||||
interval,
|
||||
position,
|
||||
target,
|
||||
is_buy,
|
||||
max_step_num,
|
||||
interval_num,
|
||||
action=1.0,
|
||||
*args,
|
||||
**kargs
|
||||
):
|
||||
private_state = np.array([position, target, interval - 1, interval_num])
|
||||
prediction_state = self.get_feature_res(feature_dfs, t, interval)
|
||||
return {
|
||||
"prediction": prediction_state,
|
||||
"private": private_state,
|
||||
"is_buy": int(is_buy),
|
||||
"action": action,
|
||||
}
|
||||
28
examples/trade/observation/ppo_obs.py
Normal file
28
examples/trade/observation/ppo_obs.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
import math
|
||||
import json
|
||||
|
||||
from .obs_rule import RuleObs
|
||||
|
||||
|
||||
class PPOObs(RuleObs):
|
||||
"""The observation defined in IJCAI 2020. The action of previous state is included in private state"""
|
||||
|
||||
def get_obs(
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, action=0,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
|
||||
public_state = self.get_feature_res(feature_dfs, t, interval, whole_day=True)
|
||||
# market_state = feature_dfs[0].reshape(-1)[:6*240]
|
||||
private_state = np.array([position / target, (t + 1) / max_step_num, action])
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(list_private_state, [0.0] * 3 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
return np.concatenate((public_state, list_private_state, seqlen))
|
||||
55
examples/trade/observation/teacher_obs.py
Normal file
55
examples/trade/observation/teacher_obs.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
import math
|
||||
import json
|
||||
|
||||
from .obs_rule import RuleObs
|
||||
|
||||
|
||||
class TeacherObs(RuleObs):
|
||||
"""
|
||||
The Observation used for OPD method.
|
||||
|
||||
Consist of public state(raw feature), private state, seqlen
|
||||
|
||||
"""
|
||||
|
||||
def get_obs(
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, *args, **kargs,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
public_state = self.get_feature_res(feature_dfs, t, interval, whole_day=True)
|
||||
private_state = np.array([position / target, (t + 1) / max_step_num])
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(list_private_state, [0.0] * 2 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
assert not (
|
||||
np.isnan(list_private_state).any() | np.isinf(list_private_state).any()
|
||||
), f"{private_state}, {target}"
|
||||
assert not (np.isnan(public_state).any() | np.isinf(public_state).any()), f"{public_state}"
|
||||
return np.concatenate((public_state, list_private_state, seqlen))
|
||||
|
||||
|
||||
class RuleTeacher(RuleObs):
|
||||
""" """
|
||||
|
||||
def get_obs(
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, *args, **kargs,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
public_state = feature_dfs[0].reshape(-1)[: 6 * 240]
|
||||
private_state = np.array([position / target, (t + 1) / max_step_num])
|
||||
teacher_action = self.get_feature_res(feature_dfs, t, interval)[-self.features[1]["size"] :]
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(list_private_state, [0.0] * 2 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
return np.concatenate((teacher_action, public_state, list_private_state, seqlen))
|
||||
62
examples/trade/order_gen.py
Normal file
62
examples/trade/order_gen.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
data_path = '../data/'
|
||||
in_dir = os.path.join(data_path, 'backtest/')
|
||||
|
||||
### create order folders ####
|
||||
|
||||
def generate_order(df, start, end):
|
||||
# df['date'] = df.index.map(lambda x: x[1].date())
|
||||
# df.set_index('date', append=True, inplace=True)
|
||||
df = df.groupby('date').take(range(start, end)).droplevel(level=0)
|
||||
div = df['$volume0'].rolling((end - start)*60).mean().shift(1).groupby(level='date').transform('first')
|
||||
order = df.groupby(level=(2, 0)).mean().dropna()
|
||||
order = pd.DataFrame(order)
|
||||
order['amount'] = np.random.lognormal(-3.28, 1.14) * order['$volume0']
|
||||
order['order_type'] = 0
|
||||
order = order.drop(columns=["$volume0", "$vwap0"])
|
||||
return order
|
||||
|
||||
def w_order(f, start, end):
|
||||
df = pd.read_pickle(in_dir + f)
|
||||
#df['date'] = df.index.get_level_values(1).map(lambda x: x.date())
|
||||
#df = df.set_index('date', append=True, drop=True)
|
||||
|
||||
order = generate_order(df, start, end)
|
||||
order_train = order[order.index.get_level_values(0) < '2020-12-01']
|
||||
order_test = order[order.index.get_level_values(0) >= '2020-12-01']
|
||||
order_valid = order_test[order_test.index.get_level_values(0) < '2021-01-01']
|
||||
order_test = order_test[order_test.index.get_level_values(0) >= '2021-01-01']
|
||||
if len(order_train) > 0:
|
||||
order_train.to_pickle(train_path + f[:-9] + '.target')
|
||||
if len(order_valid) > 0:
|
||||
order_valid.to_pickle(valid_path + f[:-9] + '.target')
|
||||
if len(order_test) > 0:
|
||||
order_test.to_pickle(test_path + f[:-9] + '.target')
|
||||
if len(order) > 0:
|
||||
order.to_pickle(all_path + f[:-9] + '.target')
|
||||
return 0
|
||||
|
||||
train_path = os.path.join(data_path, "order/train/")
|
||||
if not os.path.exists(train_path):
|
||||
os.makedirs(train_path)
|
||||
|
||||
valid_path = os.path.join(data_path, "order/valid/")
|
||||
if not os.path.exists(valid_path):
|
||||
os.makedirs(valid_path)
|
||||
|
||||
test_path = os.path.join(data_path, "order/test/")
|
||||
if not os.path.exists(test_path):
|
||||
os.makedirs(test_path)
|
||||
|
||||
all_path = os.path.join(data_path, "order/all/")
|
||||
if not os.path.exists(all_path):
|
||||
os.makedirs(all_path)
|
||||
|
||||
res = Parallel(n_jobs=64)(delayed(w_order)(f, 0, 239) for f in os.listdir(in_dir))
|
||||
print(sum(res))
|
||||
2
examples/trade/policy/__init__.py
Normal file
2
examples/trade/policy/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .ppo_supervision import *
|
||||
from .ppo import *
|
||||
255
examples/trade/policy/ppo.py
Normal file
255
examples/trade/policy/ppo.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data import to_torch
|
||||
from numba import njit
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from util import to_numpy, to_torch_as
|
||||
|
||||
|
||||
def _episodic_return(
|
||||
v_s_: np.ndarray, rew: np.ndarray, done: np.ndarray, gamma: float, gae_lambda: float,
|
||||
) -> np.ndarray:
|
||||
"""Numba speedup: 4.1s -> 0.057s."""
|
||||
returns = np.roll(v_s_, 1)
|
||||
m = (1.0 - done) * gamma
|
||||
delta = rew + v_s_ * m - returns
|
||||
m *= gae_lambda
|
||||
gae = 0.0
|
||||
for i in range(len(rew) - 1, -1, -1):
|
||||
gae_new = delta[i] + m[i] * gae
|
||||
gae = gae_new
|
||||
returns[i] += gae
|
||||
return returns
|
||||
|
||||
|
||||
class PPO(PGPolicy):
|
||||
""" The PPO policy with Teacher supervision"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: torch.distributions.Distribution,
|
||||
teacher=None,
|
||||
discount_factor: float = 0.99,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
eps_clip: float = 0.2,
|
||||
vf_clip_para=10.0,
|
||||
vf_coef: float = 0.5,
|
||||
kl_coef=0.5,
|
||||
kl_target=0.01,
|
||||
ent_coef: float = 0.01,
|
||||
sup_coef=0.1,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
reward_normalization: bool = True,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
self._eps_clip = eps_clip
|
||||
self._vf_clip_para = vf_clip_para
|
||||
self._w_vf = vf_coef
|
||||
self._w_ent = ent_coef
|
||||
self._range = action_range
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.optim = optim
|
||||
self.sup_coef = sup_coef
|
||||
self.kl_target = kl_target
|
||||
self.kl_coef = kl_coef
|
||||
self._batch = 64
|
||||
assert 0 <= gae_lambda <= 1, "GAE lambda should be in [0, 1]."
|
||||
self._lambda = gae_lambda
|
||||
assert dual_clip is None or dual_clip > 1, "Dual-clip PPO parameter should greater than 1."
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
self._rew_norm = reward_normalization
|
||||
if not teacher is None:
|
||||
self.teacher = torch.load(teacher, map_location=torch.device("cpu"))
|
||||
self.teacher.to(self.actor.device)
|
||||
self.teacher.actor.extractor.device = self.actor.device
|
||||
else:
|
||||
self.teacher = None
|
||||
|
||||
@staticmethod
|
||||
def compute_episodic_return(
|
||||
batch: Batch,
|
||||
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
rew_norm: bool = False,
|
||||
) -> Batch:
|
||||
"""Compute returns over given full-length episodes.
|
||||
Implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
||||
:param batch: a data batch which contains several full-episode data
|
||||
chronologically.
|
||||
:type batch: :class:`~tianshou.data.Batch`
|
||||
:param v_s_: the value function of all next states :math:`V(s')`.
|
||||
:type v_s_: numpy.ndarray
|
||||
:param float gamma: the discount factor, should be in [0, 1], defaults
|
||||
to 0.99.
|
||||
:param float gae_lambda: the parameter for Generalized Advantage
|
||||
Estimation, should be in [0, 1], defaults to 0.95.
|
||||
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
|
||||
to False.
|
||||
:return: a Batch. The result will be stored in batch.returns as a numpy
|
||||
array with shape (bsz, ).
|
||||
"""
|
||||
rew = batch.rew
|
||||
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten())
|
||||
assert not np.isnan(v_s_).any()
|
||||
assert not np.isnan(rew).any()
|
||||
assert not np.isnan(batch.done).any()
|
||||
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
|
||||
assert not np.isnan(returns).any()
|
||||
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
|
||||
returns = (returns - returns.mean()) / returns.std()
|
||||
assert not np.isnan(returns).any()
|
||||
batch.returns = returns
|
||||
return batch
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if not np.isclose(std, 0):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
assert not np.isnan(batch.rew).any()
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
else:
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False):
|
||||
v_.append(self.critic(b.obs_next))
|
||||
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||
assert not np.isnan(v_).any()
|
||||
return self.compute_episodic_return(batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs) -> Batch:
|
||||
"""Compute action over the given batch data."""
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
if self.training:
|
||||
try:
|
||||
act = dist.sample()
|
||||
except:
|
||||
print(logits)
|
||||
act = dist.sample()
|
||||
else:
|
||||
act = torch.argmax(logits, dim=1)
|
||||
if self._range:
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses, kl_losses = [], [], [], [], []
|
||||
if self.teacher is not None:
|
||||
supervision_losses = []
|
||||
v = []
|
||||
old_log_prob = []
|
||||
feature = []
|
||||
old_logits = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(batch_size, shuffle=False):
|
||||
v.append(self.critic(b.obs))
|
||||
b_ = self(b)
|
||||
dist = b_.dist
|
||||
logits = b_.logits
|
||||
old_log_prob.append(dist.log_prob(to_torch_as(b.act, v[0])))
|
||||
old_logits.append(logits)
|
||||
if not self.teacher is None:
|
||||
with torch.no_grad():
|
||||
for b in batch.split(batch_size, shuffle=False):
|
||||
self.teacher(b)
|
||||
feature.append(self.teacher.actor.feature)
|
||||
batch.old_feature = torch.cat(feature, dim=0)
|
||||
batch.old_logits = torch.cat(old_logits, dim=0)
|
||||
batch.v = torch.cat(v, dim=0) # old value
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch_as(batch.returns, v[0]).reshape(batch.v.shape)
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.returns = (batch.returns - mean) / std
|
||||
batch.adv = batch.returns - batch.v
|
||||
if self._rew_norm:
|
||||
mean, std = batch.adv.mean(), batch.adv.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.adv = (batch.adv - mean) / std
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
dist = self(b).dist
|
||||
value = self.critic(b.obs)
|
||||
if not self.teacher is None:
|
||||
feature = self.actor.feature
|
||||
# print(feature.pow(2).mean())
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.item())
|
||||
if self._value_clip:
|
||||
v_clip = b.v + (value - b.v).clamp(-self._vf_clip_para, self._vf_clip_para)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
vf_loss = torch.max(vf1, vf2).mean()
|
||||
else:
|
||||
vf_loss = (b.returns - value).pow(2).mean()
|
||||
if not self.teacher is None:
|
||||
supervision_loss = (b.old_feature - feature).pow(2).mean()
|
||||
supervision_losses.append(supervision_loss.item())
|
||||
kl = torch.distributions.kl.kl_divergence(self.dist_fn(b.old_logits), dist)
|
||||
kl_loss = kl.mean()
|
||||
kl_losses.append(kl_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss + self.kl_coef * kl_loss
|
||||
if self.teacher is not None:
|
||||
loss += self.sup_coef * supervision_loss
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm,
|
||||
)
|
||||
self.optim.step()
|
||||
cur_kl = np.mean(kl_losses)
|
||||
if cur_kl > 2.0 * self.kl_target:
|
||||
self.kl_coef *= 1.5
|
||||
elif cur_kl < 0.5 * self.kl_target:
|
||||
self.kl_coef *= 0.5
|
||||
res = {
|
||||
"loss/total_loss": losses,
|
||||
"loss/policy": clip_losses,
|
||||
"loss/vf": vf_losses,
|
||||
"loss/entropy": ent_losses,
|
||||
"loss/kl": kl_losses,
|
||||
}
|
||||
if not self.teacher is None:
|
||||
res["loss/supervision"] = supervision_losses
|
||||
return res
|
||||
|
||||
|
||||
Student_new = PPO
|
||||
187
examples/trade/policy/ppo_supervision.py
Normal file
187
examples/trade/policy/ppo_supervision.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data import to_torch
|
||||
from numba import njit
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from util import to_numpy, to_torch_as
|
||||
|
||||
from .ppo import _episodic_return
|
||||
|
||||
|
||||
class PPO_sup(PGPolicy):
|
||||
"""The PPO policy with a log-likelihood supervision loss"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: torch.distributions.Distribution,
|
||||
discount_factor: float = 0.99,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
eps_clip: float = 0.2,
|
||||
vf_clip_para=10.0,
|
||||
vf_coef: float = 0.5,
|
||||
kl_coef=0.5,
|
||||
kl_target=0.01,
|
||||
ent_coef: float = 0.01,
|
||||
sup_coef=0.1,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
reward_normalization: bool = True,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
self._eps_clip = eps_clip
|
||||
self._vf_clip_para = vf_clip_para
|
||||
self._w_vf = vf_coef
|
||||
self._w_ent = ent_coef
|
||||
self._range = action_range
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.optim = optim
|
||||
self.sup_coef = sup_coef
|
||||
self.kl_target = kl_target
|
||||
self.kl_coef = kl_coef
|
||||
self._batch = 64
|
||||
assert 0 <= gae_lambda <= 1, "GAE lambda should be in [0, 1]."
|
||||
self._lambda = gae_lambda
|
||||
assert dual_clip is None or dual_clip > 1, "Dual-clip PPO parameter should greater than 1."
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
self._rew_norm = reward_normalization
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if not np.isclose(std, 0):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
else:
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False):
|
||||
v_.append(self.critic(b.obs_next))
|
||||
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||
return self.compute_episodic_return(batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs) -> Batch:
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
if self.training:
|
||||
act = dist.sample()
|
||||
else:
|
||||
act = torch.argmax(logits, dim=1)
|
||||
if self._range:
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses, kl_losses, supervision_losses = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
v = []
|
||||
old_log_prob = []
|
||||
teacher_action = []
|
||||
old_logits = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(batch_size, shuffle=False):
|
||||
v.append(self.critic(b.obs))
|
||||
b_ = self(b)
|
||||
dist = b_.dist
|
||||
logits = b_.logits
|
||||
old_log_prob.append(dist.log_prob(to_torch_as(b.act, v[0])))
|
||||
old_logits.append(logits)
|
||||
teacher_action.append(self.actor.teacher_action)
|
||||
|
||||
batch.teacher_action = torch.cat(teacher_action, dim=0).to(torch.long)
|
||||
batch.old_logits = torch.cat(old_logits, dim=0)
|
||||
batch.v = torch.cat(v, dim=0) # old value
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch_as(batch.returns, v[0]).reshape(batch.v.shape)
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.returns = (batch.returns - mean) / std
|
||||
batch.adv = batch.returns - batch.v
|
||||
if self._rew_norm:
|
||||
mean, std = batch.adv.mean(), batch.adv.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.adv = (batch.adv - mean) / std
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
res = self(b)
|
||||
logits = res.logits
|
||||
dist = res.dist
|
||||
value = self.critic(b.obs)
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.item())
|
||||
if self._value_clip:
|
||||
v_clip = b.v + (value - b.v).clamp(-self._vf_clip_para, self._vf_clip_para)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
vf_loss = torch.max(vf1, vf2).mean()
|
||||
else:
|
||||
vf_loss = (b.returns - value).pow(2).mean()
|
||||
supervision_loss = F.nll_loss(logits.log(), b.teacher_action)
|
||||
supervision_losses.append(supervision_loss.item())
|
||||
kl = torch.distributions.kl.kl_divergence(self.dist_fn(b.old_logits), dist)
|
||||
kl_loss = kl.mean()
|
||||
kl_losses.append(kl_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss + self.kl_coef * kl_loss
|
||||
loss += self.sup_coef * supervision_loss
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm,
|
||||
)
|
||||
self.optim.step()
|
||||
if hasattr(self.actor, "callback"):
|
||||
self.actor.callback()
|
||||
cur_kl = np.mean(kl_losses)
|
||||
if cur_kl > 2.0 * self.kl_target:
|
||||
self.kl_coef *= 1.5
|
||||
elif cur_kl < 0.5 * self.kl_target:
|
||||
self.kl_coef *= 0.5
|
||||
res = {
|
||||
"loss/total_loss": losses,
|
||||
"loss/policy": clip_losses,
|
||||
"loss/vf": vf_losses,
|
||||
"loss/entropy": ent_losses,
|
||||
"loss/kl": kl_losses,
|
||||
"loss/supervision": supervision_losses,
|
||||
}
|
||||
return res
|
||||
10
examples/trade/requirements.txt
Normal file
10
examples/trade/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
gym==0.17.3
|
||||
torch==1.6.0
|
||||
numba==0.51.2
|
||||
numpy==1.19.1
|
||||
pandas==1.1.3
|
||||
tqdm==4.50.2
|
||||
tianshou==0.3.0.post1
|
||||
env==0.1.0
|
||||
PyYAML==5.4.1
|
||||
redis==3.5.3
|
||||
4
examples/trade/reward/__init__.py
Normal file
4
examples/trade/reward/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import *
|
||||
from .pa_penalty import *
|
||||
from .ppo_reward import *
|
||||
from .vp_penalty import *
|
||||
38
examples/trade/reward/base.py
Normal file
38
examples/trade/reward/base.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Abs_Reward(object):
|
||||
"""The abstract class for Reward."""
|
||||
|
||||
def __init__(self, config):
|
||||
return
|
||||
|
||||
def get_reward(self):
|
||||
""":return: reward"""
|
||||
reward = 0
|
||||
return reward
|
||||
|
||||
def __call__(self, *args, **kargs):
|
||||
return self.get_reward(*args, **kargs)
|
||||
|
||||
def isinstant(self):
|
||||
""":return: Whether the reward should be given at every timestep or only at the end of this episode."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Instant_Reward(Abs_Reward):
|
||||
def __init__(self, config):
|
||||
self.ffr_ratio = config["ffr_ratio"]
|
||||
self.vvr_ratio = config["vvr_ratio"]
|
||||
|
||||
def isinstant(self):
|
||||
return True
|
||||
|
||||
|
||||
class EndEpisode_Reward(Abs_Reward):
|
||||
def __init__(self, config):
|
||||
self.ffr_ratio = config["ffr_ratio"]
|
||||
self.vvr_ratio = config["vvr_ratio"]
|
||||
|
||||
def isinstant(self):
|
||||
return False
|
||||
14
examples/trade/reward/pa_penalty.py
Normal file
14
examples/trade/reward/pa_penalty.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import numpy as np
|
||||
from .base import Instant_Reward
|
||||
|
||||
|
||||
class PA_Penalty(Instant_Reward):
|
||||
"""Reward: (Abs(tt_ratio_t - 1) * 10000 * v_t / target - v_t^2 * penalty) / 100"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.penalty = config["penalty"]
|
||||
|
||||
def get_reward(self, performance_raise, v_t, target, PA_t, *args):
|
||||
reward = PA_t * v_t / target
|
||||
reward -= self.penalty * (v_t / target) ** 2
|
||||
return reward / 100
|
||||
22
examples/trade/reward/ppo_reward.py
Normal file
22
examples/trade/reward/ppo_reward.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import numpy as np
|
||||
from .base import Abs_Reward
|
||||
|
||||
|
||||
class PPO_Reward(Abs_Reward):
|
||||
"""The reward function defined in IJCAI 2020"""
|
||||
|
||||
def __init__(self, *args):
|
||||
pass
|
||||
|
||||
def isinstant(self):
|
||||
return False
|
||||
|
||||
def get_reward(self, performace_raise, ffr, this_tt_ratio, is_buy):
|
||||
if is_buy:
|
||||
this_tt_ratio = 1 / this_tt_ratio
|
||||
if this_tt_ratio < 1:
|
||||
return -1.0
|
||||
elif this_tt_ratio < 1.1:
|
||||
return 0.0
|
||||
else:
|
||||
return 1.0
|
||||
37
examples/trade/reward/vp_penalty.py
Normal file
37
examples/trade/reward/vp_penalty.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
from .base import Instant_Reward
|
||||
|
||||
|
||||
class VP_Penalty_small(Instant_Reward):
|
||||
"""Reward: (Abs(vv_ratio_t - 1) * 10000 - v_t^2 * penalty) / 100"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.penalty = config["penalty"]
|
||||
|
||||
def get_reward(self, performance_raise, v_t, target, *args):
|
||||
"""
|
||||
|
||||
:param performance_raise: Abs(vv_ratio_t - 1) * 10000.
|
||||
:param target: Target volume
|
||||
:param v_t: The traded volume
|
||||
"""
|
||||
assert target > 0
|
||||
reward = performance_raise * v_t / target
|
||||
reward -= self.penalty * (v_t / target) ** 2
|
||||
assert not (np.isnan(reward) or np.isinf(reward)), f"{performance_raise}, {v_t}, {target}"
|
||||
return reward / 100
|
||||
|
||||
|
||||
class VP_Penalty_small_vec(VP_Penalty_small):
|
||||
def get_reward(self, performance_raise, v_t, target, *args):
|
||||
"""
|
||||
|
||||
:param performance_raise: Abs(vv_ratio_t - 1) * 10000.
|
||||
:param target: Target volume
|
||||
:param v_t: The traded volume
|
||||
"""
|
||||
assert target > 0
|
||||
reward = performance_raise * v_t.sum() / target
|
||||
reward -= self.penalty * ((v_t / target) ** 2).sum()
|
||||
assert not (np.isnan(reward) or np.isinf(reward)), f"{performance_raise}, {v_t}, {target}"
|
||||
return reward / 100
|
||||
1
examples/trade/sampler/__init__.py
Normal file
1
examples/trade/sampler/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .single_sampler import *
|
||||
184
examples/trade/sampler/single_sampler.py
Normal file
184
examples/trade/sampler/single_sampler.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from multiprocessing.context import Process
|
||||
from multiprocessing import Queue
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
|
||||
|
||||
def toArray(data):
|
||||
if type(data) == np.ndarray:
|
||||
return data
|
||||
|
||||
elif type(data) == list:
|
||||
data = np.array(data)
|
||||
return data
|
||||
|
||||
elif type(data) == pd.DataFrame:
|
||||
share_index = toArray(data.index)
|
||||
share_value = toArray(data.values)
|
||||
share_colmns = toArray(data.columns)
|
||||
return share_index, share_value, share_colmns
|
||||
|
||||
else:
|
||||
try:
|
||||
share_array = np.array(data)
|
||||
return share_array
|
||||
except:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Sampler:
|
||||
"""The sampler for training of single-assert RL."""
|
||||
|
||||
def __init__(self, config):
|
||||
self.raw_dir = config["raw_dir"] + "/"
|
||||
self.order_dir = config["order_dir"] + "/"
|
||||
self.ins_list = [f[:-11] for f in os.listdir(self.order_dir) if f.endswith("target")]
|
||||
self.features = config["features"]
|
||||
self.queue = Queue(1000)
|
||||
self.child = None
|
||||
self.ins = None
|
||||
self.raw_df = None
|
||||
self.df_list = None
|
||||
self.order_df = None
|
||||
|
||||
@staticmethod
|
||||
def _worker(order_dir, raw_dir, features, ins_list, queue):
|
||||
ins = None
|
||||
index = 0
|
||||
date_list = []
|
||||
while True:
|
||||
if ins is None or index == len(date_list):
|
||||
ins = np.random.choice(ins_list, 1)[0]
|
||||
# print(ins)
|
||||
order_df = pd.read_pickle(order_dir + ins + ".pkl.target")
|
||||
feature_df_list = []
|
||||
for feature in features:
|
||||
feature_df_list.append(pd.read_pickle(f"{feature['loc']}/{ins}.pkl"))
|
||||
raw_df = pd.read_pickle(raw_dir + ins + ".pkl.backtest")
|
||||
date_list = order_df.index.get_level_values(0).tolist()
|
||||
index = 0
|
||||
date = date_list[index]
|
||||
day_order_df = order_df.iloc[index]
|
||||
target = day_order_df["amount"]
|
||||
index += 1
|
||||
if target == 0:
|
||||
continue
|
||||
day_feature_dfs = []
|
||||
day_raw_df = raw_df.loc[pd.IndexSlice[ins, :, date]]
|
||||
is_buy = bool(day_order_df["order_type"])
|
||||
for df in feature_df_list:
|
||||
day_feature_dfs.append(df.loc[ins, date].values)
|
||||
day_feature_dfs = np.array(day_feature_dfs)
|
||||
day_raw_df_index, day_raw_df_value, day_raw_df_column = toArray(day_raw_df)
|
||||
day_feature_dfs_ = toArray(day_feature_dfs)
|
||||
queue.put(
|
||||
(ins, date, day_raw_df_value, day_raw_df_column, day_raw_df_index, day_feature_dfs_, target, is_buy,),
|
||||
block=True,
|
||||
)
|
||||
|
||||
def _sample_ins(self):
|
||||
""" """
|
||||
return np.random.choice(self.ins_list, 1)[0]
|
||||
|
||||
def reset(self):
|
||||
""" """
|
||||
if self.child is None:
|
||||
self.child = Process(
|
||||
target=self._worker,
|
||||
args=(self.order_dir, self.raw_dir, self.features, self.ins_list, self.queue,),
|
||||
daemon=True,
|
||||
)
|
||||
self.child.start()
|
||||
|
||||
def sample(self):
|
||||
""" """
|
||||
sample = self.queue.get(block=True)
|
||||
return sample
|
||||
|
||||
def stop(self):
|
||||
""" """
|
||||
try:
|
||||
self.child.terminate()
|
||||
except:
|
||||
for p in self.child:
|
||||
p.terminate()
|
||||
|
||||
|
||||
class TestSampler(Sampler):
|
||||
"""The sampler for backtest of single-assert strategies."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.ins_index = -1
|
||||
|
||||
def _sample_ins(self):
|
||||
""" """
|
||||
self.ins_index += 1
|
||||
if self.ins_index >= len(self.ins_list):
|
||||
return None
|
||||
else:
|
||||
return self.ins_list[self.ins_index]
|
||||
|
||||
@staticmethod
|
||||
def _worker(order_dir, raw_dir, features, ins_list, queue):
|
||||
for ins in ins_list:
|
||||
order_df = pd.read_pickle(order_dir + ins + ".pkl.target")
|
||||
df_list = []
|
||||
for feature in features:
|
||||
df_list.append(pd.read_pickle(f"{feature['loc']}/{ins}.pkl"))
|
||||
raw_df = pd.read_pickle(raw_dir + ins + ".pkl.backtest")
|
||||
date_list = order_df.index.get_level_values(0).tolist()
|
||||
for index in range(len(date_list)):
|
||||
date = date_list[index]
|
||||
day_df_list = []
|
||||
day_raw_df = raw_df.loc[pd.IndexSlice[ins, :, date]]
|
||||
day_order_df = order_df.iloc[index]
|
||||
target = day_order_df["amount"]
|
||||
if target == 0:
|
||||
continue
|
||||
is_buy = bool(day_order_df["order_type"])
|
||||
for df in df_list:
|
||||
day_df_list.append(df.loc[ins, date].values)
|
||||
day_feature_dfs = np.array(day_df_list)
|
||||
day_raw_df_index, day_raw_df_value, day_raw_df_column = toArray(day_raw_df)
|
||||
day_feature_dfs_ = toArray(day_feature_dfs)
|
||||
queue.put(
|
||||
(
|
||||
ins,
|
||||
date,
|
||||
day_raw_df_value,
|
||||
day_raw_df_column,
|
||||
day_raw_df_index,
|
||||
day_feature_dfs_,
|
||||
target,
|
||||
is_buy,
|
||||
),
|
||||
block=True,
|
||||
)
|
||||
for _ in range(100):
|
||||
queue.put(None)
|
||||
|
||||
def reset(self, order_dir=None):
|
||||
"""
|
||||
|
||||
reset the sampler and change self.order_dir if order_dir is not None.
|
||||
|
||||
"""
|
||||
if order_dir:
|
||||
self.order_dir = order_dir
|
||||
self.ins_list = [f[:-11] for f in os.listdir(self.order_dir) if f.endswith("target")]
|
||||
if not self.child is None:
|
||||
self.child.terminate()
|
||||
while not self.queue.empty():
|
||||
self.queue.get()
|
||||
self.child = Process(
|
||||
target=self._worker,
|
||||
args=(self.order_dir, self.raw_dir, self.features, self.ins_list, self.queue,),
|
||||
daemon=True,
|
||||
)
|
||||
self.child.start()
|
||||
28
examples/trade/teacher_feature.py
Normal file
28
examples/trade/teacher_feature.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
data_path = '../data/'
|
||||
feature_path = os.path.join(data_path, 'feature/teacher/')
|
||||
if not os.path.exists(feature_path):
|
||||
os.makedirs(feature_path)
|
||||
|
||||
|
||||
log_file = os.path.join(os.environ.get('OUTPUT_DIR'),'example/OPDT_b/test/')
|
||||
|
||||
files = os.listdir(log_file)
|
||||
|
||||
for f in files:
|
||||
if f.endswith(".log"):
|
||||
df = pd.read_pickle(log_file + f)
|
||||
|
||||
#df['datetime'] = df.index.get_level_values(1).map(lambda x: x[1])
|
||||
df['datetime'] = df.index.get_level_values(1)
|
||||
df.set_index('datetime', append=True, drop=True, inplace=True)
|
||||
action = df['action']
|
||||
action = action.reset_index(level=1, drop=True)
|
||||
action.index = action.index.map(lambda x: (x[0], x[1], x[2].time()))
|
||||
action = action.unstack().iloc[:, ::30] * 2
|
||||
action = action.fillna(0)
|
||||
train_action = action.astype("int")
|
||||
final = train_action
|
||||
final.to_pickle(feature_path + f[:-4] + '.pkl')
|
||||
303
examples/trade/util.py
Normal file
303
examples/trade/util.py
Normal file
@@ -0,0 +1,303 @@
|
||||
from collections import namedtuple
|
||||
from torch.nn.utils.rnn import pack_padded_sequence
|
||||
from tianshou.data import Batch
|
||||
import numpy as np
|
||||
import torch
|
||||
import copy
|
||||
from typing import Union, Optional
|
||||
from numbers import Number
|
||||
|
||||
|
||||
def nan_weighted_avg(vals, weights, axis=None):
|
||||
"""
|
||||
|
||||
:param vals: The values to be averaged on.
|
||||
:param weights: The weights of weighted avrage.
|
||||
:param axis: On which axis to calculate the weighted avrage. (Default value = None)
|
||||
|
||||
"""
|
||||
assert vals.shape == weights.shape, AssertionError(f"{vals.shape} & {weights.shape}")
|
||||
vals = vals.copy()
|
||||
weights = weights.copy()
|
||||
res = (vals * weights).sum(axis=axis) / weights.sum(axis=axis)
|
||||
return np.nan_to_num(res, nan=vals[0])
|
||||
|
||||
|
||||
def robust_auc(y_true, y_pred):
|
||||
"""
|
||||
|
||||
Calculate AUC.
|
||||
|
||||
"""
|
||||
try:
|
||||
return roc_auc_score(y_true, y_pred)
|
||||
except:
|
||||
return np.nan
|
||||
|
||||
|
||||
def merge_dicts(d1, d2):
|
||||
"""
|
||||
|
||||
:param d1: Dict 1.
|
||||
:type d1: dict
|
||||
:param d2: Dict 2.
|
||||
:returns: A new dict that is d1 and d2 deep merged.
|
||||
:rtype: dict
|
||||
|
||||
"""
|
||||
merged = copy.deepcopy(d1)
|
||||
deep_update(merged, d2, True, [])
|
||||
return merged
|
||||
|
||||
|
||||
def deep_update(
|
||||
original, new_dict, new_keys_allowed=False, whitelist=None, override_all_if_type_changes=None,
|
||||
):
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
If new key is introduced in new_dict, then if new_keys_allowed is not
|
||||
True, an error will be thrown. Further, for sub-dicts, if the key is
|
||||
in the whitelist, then new subkeys can be introduced.
|
||||
|
||||
:param original: Dictionary with default values.
|
||||
:type original: dict
|
||||
:param new_dict(dict: dict): Dictionary with values to be updated
|
||||
:param new_keys_allowed: Whether new keys are allowed. (Default value = False)
|
||||
:type new_keys_allowed: bool
|
||||
:param whitelist: List of keys that correspond to dict
|
||||
values where new subkeys can be introduced. This is only at the top
|
||||
level. (Default value = None)
|
||||
:type whitelist: Optional[List[str]]
|
||||
:param override_all_if_type_changes: List of top level
|
||||
keys with value=dict, for which we always simply override the
|
||||
entire value (dict), iff the "type" key in that value dict changes. (Default value = None)
|
||||
:type override_all_if_type_changes: Optional[List[str]]
|
||||
:param new_dict:
|
||||
|
||||
"""
|
||||
whitelist = whitelist or []
|
||||
override_all_if_type_changes = override_all_if_type_changes or []
|
||||
|
||||
for k, value in new_dict.items():
|
||||
if k not in original and not new_keys_allowed:
|
||||
raise Exception("Unknown config parameter `{}` ".format(k))
|
||||
|
||||
# Both orginal value and new one are dicts.
|
||||
if isinstance(original.get(k), dict) and isinstance(value, dict):
|
||||
# Check old type vs old one. If different, override entire value.
|
||||
if (
|
||||
k in override_all_if_type_changes
|
||||
and "type" in value
|
||||
and "type" in original[k]
|
||||
and value["type"] != original[k]["type"]
|
||||
):
|
||||
original[k] = value
|
||||
# Whitelisted key -> ok to add new subkeys.
|
||||
elif k in whitelist:
|
||||
deep_update(original[k], value, True)
|
||||
# Non-whitelisted key.
|
||||
else:
|
||||
deep_update(original[k], value, new_keys_allowed)
|
||||
# Original value not a dict OR new value not a dict:
|
||||
# Override entire value.
|
||||
else:
|
||||
original[k] = value
|
||||
return original
|
||||
|
||||
|
||||
def get_seqlen(done_seq):
|
||||
"""
|
||||
|
||||
:param done_seq:
|
||||
|
||||
"""
|
||||
seqlen = []
|
||||
length = 0
|
||||
for i, done in enumerate(done_seq):
|
||||
length += 1
|
||||
if done:
|
||||
seqlen.append(length)
|
||||
length = 0
|
||||
if length > 0:
|
||||
seqlen.append(length)
|
||||
return np.array(seqlen)
|
||||
|
||||
|
||||
def generate_seq(seqlen, list):
|
||||
"""
|
||||
|
||||
:param seqlen: param list:
|
||||
:param list:
|
||||
|
||||
"""
|
||||
res = []
|
||||
index = 0
|
||||
maxlen = np.max(seqlen)
|
||||
for i in seqlen:
|
||||
if isinstance(list, torch.Tensor):
|
||||
res.append(torch.cat((list[index : index + i], torch.zeros_like(list[: maxlen - i])), dim=0,))
|
||||
else:
|
||||
res.append(np.concatenate((list[index : index + i], np.zeros_like(list[: maxlen - i])), axis=0))
|
||||
index += i
|
||||
if isinstance(list, torch.Tensor):
|
||||
res = torch.stack(res, dim=0)
|
||||
else:
|
||||
res = np.stack(res, axis=0)
|
||||
return res
|
||||
|
||||
|
||||
def sequence_batch(batch):
|
||||
"""
|
||||
|
||||
:param batch:
|
||||
|
||||
"""
|
||||
seqlen = get_seqlen(batch.done)
|
||||
# print(seqlen.max())
|
||||
# print(len(seqlen))
|
||||
res = Batch()
|
||||
# print(batch.keys())
|
||||
|
||||
for v in batch.keys():
|
||||
if v not in ["policy", "info"]:
|
||||
res[v] = generate_seq(seqlen, batch[v])
|
||||
else:
|
||||
res[v] = batch[v]
|
||||
res.seqlen = seqlen
|
||||
return res
|
||||
|
||||
|
||||
def flatten_seq(seq, seqlen):
|
||||
"""
|
||||
|
||||
:param seq: param seqlen:
|
||||
:param seqlen:
|
||||
|
||||
"""
|
||||
res = []
|
||||
for i, length in enumerate(seqlen):
|
||||
res.append(seq[i][:length])
|
||||
if isinstance(seq, torch.Tensor):
|
||||
res = torch.cat(res, dim=0)
|
||||
else:
|
||||
res = np.concatenate(res, axis=0)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def flatten_batch(batch):
|
||||
"""
|
||||
|
||||
:param batch:
|
||||
|
||||
"""
|
||||
for v in batch.keys():
|
||||
if v in ["policy", "info", "seqlen"]:
|
||||
continue
|
||||
batch[v] = flatten_seq(batch[v], batch.seqlen)
|
||||
return batch
|
||||
|
||||
|
||||
def to_numpy(
|
||||
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]
|
||||
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
|
||||
:param x: Union[Batch:
|
||||
:param dict: param list:
|
||||
:param tuple: param np.ndarray:
|
||||
:param torch: Tensor]:
|
||||
:param x: Union[Batch:
|
||||
:param list:
|
||||
:param np.ndarray:
|
||||
:param torch.Tensor]:
|
||||
:param x: Union[Batch:
|
||||
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach().cpu().numpy()
|
||||
elif isinstance(x, dict):
|
||||
for k, v in x.items():
|
||||
x[k] = to_numpy(v)
|
||||
elif isinstance(x, Batch):
|
||||
x.to_numpy()
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
x = to_numpy(_parse_value(x))
|
||||
except TypeError:
|
||||
x = [to_numpy(e) for e in x]
|
||||
else: # fallback
|
||||
x = np.asanyarray(x)
|
||||
return x
|
||||
|
||||
|
||||
def to_torch(
|
||||
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
|
||||
:param x: Union[Batch:
|
||||
:param dict: param list:
|
||||
:param tuple: param np.ndarray:
|
||||
:param torch: Tensor]:
|
||||
:param dtype: Optional[torch.dtype]: (Default value = None)
|
||||
:param device: Union[str:
|
||||
:param int: param torch.device]: (Default value = 'cpu')
|
||||
:param x: Union[Batch:
|
||||
:param list:
|
||||
:param np.ndarray:
|
||||
:param torch.Tensor]:
|
||||
:param dtype: Optional[torch.dtype]: (Default value = None)
|
||||
:param device: Union[str:
|
||||
:param torch.device]: (Default value = 'cpu')
|
||||
:param x: Union[Batch:
|
||||
:param dtype: Optional[torch.dtype]: (Default value = None)
|
||||
:param device: Union[str:
|
||||
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
x = x.to(device)
|
||||
elif isinstance(x, dict):
|
||||
for k, v in x.items():
|
||||
x[k] = to_torch(v, dtype, device)
|
||||
elif isinstance(x, Batch):
|
||||
x.to_torch(dtype, device)
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
x = to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
x = to_torch(_parse_value(x), dtype, device)
|
||||
except TypeError:
|
||||
x = [to_torch(e, dtype, device) for e in x]
|
||||
else: # fallback
|
||||
x = np.asanyarray(x)
|
||||
if issubclass(x.dtype.type, (np.bool_, np.number)):
|
||||
x = torch.from_numpy(x).to(device)
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
else:
|
||||
raise TypeError(f"object {x} cannot be converted to torch.")
|
||||
return x
|
||||
|
||||
|
||||
def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray], y: torch.Tensor) -> Union[dict, Batch, torch.Tensor]:
|
||||
"""
|
||||
|
||||
:param x: Union[torch.Tensor:
|
||||
:param dict: param Batch:
|
||||
:param np: ndarray]:
|
||||
:param y: torch.Tensor:
|
||||
:param x: Union[torch.Tensor:
|
||||
:param Batch:
|
||||
:param np.ndarray]:
|
||||
:param y: torch.Tensor:
|
||||
:param x: Union[torch.Tensor:
|
||||
:param y: torch.Tensor:
|
||||
:returns: to_torch(x, dtype=y.dtype, device=y.device)``.
|
||||
|
||||
"""
|
||||
assert isinstance(y, torch.Tensor)
|
||||
return to_torch(x, dtype=y.dtype, device=y.device)
|
||||
695
examples/trade/vecenv.py
Normal file
695
examples/trade/vecenv.py
Normal file
@@ -0,0 +1,695 @@
|
||||
import gym
|
||||
import time
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from multiprocessing.context import Process
|
||||
from multiprocessing import Array, Pipe, connection, Queue
|
||||
from typing import Any, List, Tuple, Union, Callable, Optional
|
||||
|
||||
from tianshou.env.worker import EnvWorker
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
_NP_TO_CT = {
|
||||
np.bool: ctypes.c_bool,
|
||||
np.bool_: ctypes.c_bool,
|
||||
np.uint8: ctypes.c_uint8,
|
||||
np.uint16: ctypes.c_uint16,
|
||||
np.uint32: ctypes.c_uint32,
|
||||
np.uint64: ctypes.c_uint64,
|
||||
np.int8: ctypes.c_int8,
|
||||
np.int16: ctypes.c_int16,
|
||||
np.int32: ctypes.c_int32,
|
||||
np.int64: ctypes.c_int64,
|
||||
np.float32: ctypes.c_float,
|
||||
np.float64: ctypes.c_double,
|
||||
}
|
||||
|
||||
|
||||
class ShArray:
|
||||
"""Wrapper of multiprocessing Array."""
|
||||
|
||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||
self.arr = Array(
|
||||
_NP_TO_CT[dtype.type], # type: ignore
|
||||
int(np.prod(shape)),
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
|
||||
def save(self, ndarray: np.ndarray) -> None:
|
||||
"""
|
||||
|
||||
:param ndarray: np.ndarray:
|
||||
:param ndarray: np.ndarray:
|
||||
:param ndarray: np.ndarray:
|
||||
|
||||
"""
|
||||
assert isinstance(ndarray, np.ndarray)
|
||||
dst = self.arr.get_obj()
|
||||
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
|
||||
np.copyto(dst_np, ndarray)
|
||||
|
||||
def get(self) -> np.ndarray:
|
||||
""" """
|
||||
obj = self.arr.get_obj()
|
||||
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
|
||||
|
||||
|
||||
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
|
||||
"""
|
||||
|
||||
:param space: gym.Space:
|
||||
:param space: gym.Space:
|
||||
:param space: gym.Space:
|
||||
|
||||
"""
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict)
|
||||
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert isinstance(space.spaces, tuple)
|
||||
return tuple([_setup_buf(t) for t in space.spaces])
|
||||
else:
|
||||
return ShArray(space.dtype, space.shape)
|
||||
|
||||
|
||||
def _worker(
|
||||
parent: connection.Connection,
|
||||
p: connection.Connection,
|
||||
env_fn_wrapper: CloudpickleWrapper,
|
||||
obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
:param parent: connection.Connection:
|
||||
:param p: connection.Connection:
|
||||
:param env_fn_wrapper: CloudpickleWrapper:
|
||||
:param obs_bufs: Optional[Union[dict:
|
||||
:param tuple: param ShArray]]: (Default value = None)
|
||||
:param parent: connection.Connection:
|
||||
:param p: connection.Connection:
|
||||
:param env_fn_wrapper: CloudpickleWrapper:
|
||||
:param obs_bufs: Optional[Union[dict:
|
||||
:param ShArray]]: (Default value = None)
|
||||
:param parent: connection.Connection:
|
||||
:param p: connection.Connection:
|
||||
:param env_fn_wrapper: CloudpickleWrapper:
|
||||
:param obs_bufs: Optional[Union[dict:
|
||||
|
||||
"""
|
||||
|
||||
def _encode_obs(obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray],) -> None:
|
||||
"""
|
||||
|
||||
:param obs: Union[dict:
|
||||
:param tuple: param np.ndarray]:
|
||||
:param buffer: Union[dict:
|
||||
:param ShArray:
|
||||
:param obs: Union[dict:
|
||||
:param np.ndarray]:
|
||||
:param buffer: Union[dict:
|
||||
:param ShArray]:
|
||||
:param obs: Union[dict:
|
||||
:param buffer: Union[dict:
|
||||
|
||||
"""
|
||||
if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
|
||||
buffer.save(obs)
|
||||
elif isinstance(obs, tuple) and isinstance(buffer, tuple):
|
||||
for o, b in zip(obs, buffer):
|
||||
_encode_obs(o, b)
|
||||
elif isinstance(obs, dict) and isinstance(buffer, dict):
|
||||
for k in obs.keys():
|
||||
_encode_obs(obs[k], buffer[k])
|
||||
return None
|
||||
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
cmd, data = p.recv()
|
||||
except EOFError: # the pipe has been closed
|
||||
p.close()
|
||||
break
|
||||
if cmd == "step":
|
||||
obs, reward, done, info = env.step(data)
|
||||
if obs_bufs is not None:
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
p.send((obs, reward, done, info))
|
||||
elif cmd == "reset":
|
||||
obs = env.reset(data)
|
||||
if obs_bufs is not None:
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
p.send(obs)
|
||||
elif cmd == "close":
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
break
|
||||
elif cmd == "render":
|
||||
p.send(env.render(**data) if hasattr(env, "render") else None)
|
||||
elif cmd == "seed":
|
||||
p.send(env.seed(data) if hasattr(env, "seed") else None)
|
||||
elif cmd == "getattr":
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
elif cmd == "toggle_log":
|
||||
env.toggle_log(data)
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
except KeyboardInterrupt:
|
||||
p.close()
|
||||
|
||||
|
||||
class SubprocEnvWorker(EnvWorker):
|
||||
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
|
||||
|
||||
def __init__(self, env_fn: Callable[[], gym.Env], share_memory: bool = False) -> None:
|
||||
super().__init__(env_fn)
|
||||
self.parent_remote, self.child_remote = Pipe()
|
||||
self.share_memory = share_memory
|
||||
self.buffer: Optional[Union[dict, tuple, ShArray]] = None
|
||||
if self.share_memory:
|
||||
dummy = env_fn()
|
||||
obs_space = dummy.observation_space
|
||||
dummy.close()
|
||||
del dummy
|
||||
self.buffer = _setup_buf(obs_space)
|
||||
args = (
|
||||
self.parent_remote,
|
||||
self.child_remote,
|
||||
CloudpickleWrapper(env_fn),
|
||||
self.buffer,
|
||||
)
|
||||
self.process = Process(target=_worker, args=args, daemon=True)
|
||||
self.process.start()
|
||||
self.child_remote.close()
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
self.parent_remote.send(["getattr", key])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
|
||||
""" """
|
||||
|
||||
def decode_obs(buffer: Optional[Union[dict, tuple, ShArray]]) -> Union[dict, tuple, np.ndarray]:
|
||||
"""
|
||||
|
||||
:param buffer: Optional[Union[dict:
|
||||
:param tuple: param ShArray]]:
|
||||
:param buffer: Optional[Union[dict:
|
||||
:param ShArray]]:
|
||||
:param buffer: Optional[Union[dict:
|
||||
|
||||
"""
|
||||
if isinstance(buffer, ShArray):
|
||||
return buffer.get()
|
||||
elif isinstance(buffer, tuple):
|
||||
return tuple([decode_obs(b) for b in buffer])
|
||||
elif isinstance(buffer, dict):
|
||||
return {k: decode_obs(v) for k, v in buffer.items()}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return decode_obs(self.buffer)
|
||||
|
||||
def reset(self, sample) -> Any:
|
||||
"""
|
||||
|
||||
:param sample:
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["reset", sample])
|
||||
# obs = self.parent_remote.recv()
|
||||
# if self.share_memory:
|
||||
# obs = self._decode_obs()
|
||||
# return obs
|
||||
|
||||
def get_reset_result(self):
|
||||
""" """
|
||||
obs = self.parent_remote.recv()
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs
|
||||
|
||||
@staticmethod
|
||||
def wait( # type: ignore
|
||||
workers: List["SubprocEnvWorker"], wait_num: int, timeout: Optional[float] = None,
|
||||
) -> List["SubprocEnvWorker"]:
|
||||
"""
|
||||
|
||||
:param # type: ignoreworkers: List["SubprocEnvWorker"]:
|
||||
:param wait_num: int:
|
||||
:param timeout: Optional[float]: (Default value = None)
|
||||
:param # type: ignoreworkers: List["SubprocEnvWorker"]:
|
||||
:param wait_num: int:
|
||||
:param timeout: Optional[float]: (Default value = None)
|
||||
|
||||
"""
|
||||
remain_conns = conns = [x.parent_remote for x in workers]
|
||||
ready_conns: List[connection.Connection] = []
|
||||
remain_time, t1 = timeout, time.time()
|
||||
while len(remain_conns) > 0 and len(ready_conns) < wait_num:
|
||||
if timeout:
|
||||
remain_time = timeout - (time.time() - t1)
|
||||
if remain_time <= 0:
|
||||
break
|
||||
# connection.wait hangs if the list is empty
|
||||
new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
|
||||
ready_conns.extend(new_ready_conns) # type: ignore
|
||||
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
|
||||
return [workers[conns.index(con)] for con in ready_conns]
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
"""
|
||||
|
||||
:param action: np.ndarray:
|
||||
:param action: np.ndarray:
|
||||
:param action: np.ndarray:
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["step", action])
|
||||
|
||||
def toggle_log(self, log):
|
||||
self.parent_remote.send(["toggle_log", log])
|
||||
|
||||
def get_result(self,) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
""" """
|
||||
obs, rew, done, info = self.parent_remote.recv()
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs, rew, done, info
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
"""
|
||||
|
||||
:param seed: Optional[int]: (Default value = None)
|
||||
:param seed: Optional[int]: (Default value = None)
|
||||
:param seed: Optional[int]: (Default value = None)
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["seed", seed])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def render(self, **kwargs: Any) -> Any:
|
||||
"""
|
||||
|
||||
:param **kwargs: Any:
|
||||
:param **kwargs: Any:
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["render", kwargs])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def close_env(self) -> None:
|
||||
""" """
|
||||
try:
|
||||
self.parent_remote.send(["close", None])
|
||||
# mp may be deleted so it may raise AttributeError
|
||||
self.parent_remote.recv()
|
||||
self.process.join()
|
||||
except (BrokenPipeError, EOFError, AttributeError):
|
||||
pass
|
||||
# ensure the subproc is terminated
|
||||
self.process.terminate()
|
||||
|
||||
|
||||
class BaseVectorEnv(gym.Env):
|
||||
"""Base class for vectorized environments wrapper.
|
||||
Usage:
|
||||
::
|
||||
env_num = 8
|
||||
envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])
|
||||
assert len(envs) == env_num
|
||||
It accepts a list of environment generators. In other words, an environment
|
||||
generator ``efn`` of a specific task means that ``efn()`` returns the
|
||||
environment of the given task, for example, ``gym.make(task)``.
|
||||
All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`.
|
||||
Here are some other usages:
|
||||
::
|
||||
envs.seed(2) # which is equal to the next line
|
||||
envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env
|
||||
obs = envs.reset() # reset all environments
|
||||
obs = envs.reset([0, 5, 7]) # reset 3 specific environments
|
||||
obs, rew, done, info = envs.step([1] * 8) # step synchronously
|
||||
envs.render() # render all environments
|
||||
envs.close() # close all environments
|
||||
.. warning::
|
||||
If you use your own environment, please make sure the ``seed`` method
|
||||
is set up properly, e.g.,
|
||||
::
|
||||
def seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
Otherwise, the outputs of these envs may be the same with each other.
|
||||
|
||||
:param env_fns: a list of callable envs
|
||||
:param env:
|
||||
:param worker_fn: a callable worker
|
||||
:param worker: which contains the i
|
||||
:param int: wait_num
|
||||
:param env: step
|
||||
:param environments: to finish a step is time
|
||||
:param return: when
|
||||
:param simulation: in these environments
|
||||
:param is: disabled
|
||||
:param float: timeout
|
||||
:param vectorized: step it only deal with those environments spending time
|
||||
:param within: timeout
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
|
||||
sampler=None,
|
||||
testing: Optional[bool] = False,
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
self._env_fns = env_fns
|
||||
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
|
||||
# interact with the given envs (one worker <-> one env).
|
||||
self.workers = [worker_fn(fn) for fn in env_fns]
|
||||
self.worker_class = type(self.workers[0])
|
||||
assert issubclass(self.worker_class, EnvWorker)
|
||||
assert all([isinstance(w, self.worker_class) for w in self.workers])
|
||||
|
||||
self.env_num = len(env_fns)
|
||||
self.wait_num = wait_num or len(env_fns)
|
||||
assert 1 <= self.wait_num <= len(env_fns), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
||||
self.timeout = timeout
|
||||
assert self.timeout is None or self.timeout > 0, f"timeout is {timeout}, it should be positive if provided!"
|
||||
self.is_async = self.wait_num != len(env_fns) or timeout is not None or testing
|
||||
self.waiting_conn: List[EnvWorker] = []
|
||||
# environments in self.ready_id is actually ready
|
||||
# but environments in self.waiting_id are just waiting when checked,
|
||||
# and they may be ready now, but this is not known until we check it
|
||||
# in the step() function
|
||||
self.waiting_id: List[int] = []
|
||||
# all environments are ready in the beginning
|
||||
self.ready_id = list(range(self.env_num))
|
||||
self.is_closed = False
|
||||
self.sampler = sampler
|
||||
self.sample_obs = None
|
||||
|
||||
def _assert_is_not_closed(self) -> None:
|
||||
""" """
|
||||
assert not self.is_closed, f"Methods of {self.__class__.__name__} cannot be called after " "close."
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self), which is the number of environments."""
|
||||
return self.env_num
|
||||
|
||||
def __getattribute__(self, key: str) -> Any:
|
||||
"""Switch the attribute getter depending on the key.
|
||||
Any class who inherits ``gym.Env`` will inherit some attributes, like
|
||||
``action_space``. However, we would like the attribute lookup to go
|
||||
straight into the worker (in fact, this vector env's action_space is
|
||||
always None).
|
||||
"""
|
||||
if key in [
|
||||
"metadata",
|
||||
"reward_range",
|
||||
"spec",
|
||||
"action_space",
|
||||
"observation_space",
|
||||
]: # reserved keys in gym.Env
|
||||
return self.__getattr__(key)
|
||||
else:
|
||||
return super().__getattribute__(key)
|
||||
|
||||
def __getattr__(self, key: str) -> List[Any]:
|
||||
"""Fetch a list of env attributes.
|
||||
This function tries to retrieve an attribute from each individual
|
||||
wrapped environment, if it does not belong to the wrapping vector
|
||||
environment class.
|
||||
"""
|
||||
return [getattr(worker, key) for worker in self.workers]
|
||||
|
||||
def _wrap_id(self, id: Optional[Union[int, List[int], np.ndarray]] = None) -> Union[List[int], np.ndarray]:
|
||||
"""
|
||||
|
||||
:param id: Optional[Union[int:
|
||||
:param List: int]:
|
||||
:param np: ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
:param List[int]:
|
||||
:param np.ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
|
||||
"""
|
||||
if id is None:
|
||||
id = list(range(self.env_num))
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
return id
|
||||
|
||||
def _assert_id(self, id: List[int]) -> None:
|
||||
"""
|
||||
|
||||
:param id: List[int]:
|
||||
:param id: List[int]:
|
||||
:param id: List[int]:
|
||||
|
||||
"""
|
||||
for i in id:
|
||||
assert i not in self.waiting_id, f"Cannot interact with environment {i} which is stepping now."
|
||||
assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}."
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None) -> np.ndarray:
|
||||
"""Reset the state of some envs and return initial observations.
|
||||
If id is None, reset the state of all the environments and return
|
||||
initial observations, otherwise reset the specific environments with
|
||||
the given id, either an int or a list.
|
||||
|
||||
:param id: Optional[Union[int:
|
||||
:param List: int]:
|
||||
:param np: ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
:param List[int]:
|
||||
:param np.ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
|
||||
"""
|
||||
start_time = time.time()
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
if self.is_async:
|
||||
self._assert_id(id)
|
||||
obs = []
|
||||
stop_id = []
|
||||
for i in id:
|
||||
sample = self.sampler.sample()
|
||||
if sample is None:
|
||||
stop_id.append(i)
|
||||
else:
|
||||
self.workers[i].reset(sample)
|
||||
for i in id:
|
||||
if i in stop_id:
|
||||
obs.append(self.sample_obs)
|
||||
else:
|
||||
this_obs = self.workers[i].get_reset_result()
|
||||
if self.sample_obs is None:
|
||||
self.sample_obs = this_obs
|
||||
for j in range(len(obs)):
|
||||
if obs[j] is None:
|
||||
obs[j] = self.sample_obs
|
||||
obs.append(this_obs)
|
||||
|
||||
if len(obs) > 0:
|
||||
obs = np.stack(obs)
|
||||
# if len(stop_id)> 0:
|
||||
# obs_zero =
|
||||
# print(time.time() - start_timed)
|
||||
|
||||
return obs, stop_id
|
||||
|
||||
def toggle_log(self, log):
|
||||
for worker in self.workers:
|
||||
worker.toggle_log(log)
|
||||
|
||||
def reset_sampler(self):
|
||||
""" """
|
||||
self.sampler.reset()
|
||||
|
||||
def step(self, action: np.ndarray, id: Optional[Union[int, List[int], np.ndarray]] = None) -> List[np.ndarray]:
|
||||
"""Run one timestep of some environments' dynamics.
|
||||
If id is None, run one timestep of all the environments’ dynamics;
|
||||
otherwise run one timestep for some environments with given id, either
|
||||
an int or a list. When the end of episode is reached, you are
|
||||
responsible for calling reset(id) to reset this environment’s state.
|
||||
Accept a batch of action and return a tuple (batch_obs, batch_rew,
|
||||
batch_done, batch_info) in numpy format.
|
||||
|
||||
:param numpy: ndarray action: a batch of action provided by the agent.
|
||||
:param action: np.ndarray:
|
||||
:param id: Optional[Union[int:
|
||||
:param List: int]:
|
||||
:param np: ndarray]]: (Default value = None)
|
||||
:param action: np.ndarray:
|
||||
:param id: Optional[Union[int:
|
||||
:param List[int]:
|
||||
:param np.ndarray]]: (Default value = None)
|
||||
:param action: np.ndarray:
|
||||
:param id: Optional[Union[int:
|
||||
:rtype: A tuple including four items
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
if not self.is_async:
|
||||
assert len(action) == len(id)
|
||||
for i, j in enumerate(id):
|
||||
self.workers[j].send_action(action[i])
|
||||
result = []
|
||||
for j in id:
|
||||
obs, rew, done, info = self.workers[j].get_result()
|
||||
info["env_id"] = j
|
||||
result.append((obs, rew, done, info))
|
||||
else:
|
||||
if action is not None:
|
||||
self._assert_id(id)
|
||||
assert len(action) == len(id)
|
||||
for i, (act, env_id) in enumerate(zip(action, id)):
|
||||
self.workers[env_id].send_action(act)
|
||||
self.waiting_conn.append(self.workers[env_id])
|
||||
self.waiting_id.append(env_id)
|
||||
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||
ready_conns: List[EnvWorker] = []
|
||||
while not ready_conns:
|
||||
ready_conns = self.worker_class.wait(self.waiting_conn, self.wait_num, self.timeout)
|
||||
result = []
|
||||
for conn in ready_conns:
|
||||
waiting_index = self.waiting_conn.index(conn)
|
||||
self.waiting_conn.pop(waiting_index)
|
||||
env_id = self.waiting_id.pop(waiting_index)
|
||||
obs, rew, done, info = conn.get_result()
|
||||
info["env_id"] = env_id
|
||||
result.append((obs, rew, done, info))
|
||||
self.ready_id.append(env_id)
|
||||
return list(map(np.stack, zip(*result)))
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[Optional[List[int]]]:
|
||||
"""Set the seed for all environments.
|
||||
Accept ``None``, an int (which will extend ``i`` to
|
||||
``[i, i + 1, i + 2, ...]``) or a list.
|
||||
|
||||
:param seed: Optional[Union[int:
|
||||
:param List: int]]]: (Default value = None)
|
||||
:param seed: Optional[Union[int:
|
||||
:param List[int]]]: (Default value = None)
|
||||
:param seed: Optional[Union[int:
|
||||
:returns: The list of seeds used in this env's random number generators.
|
||||
The first value in the list should be the "main" seed, or the value
|
||||
which a reproducer pass to "seed".
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
seed_list: Union[List[None], List[int]]
|
||||
if seed is None:
|
||||
seed_list = [seed] * self.env_num
|
||||
elif isinstance(seed, int):
|
||||
seed_list = [seed + i for i in range(self.env_num)]
|
||||
else:
|
||||
seed_list = seed
|
||||
return [w.seed(s) for w, s in zip(self.workers, seed_list)]
|
||||
|
||||
def render(self, **kwargs: Any) -> List[Any]:
|
||||
"""Render all of the environments.
|
||||
|
||||
:param **kwargs: Any:
|
||||
:param **kwargs: Any:
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
if self.is_async and len(self.waiting_id) > 0:
|
||||
raise RuntimeError(f"Environments {self.waiting_id} are still stepping, cannot " "render them now.")
|
||||
return [w.render(**kwargs) for w in self.workers]
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close all of the environments.
|
||||
This function will be called only once (if not, it will be called
|
||||
during garbage collected). This way, ``close`` of all workers can be
|
||||
assured.
|
||||
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
for w in self.workers:
|
||||
w.close()
|
||||
self.is_closed = True
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Redirect to self.close()."""
|
||||
if not self.is_closed:
|
||||
self.close()
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on subprocess.
|
||||
.. seealso::
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
sampler=None,
|
||||
testing=False,
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
"""
|
||||
|
||||
:param fn: Callable[[]:
|
||||
:param gym: Env]:
|
||||
:param fn: Callable[[]:
|
||||
:param gym.Env]:
|
||||
:param fn: Callable[[]:
|
||||
|
||||
"""
|
||||
return SubprocEnvWorker(fn, share_memory=False)
|
||||
|
||||
super().__init__(env_fns, worker_fn, sampler, testing, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class ShmemVectorEnv(BaseVectorEnv):
|
||||
"""Optimized SubprocVectorEnv with shared buffers to exchange observations.
|
||||
ShmemVectorEnv has exactly the same API as SubprocVectorEnv.
|
||||
.. seealso::
|
||||
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
|
||||
detailed explanation.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
sampler=None,
|
||||
testing=False,
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
"""
|
||||
|
||||
:param fn: Callable[[]:
|
||||
:param gym: Env]:
|
||||
:param fn: Callable[[]:
|
||||
:param gym.Env]:
|
||||
:param fn: Callable[[]:
|
||||
|
||||
"""
|
||||
return SubprocEnvWorker(fn, share_memory=True)
|
||||
|
||||
super().__init__(env_fns, worker_fn, sampler, testing, wait_num=wait_num, timeout=timeout)
|
||||
@@ -17,7 +17,7 @@ from qlib.contrib.evaluate import (
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -25,9 +25,6 @@ if __name__ == "__main__":
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
@@ -98,6 +95,7 @@ if __name__ == "__main__":
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
"return_order": True,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -105,6 +103,11 @@ if __name__ == "__main__":
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
# NOTE: This line is optional
|
||||
# It demonstrates that the dataset can be used standalone.
|
||||
example_df = dataset.prepare("train")
|
||||
print(example_df.head())
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
|
||||
@@ -2,91 +2,49 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
__version__ = "0.6.0"
|
||||
__version__ = "0.6.1.99"
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import copy
|
||||
import yaml
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path
|
||||
from .workflow.utils import experiment_exit_handler
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
from .config import C, REG_CN, REG_US, QlibConfig
|
||||
from .data.data import register_all_wrappers
|
||||
from .log import get_module_logger, set_log_with_config
|
||||
from .config import C
|
||||
from .log import get_module_logger
|
||||
from .data.cache import H
|
||||
from .workflow import R, QlibRecorder
|
||||
|
||||
C.reset()
|
||||
H.clear()
|
||||
|
||||
_logging_config = C.logging_config
|
||||
if "logging_config" in kwargs:
|
||||
_logging_config = kwargs["logging_config"]
|
||||
|
||||
# set global config
|
||||
if _logging_config:
|
||||
set_log_with_config(_logging_config)
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
LOG = get_module_logger("Initialization", level=logging.INFO)
|
||||
LOG.info(f"default_conf: {default_conf}.")
|
||||
logger = get_module_logger("Initialization", level=logging.INFO)
|
||||
|
||||
C.set_mode(default_conf)
|
||||
C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
|
||||
|
||||
for k, v in kwargs.items():
|
||||
C[k] = v
|
||||
if k not in C:
|
||||
LOG.warning("Unrecognized config %s" % k)
|
||||
|
||||
C.resolve_path()
|
||||
|
||||
if not (C["expression_cache"] is None and C["dataset_cache"] is None):
|
||||
# check redis
|
||||
if not can_use_cache():
|
||||
LOG.warning(
|
||||
f"redis connection failed(host={C['redis_host']} port={C['redis_port']}), cache will not be used!"
|
||||
)
|
||||
C["expression_cache"] = None
|
||||
C["dataset_cache"] = None
|
||||
C.set(default_conf, **kwargs)
|
||||
|
||||
# check path if server/local
|
||||
if C.get_uri_type() == QlibConfig.LOCAL_URI:
|
||||
if C.get_uri_type() == C.LOCAL_URI:
|
||||
if not os.path.exists(C["provider_uri"]):
|
||||
if C["auto_mount"]:
|
||||
LOG.error(
|
||||
logger.error(
|
||||
f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
|
||||
)
|
||||
else:
|
||||
LOG.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
|
||||
elif C.get_uri_type() == QlibConfig.NFS_URI:
|
||||
logger.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
|
||||
elif C.get_uri_type() == C.NFS_URI:
|
||||
_mount_nfs_uri(C)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of URI is not supported")
|
||||
|
||||
LOG.info("qlib successfully initialized based on %s settings." % default_conf)
|
||||
register_all_wrappers()
|
||||
|
||||
LOG.info(f"data_path={C.get_data_path()}")
|
||||
C.register()
|
||||
|
||||
if "flask_server" in C:
|
||||
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
|
||||
# set up QlibRecorder
|
||||
exp_manager = init_instance_by_config(C["exp_manager"])
|
||||
qr = QlibRecorder(exp_manager)
|
||||
R.register(qr)
|
||||
# clean up experiment when python program ends
|
||||
experiment_exit_handler()
|
||||
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
logger.info("qlib successfully initialized based on %s settings." % default_conf)
|
||||
logger.info(f"data_path={C.get_data_path()}")
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user