1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-29 00:51:19 +08:00

Compare commits

..

27 Commits

Author SHA1 Message Date
you-n-g
670ae6aa61 Update test_qlib_from_source.yml 2022-06-29 17:07:05 +08:00
you-n-g
7e3ca3c5f4 Update test_qlib_from_pip.yml 2022-06-29 17:02:57 +08:00
you-n-g
ab0174a363 Update test_qlib_from_source.yml 2022-06-29 17:01:51 +08:00
you-n-g
8c72ed99c2 Update test_qlib_from_source.yml 2022-06-29 17:00:35 +08:00
you-n-g
b9624b074f Update test_qlib_from_source_slow.yml 2022-06-29 13:38:36 +08:00
Huoran Li
23c657a7a2 Backtest Mypy (#1130)
* Done

* Fix test errors

* Revert profit_attribution.py

* Minor

* A minor update on collect_data type hint

* Resolve PR comments

* Use black to format code

* Fix CI errors
2022-06-28 22:16:46 +08:00
you-n-g
9bf3423a64 Auto log uncommmitted code (#1167)
* Auto log uncommmitted code

* Support set record name & trainer;

* Update recorder.py
2022-06-28 19:53:21 +08:00
Yuge Zhang
25ecb1135f Qlib RL framework (stage 2) - trainer (#1125)
* checkpoint

(cherry picked from commit 1a8e0bd4671ee6d624a7d09bb198a273282cd050)

* Not a workable version

(cherry picked from commit 3498e185684cd5590d3ab97e0ab69eab8c1e0e3a)

* vessel

* ckpt

* .

* vessel

* .

* .

* checkpoint callback

* .

* cleanup

* logger

* .

* test

* .

* add test

* .

* .

* .

* .

* New reward

* Add train API

* fix mypy

* fix lint

* More comment

* 3.7 compat

* fix test

* fix test

* .

* Resolve comments

* fix typehint
2022-06-28 19:53:05 +08:00
Linlang
2ca0d88d2d change_pitdata_source (#1171)
* change_pitdata_source

* retain_normalize

* add_comment
2022-06-28 16:29:59 +08:00
Linlang
50d74b5560 split_CI (#1141) 2022-06-28 10:17:29 +08:00
you-n-g
a87b02619a Qlib dev doc (#1142) 2022-06-21 09:46:30 +08:00
you-n-g
da676a20a2 Add time limit for CI (#1127)
* Add time limit for CI

* Update test_macos.yml
2022-06-16 16:35:20 +08:00
you-n-g
13d904d9a9 Update Version To Dev 2022-06-15 14:53:54 +08:00
Young
36950b905d Update Qlib Version 2022-06-15 14:48:54 +08:00
you-n-g
58540f76ee Csi500 example (#1126)
* Stage code

* Update results and scripts
2022-06-15 10:18:13 +08:00
YaOzI
3e6e2865ce Fixed a few mixed Chinese punctuation typos (#1123) 2022-06-14 20:12:14 +08:00
you-n-g
3fcbaa33fa Fix hist_ref in update.py (#1096)
* Fix hist_ref in update.py

* Update setup.py
2022-06-14 11:59:43 +08:00
you-n-g
50409ff17b Add log info for ensemble (#1113)
* Add log info for ensemble

* Update ensemble.py

* Update setup.py
2022-06-14 11:58:57 +08:00
you-n-g
afcea404a5 opt local trainer (better mem releasing) (#1116)
* opt local trainer (better mem releasing)

* Update setup.py

* Update data.py

* fix CI
2022-06-14 11:58:39 +08:00
you-n-g
e24ef67663 Update README.md 2022-06-14 10:53:09 +08:00
you-n-g
2d5eecb9a2 Update README.md 2022-06-14 10:52:50 +08:00
Huoran Li
89972f6c6f Refine backtest codes (#1120)
* Refine backtest code

* Keep working

* Minor

* Resolve PR comments

* Fix import error

* Fix import error
2022-06-10 12:14:48 +08:00
Linlang
1ef8e61abd fix_pylint_for_CI (#1119)
* fix_pylint_for_CI

* reformat_with_black

* fix_pylint_C3001

* fix_flake8_error
2022-06-09 16:12:33 +08:00
you-n-g
1a4114b683 Add explanation for the evalution metrics of Qlib (#1090)
* Add explanation for the evalution metrics of Qlib

* Update evaluate.py
2022-05-31 19:37:55 +08:00
Linlang
e874ef2bc1 change_datasource (#1109)
* change_datasource

* split_test_data_and_complete_data

* fix_CI
2022-05-31 19:35:49 +08:00
Huoran Li
14b2b355a7 Update .gitignore (#1110) 2022-05-30 21:27:49 +08:00
Huoran Li
64fadff218 Add .idea/ into gitignore (#1108) 2022-05-25 13:59:35 +08:00
78 changed files with 3136 additions and 1097 deletions

View File

@@ -1,94 +0,0 @@
# There are some issues (in the downloading data phase) on MacOS when running with other tests. So we split it into an individual config.
name: Test MacOS
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Lint with Black
run: |
cd ..
python -m pip install pip --upgrade
python -m pip install wheel --upgrade
python -m pip install black
python -m black qlib -l 120 --check --diff
# Test Qlib installed with pip
- name: Check Qlib with flake8
run: |
pip install --upgrade pip
pip install flake8
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
- name: Install Qlib with pip
run: |
python -m pip install numpy==1.19.5
python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
- name: Make html with sphnix
run: |
pip install -U sphinx
pip install sphinx_rtd_theme readthedocs_sphinx_ext
pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
sphinx-build -b html . build
cd ..
- name: Install Lightgbm for MacOS
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
# FIX MacOS error: Segmentation fault
# reference: https://github.com/microsoft/LightGBM/issues/4229
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
brew unlink libomp
brew install libomp.rb
- name: Test data downloads
run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
mv /tmp/qlibpublic/data tests/.data
- name: Test workflow by config (install from pip)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
python -m pip uninstall -y pyqlib
# Test Qlib installed from source
- name: Install Qlib from source
run: |
python -m pip install --upgrade cython
python -m pip install numpy jupyter jupyter_contrib_nbextensions
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
python -m pip install gym tianshou torch
pip install -e .
- name: Install test dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -U pyopenssl idna
python -m pip install black pytest
- name: Unit tests with Pytest
run: |
pip install -r scripts/data_collector/pit/requirements.txt
cd tests
python -m pytest . --durations=0
- name: Test workflow by config (install from source)
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

View File

@@ -0,0 +1,57 @@
name: Test qlib from pip
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
timeout-minutes: 120
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
steps:
- name: Test qlib from pip
uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Update pip to the latest version
run: |
python -m pip install --upgrade pip
- name: Qlib installation test
run: |
python -m pip install pyqlib
# Specify the numpy version because the numpy upgrade caused the CI test to fail,
# and this line of code will be removed when the next version of qlib is released.
python -m pip install "numpy<1.23"
- name: Install Lightgbm for MacOS
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
# FIX MacOS error: Segmentation fault
# reference: https://github.com/microsoft/LightGBM/issues/4229
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
brew unlink libomp
brew install libomp.rb
- name: Downloads dependencies data
run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- name: Test workflow by config
run: |
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

View File

@@ -1,4 +1,4 @@
name: Test
name: Test qlib from source
on:
push:
@@ -8,42 +8,60 @@ on:
jobs:
build:
timeout-minutes: 120
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-18.04, ubuntu-20.04]
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
steps:
- uses: actions/checkout@v2
- name: Test qlib from source
uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Update pip to the latest version
run: |
python -m pip install --upgrade pip
- name: Installing pytorch for macos
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
run: |
python -m pip install torch torchvision torchaudio
- name: Installing pytorch for ubuntu
if: ${{ matrix.os == 'ubuntu-18.04' || matrix.os == 'ubuntu-20.04' }}
run: |
python -m pip install --upgrade pip
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
- name: Installing pytorch for windows
if: ${{ matrix.os == 'windows-latest' }}
run: |
python -m pip install --upgrade pip
python -m pip install torch torchvision torchaudio
- name: Set up Python tools
run: |
python -m pip install --upgrade cython
python -m pip install -e .[dev]
- name: Lint with Black
run: |
pip install --upgrade pip
pip install black wheel
black qlib -l 120 --check --diff
- name: Install Qlib with pip
run: |
pip install numpy==1.19.5 ruamel.yaml
pip install pyqlib --ignore-installed
black . -l 120 --check --diff
- name: Make html with sphinx
run: |
pip install -U sphinx
pip install sphinx_rtd_theme readthedocs_sphinx_ext
pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
sphinx-build -b html . build
cd ..
# Check Qlib with pylint
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
# C0103: invalid-name
@@ -67,12 +85,10 @@ jobs:
# W1309: f-string-without-interpolation
# E1102: not-callable
# E1136: unsubscriptable-object
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
- name: Check Qlib with pylint
run: |
pip install --upgrade pip
pip install pylint
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0201,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
# The following flake8 error codes were ignored:
# E501 line too long
@@ -95,47 +111,40 @@ jobs:
# Description: If there is whitespace before ":", it cannot pass the black check.
- name: Check Qlib with flake8
run: |
pip install --upgrade pip
pip install flake8
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
# https://github.com/python/mypy/issues/10600
- name: Check Qlib with mypy
run: |
pip install mypy
mypy qlib --install-types --non-interactive || true
mypy qlib
mypy qlib --verbose
- name: Test data downloads
run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl /tmp/qlibpublic/data --recursive
mv /tmp/qlibpublic/data tests/.data
- name: Test workflow by config (install from pip)
- name: Install Lightgbm for MacOS
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
run: |
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
python -m pip uninstall -y pyqlib
# Test Qlib installed from source
- name: Install Qlib from source
run: |
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
pip install gym tianshou torch
pip install -e .
- name: Install test dependencies
run: |
pip install --upgrade pip
pip install black pytest
- name: Unit tests with Pytest
run: |
pip install -r scripts/data_collector/pit/requirements.txt
cd tests
python -m pytest . --durations=10
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
# FIX MacOS error: Segmentation fault
# reference: https://github.com/microsoft/LightGBM/issues/4229
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
brew unlink libomp
brew install libomp.rb
- name: Test workflow by config (install from source)
run: |
# Version 0.52.0 of numba must be installed manually in CI, otherwise it will cause incompatibility with the latest version of numpy.
python -m pip install numba==0.52.0
# You must update numpy manually, because when installing python tools, it will try to uninstall numpy and cause CI to fail.
python -m pip install --upgrade numpy
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
- name: Unit tests with Pytest
run: |
cd tests
python -m pytest . -m "not slow" --durations=0

View File

@@ -0,0 +1,56 @@
name: Test qlib from source slow
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
timeout-minutes: 120
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
steps:
- name: Test qlib from source slow
uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Set up Python tools
run: |
pip install --upgrade cython numpy pip
pip install -e .[dev]
- name: Downloads dependencies data
run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- name: Install Lightgbm for MacOS
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
# FIX MacOS error: Segmentation fault
# reference: https://github.com/microsoft/LightGBM/issues/4229
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
brew unlink libomp
brew install libomp.rb
- name: Unit tests with Pytest
uses: nick-fields/retry@v2
with:
timeout_minutes: 120
max_attempts: 3
command: |
cd tests
python -m pytest . -m "slow" --durations=0

1
.gitignore vendored
View File

@@ -44,3 +44,4 @@ tags
*.swp
./pretrain
.idea/

View File

@@ -1,6 +1,6 @@
[mypy]
exclude = (?x)(
^qlib/backtest
^qlib/backtest/high_performance_ds\.py$
| ^qlib/contrib
| ^qlib/data
| ^qlib/model

View File

@@ -458,7 +458,7 @@ Before we released Qlib as an open-source project on Github in Sep 2020, Qlib is
This project welcomes contributions and suggestions.
**Here are some
[code standards](docs/developer/code_standard.rst) for submiting a pull request.**
[code standards and development guidance](docs/developer/code_standard_and_dev_guide.rst) for submiting a pull request.**
Making contributions is not a hard thing. Solving an issue(maybe just answering a question raised in [issues list](https://github.com/microsoft/qlib/issues) or [gitter](https://gitter.im/Microsoft/qlib)), fixing/issuing a bug, improving the documents and even fixing a typo are important contributions to Qlib.

View File

@@ -66,7 +66,7 @@ TopkDropoutStrategy
- Adopt the ``Topk-Drop`` algorithm to calculate the target amount of each stock
.. note::
There are two parameters for the ``Topk-Drop`` algorithm
There are two parameters for the ``Topk-Drop`` algorithm:
- `Topk`: The number of stocks held
- `Drop`: The number of stocks sold on each trading day

View File

@@ -45,4 +45,16 @@ When you submit a PR request, you can check whether your code passes the CI test
.. code-block:: bash
pip install -e .[dev]
pre-commit install
pre-commit install
=================================
Development Guidance
=================================
As a developer, you often want make changes to `Qlib` and hope it would reflect directly in your environment without reinstalling it. You can install `Qlib` in editable mode with following command.
The `[dev]` option will help you to install some related packages when developing `Qlib` (e.g. pytest, sphinx)
.. code-block:: bash
pip install -e .[dev]

View File

@@ -1,3 +1,3 @@
pandas==1.1.2
numpy==1.21.0
lightgbm==3.1.0
lightgbm

View File

@@ -0,0 +1,72 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi500
benchmark: &benchmark SH000905
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,80 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi500
benchmark: &benchmark SH000905
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors: []
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -20,7 +20,9 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
> NOTE:
> We have very limited resources to implement and finetune the models. We tried our best effort to fairly compare these models. But some models may have greater potential than what it looks like in the table below. Your contribution is highly welcomed to explore their potential.
## Alpha158 dataset
## Results on CSI300
### Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
@@ -44,7 +46,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 |
## Alpha360 dataset
### Alpha360 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|-------------------------------------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
@@ -79,6 +81,38 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
- Signal-based evaluation: IC, ICIR, Rank IC, Rank ICIR
- Portfolio-based metrics: Annualized Return, Information Ratio, Max Drawdown
## Results on CSI500
The results on CSI500 is not complete. PR's for models on csi500 are welcome!
Transfer previous models in CSI300 to CSI500 is quite easy. You can try models with just a few commands below.
```
cd examples/benchmarks/LightGBM
pip install -r requirements.txt
# create new config and set the benchmark to csi500
cp workflow_config_lightgbm_Alpha158.yaml workflow_config_lightgbm_Alpha158_csi500.yaml
sed -i "s/csi300/csi500/g" workflow_config_lightgbm_Alpha158_csi500.yaml
sed -i "s/SH000300/SH000905/g" workflow_config_lightgbm_Alpha158_csi500.yaml
# you can either run the model once
qrun workflow_config_lightgbm_Alpha158_csi500.yaml
# or run it for multiple times automatically and get the summarized results.
cd ../../
python run_all_model.py run 3 lightgbm Alpha158 csi500 # for models with randomness. please run it for 20 times.
```
### Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
| LightGBM | Alpha158 | 0.0377±0.00 | 0.3860±0.00 | 0.0448±0.00 | 0.4675±0.00 | 0.1151±0.00 | 1.3884±0.00 | -0.0898±0.00 |
### Alpha360 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
| LightGBM | Alpha360 | 0.0400±0.00 | 0.3605±0.00 | 0.0536±0.00 | 0.5431±0.00 | 0.0505±0.00 | 0.7658±0.02 | -0.1880±0.00 |
# Contributing

View File

@@ -28,6 +28,8 @@ The default forecasting models are `Linear`. Users can choose other forecasting
The results of related methods in Qlib's public dataset can be found [here](../)
# Requirements
Here is the minimal hardware requirements to run the ``workflow.py`` of DDG-DA.
Here are the minimal hardware requirements to run the ``workflow.py`` of DDG-DA.
* Memory: 45G
* Disk: 4G
Pytorch with CPU & RAM will be enough for this example.

View File

@@ -117,8 +117,10 @@ def get_all_folders(models, exclude) -> dict:
# function to get all the files under the model folder
def get_all_files(folder_path, dataset) -> (str, str):
yaml_path = str(Path(f"{folder_path}") / f"*{dataset}*.yaml")
def get_all_files(folder_path, dataset, universe="") -> (str, str):
if universe != "":
universe = f"_{universe}"
yaml_path = str(Path(f"{folder_path}") / f"*{dataset}{universe}.yaml")
req_path = str(Path(f"{folder_path}") / f"*.txt")
yaml_file = glob.glob(yaml_path)
req_file = glob.glob(req_path)
@@ -224,6 +226,7 @@ class ModelRunner:
times=1,
models=None,
dataset="Alpha360",
universe="",
exclude=False,
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
exp_folder_name: str = "run_all_model_records",
@@ -245,6 +248,9 @@ class ModelRunner:
determines whether the model being used is excluded or included.
dataset : str
determines the dataset to be used for each model.
universe : str
the stock universe of the dataset.
default "" indicates that
qlib_uri : str
the uri to install qlib with pip
it could be url on the we or local path (NOTE: the local path must be a absolute path)
@@ -259,6 +265,15 @@ class ModelRunner:
-------
Here are some use cases of the function in the bash:
The run_all_models will decide which config to run based no `models` `dataset` `universe`
Example 1):
models="lightgbm", dataset="Alpha158", universe="" will result in running the following config
examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
models="lightgbm", dataset="Alpha158", universe="csi500" will result in running the following config
examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_csi500.yaml
.. code-block:: bash
# Case 1 - run all models multiple times
@@ -279,6 +294,9 @@ class ModelRunner:
# Case 6 - run other models except those are given as arguments for one time
python run_all_model.py run --models=[mlp,tft,sfm] --exclude=True
# Case 7 - run lightgbm model on csi500.
python run_all_model.py run 3 lightgbm Alpha158 csi500
"""
self._init_qlib(exp_folder_name)
@@ -290,7 +308,7 @@ class ModelRunner:
for fn in folders:
# get all files
sys.stderr.write("Retrieving files...\n")
yaml_path, req_path = get_all_files(folders[fn], dataset)
yaml_path, req_path = get_all_files(folders[fn], dataset, universe=universe)
if yaml_path is None:
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
continue

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from pathlib import Path
__version__ = "0.8.5.99"
__version__ = "0.8.6.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
from typing import Union

View File

@@ -2,24 +2,28 @@
# Licensed under the MIT License.
from __future__ import annotations
import copy
from typing import List, Tuple, Union, TYPE_CHECKING
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
import pandas as pd
from .account import Account
from .report import Indicator, PortfolioMetrics
if TYPE_CHECKING:
from ..strategy.base import BaseStrategy
from .executor import BaseExecutor
from .decision import BaseTradeDecision
from .position import Position
from .exchange import Exchange
from .backtest import backtest_loop
from .backtest import collect_data_loop
from .utils import CommonInfrastructure
from .decision import Order
from ..utils import init_instance_by_config
from ..log import get_module_logger
from ..config import C
from ..log import get_module_logger
from ..utils import init_instance_by_config
from .backtest import backtest_loop, collect_data_loop
from .decision import Order
from .exchange import Exchange
from .utils import CommonInfrastructure
# make import more user-friendly by adding `from qlib.backtest import STH`
@@ -28,26 +32,35 @@ logger = get_module_logger("backtest caller")
def get_exchange(
exchange=None,
freq="day",
start_time=None,
end_time=None,
codes="all",
subscribe_fields=[],
open_cost=0.0015,
close_cost=0.0025,
min_cost=5.0,
limit_threshold=None,
exchange: Union[str, dict, object, Path] = None,
freq: str = "day",
start_time: Union[pd.Timestamp, str] = None,
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
subscribe_fields: list = [],
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
limit_threshold: Union[Tuple[str, str], float, None] = None,
deal_price: Union[str, Tuple[str], List[str]] = None,
**kwargs,
):
**kwargs: Any,
) -> Exchange:
"""get_exchange
Parameters
----------
# exchange related arguments
exchange: Exchange().
exchange: Exchange
It could be None or any types that are acceptable by `init_instance_by_config`.
freq: str
frequency of data.
start_time: Union[pd.Timestamp, str]
closed start time for backtest.
end_time: Union[pd.Timestamp, str]
closed end time for backtest.
codes: Union[list, str]
list stock_id list or a string of instruments (i.e. all, csi500, sse50)
subscribe_fields: list
subscribe fields.
open_cost : float
@@ -57,8 +70,6 @@ def get_exchange(
min_cost : float
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
trade_unit : int
Included in kwargs. Please refer to the docs of `__init__` of `Exchange`
deal_price: Union[str, Tuple[str], List[str]]
The `deal_price` supports following two types of input
- <deal_price> : str
@@ -101,10 +112,14 @@ def get_exchange(
def create_account_instance(
start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position"
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
benchmark: str,
account: Union[float, int, dict],
pos_type: str = "Position",
) -> Account:
"""
# TODO: is very strange pass benchmark_config in the account(maybe for report)
# TODO: is very strange pass benchmark_config in the account (maybe for report)
# There should be a post-step to process the report.
Parameters
@@ -132,42 +147,40 @@ def create_account_instance(
key "cash" means initial cash.
key "stock1" means the information of first stock with amount and price(optional).
...
pos_type: str
Postion type.
"""
if isinstance(account, (int, float)):
pos_kwargs = {"init_cash": account}
init_cash = account
position_dict = {}
elif isinstance(account, dict):
init_cash = account["cash"]
del account["cash"]
pos_kwargs = {
"init_cash": init_cash,
"position_dict": account,
}
init_cash = account.pop("cash")
position_dict = account
else:
raise ValueError("account must be in (int, float, Position)")
raise ValueError("account must be in (int, float, dict)")
kwargs = {
"init_cash": account,
"benchmark_config": {
return Account(
init_cash=init_cash,
position_dict=position_dict,
pos_type=pos_type,
benchmark_config={
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
},
"pos_type": pos_type,
}
kwargs.update(pos_kwargs)
return Account(**kwargs)
)
def get_strategy_executor(
start_time,
end_time,
strategy: BaseStrategy,
executor: BaseExecutor,
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, Position] = 1e9,
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
):
) -> Tuple[BaseStrategy, BaseExecutor]:
# NOTE:
# - for avoiding recursive import
@@ -176,7 +189,11 @@ def get_strategy_executor(
from .executor import BaseExecutor # pylint: disable=C0415
trade_account = create_account_instance(
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
start_time=start_time,
end_time=end_time,
benchmark=benchmark,
account=account,
pos_type=pos_type,
)
exchange_kwargs = copy.copy(exchange_kwargs)
@@ -196,29 +213,31 @@ def get_strategy_executor(
def backtest(
start_time,
end_time,
strategy,
executor,
benchmark="SH000300",
account=1e9,
exchange_kwargs={},
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
):
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and executor in the nested decision execution
) -> Tuple[PortfolioMetrics, Indicator]:
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
executor in the nested decision execution
Parameters
----------
start_time : pd.Timestamp|str
start_time : Union[pd.Timestamp, str]
closed start time for backtest
**NOTE**: This will be applied to the outmost executor's calendar.
end_time : pd.Timestamp|str
end_time : Union[pd.Timestamp, str]
closed end time for backtest
**NOTE**: This will be applied to the outmost executor's calendar.
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
strategy : Union[str, dict, BaseStrategy]
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information.
executor : Union[str, dict, BaseExecutor]
strategy : Union[str, dict, object, Path]
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more
information.
executor : Union[str, dict, object, Path]
for initializing the outermost executor.
benchmark: str
the benchmark for reporting.
@@ -257,16 +276,16 @@ def backtest(
def collect_data(
start_time,
end_time,
strategy,
executor,
benchmark="SH000300",
account=1e9,
exchange_kwargs={},
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
return_value: dict = None,
):
) -> Generator[object, None, None]:
"""initialize the strategy and executor, then collect the trade decision data for rl training
please refer to the docs of the backtest for the explanation of the parameters
@@ -291,7 +310,7 @@ def collect_data(
def format_decisions(
decisions: List[BaseTradeDecision],
) -> Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:
) -> Optional[Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]]:
"""
format the decisions collected by `qlib.backtest.collect_data`
The decisions will be organized into a tree-like structure.
@@ -316,7 +335,7 @@ def format_decisions(
cur_freq = decisions[0].strategy.trade_calendar.get_freq()
res = (cur_freq, [])
res: Tuple[str, list] = (cur_freq, [])
last_dec_idx = 0
for i, dec in enumerate(decisions[1:], 1):
if dec.strategy.trade_calendar.get_freq() == cur_freq:
@@ -326,4 +345,4 @@ def format_decisions(
return res
__all__ = ["Order"]
__all__ = ["Order", "backtest"]

View File

@@ -1,15 +1,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import copy
from typing import Dict, List, Tuple
from qlib.utils import init_instance_by_config
from typing import Dict, List, Optional, Tuple, cast
import pandas as pd
from .position import BasePosition
from .report import PortfolioMetrics, Indicator
from qlib.utils import init_instance_by_config
from .decision import BaseTradeDecision, Order
from .exchange import Exchange
from .high_performance_ds import BaseOrderIndicator
from .position import BasePosition
from .report import Indicator, PortfolioMetrics
"""
rtn & earning in the Account
@@ -34,40 +38,42 @@ class AccumulatedInfo:
AccumulatedInfo should be shared across different levels
"""
def __init__(self):
def __init__(self) -> None:
self.reset()
def reset(self):
self.rtn = 0 # accumulated return, do not consider cost
self.cost = 0 # accumulated cost
self.to = 0 # accumulated turnover
def reset(self) -> None:
self.rtn: float = 0.0 # accumulated return, do not consider cost
self.cost: float = 0.0 # accumulated cost
self.to: float = 0.0 # accumulated turnover
def add_return_value(self, value):
def add_return_value(self, value: float) -> None:
self.rtn += value
def add_cost(self, value):
def add_cost(self, value: float) -> None:
self.cost += value
def add_turnover(self, value):
def add_turnover(self, value: float) -> None:
self.to += value
@property
def get_return(self):
def get_return(self) -> float:
return self.rtn
@property
def get_cost(self):
def get_cost(self) -> float:
return self.cost
@property
def get_turnover(self):
def get_turnover(self) -> float:
return self.to
class Account:
"""
The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in qlib/backtest/executor.py:NestedExecutor
Different level of executor has different Account object when calculating metrics. But the position object is shared cross all the Account object.
The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in
qlib/backtest/executor.py:NestedExecutor
Different level of executor has different Account object when calculating metrics. But the position object is
shared cross all the Account object.
"""
def __init__(
@@ -78,7 +84,7 @@ class Account:
benchmark_config: dict = {},
pos_type: str = "Position",
port_metr_enabled: bool = True,
):
) -> None:
"""the trade account of backtest.
Parameters
@@ -99,10 +105,10 @@ class Account:
self._pos_type = pos_type
self._port_metr_enabled = port_metr_enabled
self.benchmark_config = None # avoid no attribute error
self.benchmark_config: dict = {} # avoid no attribute error
self.init_vars(init_cash, position_dict, freq, benchmark_config)
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:
# 1) the following variables are shared by multiple layers
# - you will see a shallow copy instead of deepcopy in the NestedExecutor;
self.init_cash = init_cash
@@ -114,22 +120,22 @@ class Account:
"position_dict": position_dict,
},
"module_path": "qlib.backtest.position",
}
},
)
self.accum_info = AccumulatedInfo()
# 2) following variables are not shared between layers
self.portfolio_metrics = None
self.hist_positions = {}
self.portfolio_metrics: Optional[PortfolioMetrics] = None
self.hist_positions: Dict[pd.Timestamp, BasePosition] = {}
self.reset(freq=freq, benchmark_config=benchmark_config)
def is_port_metr_enabled(self):
def is_port_metr_enabled(self) -> bool:
"""
Is portfolio-based metrics enabled.
"""
return self._port_metr_enabled and not self.current_position.skip_update()
def reset_report(self, freq, benchmark_config):
def reset_report(self, freq: str, benchmark_config: dict) -> None:
# portfolio related metrics
if self.is_port_metr_enabled():
# NOTE:
@@ -140,13 +146,13 @@ class Account:
# fill stock value
# The frequency of account may not align with the trading frequency.
# This may result in obscure bugs when data quality is low.
if isinstance(self.benchmark_config, dict) and self.benchmark_config.get("start_time") is not None:
if isinstance(self.benchmark_config, dict) and "start_time" in self.benchmark_config:
self.current_position.fill_stock_value(self.benchmark_config["start_time"], self.freq)
# trading related metrics(e.g. high-frequency trading)
self.indicator = Indicator()
def reset(self, freq=None, benchmark_config=None, port_metr_enabled: bool = None):
def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabled: bool = None) -> None:
"""reset freq and report of account
Parameters
@@ -155,6 +161,7 @@ class Account:
frequency of account & report, by default None
benchmark_config : {}, optional
benchmark config of report, by default None
port_metr_enabled: bool
"""
if freq is not None:
self.freq = freq
@@ -165,13 +172,13 @@ class Account:
self.reset_report(self.freq, self.benchmark_config)
def get_hist_positions(self):
def get_hist_positions(self) -> Dict[pd.Timestamp, BasePosition]:
return self.hist_positions
def get_cash(self):
def get_cash(self) -> float:
return self.current_position.get_cash()
def _update_state_from_order(self, order, trade_val, cost, trade_price):
def _update_state_from_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
if self.is_port_metr_enabled():
# update turnover
self.accum_info.add_turnover(trade_val)
@@ -191,13 +198,14 @@ class Account:
profit = self.current_position.get_stock_price(order.stock_id) * trade_amount - trade_val
self.accum_info.add_return_value(profit) # note here do not consider cost
def update_order(self, order, trade_val, cost, trade_price):
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
if self.current_position.skip_update():
# TODO: supporting polymorphism for account
# updating order for infinite position is meaningless
return
# if stock is sold out, no stock price information in Position, then we should update account first, then update current position
# if stock is sold out, no stock price information in Position, then we should update account first,
# then update current position
# if stock is bought, there is no stock in current position, update current, then update account
# The cost will be subtracted from the cash at last. So the trading logic can ignore the cost calculation
if order.direction == Order.SELL:
@@ -212,29 +220,40 @@ class Account:
self.current_position.update_order(order, trade_val, cost, trade_price)
self._update_state_from_order(order, trade_val, cost, trade_price)
def update_current_position(self, trade_start_time, trade_end_time, trade_exchange):
"""update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock"""
def update_current_position(
self,
trade_start_time: pd.Timestamp,
trade_end_time: pd.Timestamp,
trade_exchange: Exchange,
) -> None:
"""
Update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock
"""
# update price for stock in the position and the profit from changed_price
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
assert self.current_position is not None
if not self.current_position.skip_update():
stock_list = self.current_position.get_stock_list()
for code in stock_list:
# if suspend, no new price to be updated, profit is 0
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
continue
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time))
self.current_position.update_stock_price(stock_id=code, price=bar_close)
# update holding day count
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
self.current_position.add_count_all(bar=self.freq)
def update_portfolio_metrics(self, trade_start_time, trade_end_time):
def update_portfolio_metrics(self, trade_start_time: pd.Timestamp, trade_end_time: pd.Timestamp) -> None:
"""update portfolio_metrics"""
# calculate earning
# account_value - last_account_value
# for the first trade date, account_value - init_cash
# self.portfolio_metrics.is_empty() to judge is_first_trade_date
# get last_account_value, last_total_cost, last_total_turnover
assert self.portfolio_metrics is not None
if self.portfolio_metrics.is_empty():
last_account_value = self.init_cash
last_total_cost = 0
@@ -243,14 +262,16 @@ class Account:
last_account_value = self.portfolio_metrics.get_latest_account_value()
last_total_cost = self.portfolio_metrics.get_latest_total_cost()
last_total_turnover = self.portfolio_metrics.get_latest_total_turnover()
# get now_account_value, now_stock_value, now_earning, now_cost, now_turnover
now_account_value = self.current_position.calculate_value()
now_stock_value = self.current_position.calculate_stock_value()
now_earning = now_account_value - last_account_value
now_cost = self.accum_info.get_cost - last_total_cost
now_turnover = self.accum_info.get_turnover - last_total_turnover
# update portfolio_metrics for today
# judge whether the the trading is begin.
# judge whether the trading is begin.
# and don't add init account state into portfolio_metrics, due to we don't have excess return in those days.
self.portfolio_metrics.update_portfolio_metrics_record(
trade_start_time=trade_start_time,
@@ -267,7 +288,7 @@ class Account:
stock_value=now_stock_value,
)
def update_hist_positions(self, trade_start_time):
def update_hist_positions(self, trade_start_time: pd.Timestamp) -> None:
"""update history position"""
now_account_value = self.current_position.calculate_value()
# set now_account_value to position
@@ -283,11 +304,11 @@ class Account:
trade_exchange: Exchange,
atomic: bool,
outer_trade_decision: BaseTradeDecision,
trade_info: list = None,
inner_order_indicators: List[Dict[str, pd.Series]] = None,
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
trade_info: list = [],
inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {},
):
) -> None:
"""update trade indicators and order indicators in each bar end"""
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
@@ -319,11 +340,11 @@ class Account:
trade_exchange: Exchange,
atomic: bool,
outer_trade_decision: BaseTradeDecision,
trade_info: list = None,
inner_order_indicators: List[Dict[str, pd.Series]] = None,
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
trade_info: list = [],
inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {},
):
) -> None:
"""update account at each trading bar step
Parameters
@@ -338,6 +359,8 @@ class Account:
whether the trading executor is atomic, which means there is no higher-frequency trading executor inside it
- if atomic is True, calculate the indicators with trade_info
- else, aggregate indicators with inner indicators
outer_trade_decision: BaseTradeDecision
external trade decision
trade_info : List[(Order, float, float, float)], optional
trading information, by default None
- necessary if atomic is True
@@ -377,9 +400,10 @@ class Account:
indicator_config=indicator_config,
)
def get_portfolio_metrics(self):
def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
"""get the history portfolio_metrics and positions instance"""
if self.is_port_metr_enabled():
assert self.portfolio_metrics is not None
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
_positions = self.get_hist_positions()
return _portfolio_metrics, _positions

View File

@@ -2,17 +2,29 @@
# Licensed under the MIT License.
from __future__ import annotations
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
import pandas as pd
from qlib.backtest.decision import BaseTradeDecision
from typing import TYPE_CHECKING
from qlib.backtest.report import Indicator, PortfolioMetrics
if TYPE_CHECKING:
from qlib.strategy.base import BaseStrategy
from qlib.backtest.executor import BaseExecutor
from ..utils.time import Freq
from tqdm.auto import tqdm
from ..utils.time import Freq
def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor):
def backtest_loop(
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
) -> Tuple[PortfolioMetrics, Indicator]:
"""backtest function for the interaction of the outermost strategy and executor in the nested decision execution
please refer to the docs of `collect_data_loop`
@@ -24,26 +36,33 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec
indicator: Indicator
it computes the trading indicator
"""
return_value = {}
return_value: dict = {}
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
pass
return return_value.get("portfolio_metrics"), return_value.get("indicator")
portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
indicator = cast(Indicator, return_value.get("indicator"))
return portfolio_metrics, indicator
def collect_data_loop(
start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None
):
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
return_value: dict = None,
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
"""Generator for collecting the trade decision data for rl training
Parameters
----------
start_time : pd.Timestamp|str
start_time : Union[pd.Timestamp, str]
closed start time for backtest
**NOTE**: This will be applied to the outmost executor's calendar.
end_time : pd.Timestamp|str
end_time : Union[pd.Timestamp, str]
closed end time for backtest
**NOTE**: This will be applied to the outmost executor's calendar.
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
trade_strategy : BaseStrategy
the outermost portfolio strategy
trade_executor : BaseExecutor

View File

@@ -2,27 +2,33 @@
# Licensed under the MIT License.
from __future__ import annotations
from enum import IntEnum
from qlib.data.data import Cal
from qlib.utils.time import concat_date_time, epsilon_change
from qlib.log import get_module_logger
from typing import ClassVar, Optional, Union, List, Tuple
from abc import abstractmethod
from enum import IntEnum
# try to fix circular imports when enabling type hints
from typing import TYPE_CHECKING
from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast
from qlib.backtest.utils import TradeCalendarManager
from qlib.data.data import Cal
from qlib.log import get_module_logger
from qlib.utils.time import concat_date_time, epsilon_change
if TYPE_CHECKING:
from qlib.strategy.base import BaseStrategy
from qlib.backtest.exchange import Exchange
from qlib.backtest.utils import TradeCalendarManager
from dataclasses import dataclass
import numpy as np
import pandas as pd
from dataclasses import dataclass
DecisionType = TypeVar("DecisionType")
class OrderDir(IntEnum):
# Order direction
# Order direction
SELL = 0
BUY = 1
@@ -46,7 +52,7 @@ class Order:
# - they are set by users and is time-invariant.
stock_id: str
amount: float # `amount` is a non-negative and adjusted value
direction: int
direction: OrderDir
# 2) time variant values:
# - Users may want to set these values when using lower level APIs
@@ -61,8 +67,8 @@ class Order:
# What the value should be about in all kinds of cases
# - not tradable: the deal_amount == 0 , factor is None
# - the stock is suspended and the entire order fails. No cost for this order
# - dealed or partially dealed: deal_amount >= 0 and factor is not None
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
# - dealt or partially dealt: deal_amount >= 0 and factor is not None
deal_amount: float = 0.0 # `deal_amount` is a non-negative value
factor: Optional[float] = None
# TODO:
@@ -74,10 +80,10 @@ class Order:
SELL: ClassVar[OrderDir] = OrderDir.SELL
BUY: ClassVar[OrderDir] = OrderDir.BUY
def __post_init__(self):
def __post_init__(self) -> None:
if self.direction not in {Order.SELL, Order.BUY}:
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
self.deal_amount = 0
self.deal_amount = 0.0
self.factor = None
@property
@@ -99,7 +105,7 @@ class Order:
return self.deal_amount * self.sign
@property
def sign(self) -> float:
def sign(self) -> int:
"""
return the sign of trading
- `+1` indicates buying
@@ -112,15 +118,12 @@ class Order:
if isinstance(direction, OrderDir):
return direction
elif isinstance(direction, (int, float, np.integer, np.floating)):
if direction > 0:
return Order.BUY
else:
return Order.SELL
return Order.BUY if direction > 0 else Order.SELL
elif isinstance(direction, str):
dl = direction.lower()
if dl.strip() == "sell":
dl = direction.lower().strip()
if dl == "sell":
return OrderDir.SELL
elif dl.strip() == "buy":
elif dl == "buy":
return OrderDir.BUY
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -138,14 +141,14 @@ class OrderHelper:
Motivation
- Make generating order easier
- User may have no knowledge about the adjust-factor information about the system.
- It involves to much interaction with the exchange when generating orders.
- It involves too much interaction with the exchange when generating orders.
"""
def __init__(self, exchange: Exchange):
def __init__(self, exchange: Exchange) -> None:
self.exchange = exchange
@staticmethod
def create(
self,
code: str,
amount: float,
direction: OrderDir,
@@ -175,21 +178,18 @@ class OrderHelper:
Order:
The created order
"""
if start_time is not None:
start_time = pd.Timestamp(start_time)
if end_time is not None:
end_time = pd.Timestamp(end_time)
# NOTE: factor is a value belongs to the results section. User don't have to care about it when creating orders
return Order(
stock_id=code,
amount=amount,
start_time=start_time,
end_time=end_time,
start_time=start_time if start_time is not None else pd.Timestamp(start_time),
end_time=end_time if end_time is not None else pd.Timestamp(end_time),
direction=direction,
)
class TradeRange:
@abstractmethod
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
"""
This method will be call with following way
@@ -216,6 +216,7 @@ class TradeRange:
"""
raise NotImplementedError(f"Please implement the `__call__` method")
@abstractmethod
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
"""
Parameters
@@ -234,23 +235,26 @@ class TradeRange:
class IdxTradeRange(TradeRange):
def __init__(self, start_idx: int, end_idx: int):
def __init__(self, start_idx: int, end_idx: int) -> None:
self._start_idx = start_idx
self._end_idx = end_idx
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
return self._start_idx, self._end_idx
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
raise NotImplementedError
class TradeRangeByTime(TradeRange):
"""This is a helper function for make decisions"""
def __init__(self, start_time: str, end_time: str):
def __init__(self, start_time: str, end_time: str) -> None:
"""
This is a callable class.
**NOTE**:
- It is designed for minute-bar for intraday trading!!!!!
- It is designed for minute-bar for intra-day trading!!!!!
- Both start_time and end_time are **closed** in the range
Parameters
@@ -264,26 +268,25 @@ class TradeRangeByTime(TradeRange):
self.end_time = pd.Timestamp(end_time).time()
assert self.start_time < self.end_time
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
if trade_calendar is None:
raise NotImplementedError("trade_calendar is necessary for getting TradeRangeByTime.")
start = trade_calendar.start_time
val_start, val_end = concat_date_time(start.date(), self.start_time), concat_date_time(
start.date(), self.end_time
)
start_date = trade_calendar.start_time.date()
val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time)
return trade_calendar.get_range_idx(val_start, val_end)
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
start_date = start_time.date()
val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time)
# NOTE: `end_date` should not be used. Because the `end_date` is for slicing. It may be in the next day
# Assumption: start_time and end_time is for intraday trading. So it is OK for only using start_date
# Assumption: start_time and end_time is for intra-day trading. So it is OK for only using start_date
return max(val_start, start_time), min(val_end, end_time)
class BaseTradeDecision:
class BaseTradeDecision(Generic[DecisionType]):
"""
Trade decisions ara made by strategy and executed by exeuter
Trade decisions ara made by strategy and executed by executor
Motivation:
Here are several typical scenarios for `BaseTradeDecision`
@@ -297,7 +300,7 @@ class BaseTradeDecision:
2. Same as `case 1.3`
"""
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None):
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None) -> None:
"""
Parameters
----------
@@ -316,20 +319,21 @@ class BaseTradeDecision:
"""
self.strategy = strategy
self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
if isinstance(trade_range, Tuple):
# upper strategy has no knowledge about the sub executor before `_init_sub_trading`
self.total_step: Optional[int] = None
if isinstance(trade_range, tuple):
# for Tuple[int, int]
trade_range = IdxTradeRange(*trade_range)
self.trade_range: TradeRange = trade_range
self.trade_range: Optional[TradeRange] = trade_range
def get_decision(self) -> List[object]:
def get_decision(self) -> List[DecisionType]:
"""
get the **concrete decision** (e.g. execution orders)
This will be called by the inner strategy
Returns
-------
List[object]:
List[DecisionType:
The decision result. Typically it is some orders
Example:
[]:
@@ -339,7 +343,7 @@ class BaseTradeDecision:
"""
raise NotImplementedError(f"This type of input is not supported")
def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]:
def update(self, trade_calendar: TradeCalendarManager) -> Optional[BaseTradeDecision]:
"""
Be called at the **start** of each step.
@@ -354,10 +358,8 @@ class BaseTradeDecision:
Returns
-------
None:
No update, use previous decision(or unavailable)
BaseTradeDecision:
New update, use new decision
New update, use new decision. If no updates, return None (use previous decision (or unavailable))
"""
# purpose 1)
self.total_step = trade_calendar.get_trade_len()
@@ -365,13 +367,13 @@ class BaseTradeDecision:
# purpose 2)
return self.strategy.update_trade_decision(self, trade_calendar)
def _get_range_limit(self, **kwargs) -> Tuple[int, int]:
def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
if self.trade_range is not None:
return self.trade_range(trade_calendar=kwargs.get("inner_calendar"))
return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar")))
else:
raise NotImplementedError("The decision didn't provide an index range")
def get_range_limit(self, **kwargs) -> Tuple[int, int]:
def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
"""
return the expected step range for limiting the decision execution time
Both left and right are **closed**
@@ -412,21 +414,22 @@ class BaseTradeDecision:
"""
try:
_start_idx, _end_idx = self._get_range_limit(**kwargs)
except NotImplementedError:
except NotImplementedError as e:
if "default_value" in kwargs:
return kwargs["default_value"]
else:
# Default to get full index
raise NotImplementedError(f"The decision didn't provide an index range") from NotImplementedError
raise NotImplementedError(f"The decision didn't provide an index range") from e
# clip index
if getattr(self, "total_step", None) is not None:
# if `self.update` is called.
# Then the _start_idx, _end_idx should be clipped
assert self.total_step is not None
if _start_idx < 0 or _end_idx >= self.total_step:
logger = get_module_logger("decision")
logger.warning(
f"[{_start_idx},{_end_idx}] go beyoud the total_step({self.total_step}), it will be clipped"
f"[{_start_idx},{_end_idx}] go beyond the total_step({self.total_step}), it will be clipped.",
)
_start_idx, _end_idx = max(0, _start_idx), min(self.total_step - 1, _end_idx)
return _start_idx, _end_idx
@@ -444,7 +447,7 @@ class BaseTradeDecision:
Parameters
----------
rtype: str
- "full": return the full limitation of the deicsion in the day
- "full": return the full limitation of the decision in the day
- "step": return the limitation of current step
raise_error: bool
@@ -497,11 +500,10 @@ class BaseTradeDecision:
return True
return True
def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision):
def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision) -> None:
"""
This method will be called on the inner_trade_decision after it is generated.
`inner_trade_decision` will be changed **inplaced**.
`inner_trade_decision` will be changed **inplace**.
Motivation of the `mod_inner_decision`
- Leave a hook for outer decision to affect the decision generated by the inner strategy
@@ -519,29 +521,38 @@ class BaseTradeDecision:
inner_trade_decision.trade_range = self.trade_range
class EmptyTradeDecision(BaseTradeDecision):
class EmptyTradeDecision(BaseTradeDecision[object]):
def get_decision(self) -> List[object]:
return []
def empty(self) -> bool:
return True
class TradeDecisionWO(BaseTradeDecision):
class TradeDecisionWO(BaseTradeDecision[Order]):
"""
Trade Decision (W)ith (O)rder.
Besides, the time_range is also included.
"""
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None:
super().__init__(strategy, trade_range=trade_range)
self.order_list = order_list
self.order_list = cast(List[Order], order_list)
start, end = strategy.trade_calendar.get_step_time()
for o in order_list:
assert isinstance(o, Order)
if o.start_time is None:
o.start_time = start
if o.end_time is None:
o.end_time = end
def get_decision(self) -> List[object]:
def get_decision(self) -> List[Order]:
return self.order_list
def __repr__(self) -> str:
return f"class: {self.__class__.__name__}; strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]"
return (
f"class: {self.__class__.__name__}; "
f"strategy: {self.strategy}; "
f"trade_range: {self.trade_range}; "
f"order_list[{len(self.order_list)}]"
)

View File

@@ -1,21 +1,25 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import List, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
from ..utils.index_data import IndexData
if TYPE_CHECKING:
from .account import Account
from qlib.backtest.position import BasePosition, Position
import random
import numpy as np
import pandas as pd
from ..data.data import D
from qlib.backtest.position import BasePosition
from ..config import C
from ..constant import REG_CN
from ..data.data import D
from ..log import get_module_logger
from .decision import Order, OrderDir, OrderHelper
from .high_performance_ds import BaseQuote, NumpyQuote
@@ -24,22 +28,22 @@ from .high_performance_ds import BaseQuote, NumpyQuote
class Exchange:
def __init__(
self,
freq="day",
start_time=None,
end_time=None,
codes="all",
freq: str = "day",
start_time: Union[pd.Timestamp, str] = None,
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
deal_price: Union[str, Tuple[str], List[str]] = None,
subscribe_fields=[],
subscribe_fields: list = [],
limit_threshold: Union[Tuple[str, str], float, None] = None,
volume_threshold=None,
open_cost=0.0015,
close_cost=0.0025,
min_cost=5,
impact_cost=0.0,
extra_quote=None,
quote_cls=NumpyQuote,
**kwargs,
):
volume_threshold: Union[tuple, dict] = None,
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
impact_cost: float = 0.0,
extra_quote: pd.DataFrame = None,
quote_cls: Type[BaseQuote] = NumpyQuote,
**kwargs: Any,
) -> None:
"""__init__
:param freq: frequency of data
:param start_time: closed start time for backtest
@@ -72,11 +76,12 @@ class Exchange:
]
1) ("cum" or "current", limit_str) denotes a single volume limit.
- limit_str is qlib data expression which is allowed to define your own Operator.
Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for high frequency,
such as DayCumsum. !!!NOTE: if you want you use the custom operator, you need to
register it in qlib_init.
- "cum" means that this is a cumulative value over time, such as cumulative market volume.
So when it is used as a volume limit, it is necessary to subtract the dealt amount.
Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for
high frequency, such as DayCumsum. !!!NOTE: if you want you use the custom
operator, you need to register it in qlib_init.
- "cum" means that this is a cumulative value over time, such as cumulative market
volume. So when it is used as a volume limit, it is necessary to subtract the dealt
amount.
- "current" means that this is a real-time value and will not accumulate over time,
so it can be directly used as a capacity limit.
e.g. ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"), ("current", "$bidV1")
@@ -84,7 +89,7 @@ class Exchange:
"buy" means the volume limits of buying. "sell" means the volume limits of selling.
Different volume limits will be aggregated with min(). If volume_threshold is only
("cum" or "current", limit_str) instead of a dict, the volume limits are for
both by deault. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
both by default. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
3) e.g. "volume_threshold": {
"all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"),
"buy": ("current", "$askV1"),
@@ -104,13 +109,14 @@ class Exchange:
Necessary fields:
$close is for calculating the total value at end of each day.
Optional fields:
$volume is only necessary when we limit the trade amount or calculate PA(vwap) indicator
$volume is only necessary when we limit the trade amount or calculate
PA(vwap) indicator
$vwap is only necessary when we use the $vwap price as the deal price
$factor is for rounding to the trading unit
limit_sell will be set to False by default(False indicates we can sell this
target on this day).
limit_buy will be set to False by default(False indicates we can buy this
target on this day).
limit_sell will be set to False by default (False indicates we can sell
this target on this day).
limit_buy will be set to False by default (False indicates we can buy
this target on this day).
index: MultipleIndex(instrument, pd.Datetime)
"""
self.freq = freq
@@ -135,7 +141,7 @@ class Exchange:
if limit_threshold is None:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1:
elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
@@ -144,7 +150,7 @@ class Exchange:
deal_price = "$" + deal_price
self.buy_price = self.sell_price = deal_price
elif isinstance(deal_price, (tuple, list)):
self.buy_price, self.sell_price = deal_price
self.buy_price, self.sell_price = cast(Tuple[str, str], deal_price)
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -161,10 +167,10 @@ class Exchange:
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
if self.limit_type == self.LT_TP_EXP:
assert isinstance(limit_threshold, tuple)
for exp in limit_threshold:
necessary_fields.add(exp)
all_fields = necessary_fields | vol_lt_fields
all_fields = list(all_fields | set(subscribe_fields))
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
self.all_fields = all_fields
@@ -182,17 +188,22 @@ class Exchange:
self.quote_cls = quote_cls
self.quote: BaseQuote = self.quote_cls(self.quote_df, freq)
def get_quote_from_qlib(self):
def get_quote_from_qlib(self) -> None:
# get stock data from qlib
if len(self.codes) == 0:
self.codes = D.instruments()
self.quote_df = D.features(
self.codes, self.all_fields, self.start_time, self.end_time, freq=self.freq, disk_cache=True
self.codes,
self.all_fields,
self.start_time,
self.end_time,
freq=self.freq,
disk_cache=True,
).dropna(subset=["$close"])
self.quote_df.columns = self.all_fields
# check buy_price data and sell_price data
for attr in "buy_price", "sell_price":
for attr in ("buy_price", "sell_price"):
pstr = getattr(self, attr) # price string
if self.quote_df[pstr].isna().any():
self.logger.warning("{} field data contains nan.".format(pstr))
@@ -238,9 +249,9 @@ class Exchange:
LT_FLT = "float" # float
LT_NONE = "none" # none
def _get_limit_type(self, limit_threshold):
def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
"""get limit type"""
if isinstance(limit_threshold, Tuple):
if isinstance(limit_threshold, tuple):
return self.LT_TP_EXP
elif isinstance(limit_threshold, float):
return self.LT_FLT
@@ -249,7 +260,7 @@ class Exchange:
else:
raise NotImplementedError(f"This type of `limit_threshold` is not supported")
def _update_limit(self, limit_threshold):
def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None:
# check limit_threshold
limit_type = self._get_limit_type(limit_threshold)
if limit_type == self.LT_NONE:
@@ -257,15 +268,18 @@ class Exchange:
self.quote_df["limit_sell"] = False
elif limit_type == self.LT_TP_EXP:
# set limit
limit_threshold = cast(tuple, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
elif limit_type == self.LT_FLT:
limit_threshold = cast(float, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
def _get_vol_limit(self, volume_threshold):
@staticmethod
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
"""
preproccess the volume limit.
preprocess the volume limit.
get the fields need to get from qlib.
get the volume limit list of buying and selling which is composed of all limits.
Parameters
@@ -295,8 +309,7 @@ class Exchange:
volume_threshold = {"all": volume_threshold}
assert isinstance(volume_threshold, dict)
for key in volume_threshold:
vol_limit = volume_threshold[key]
for key, vol_limit in volume_threshold.items():
assert isinstance(vol_limit, tuple)
fields.add(vol_limit[1])
@@ -307,10 +320,19 @@ class Exchange:
return buy_vol_limit, sell_vol_limit, fields
def check_stock_limit(self, stock_id, start_time, end_time, direction=None):
def check_stock_limit(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: int = None,
) -> bool:
"""
Parameters
----------
stock_id : str
start_time: pd.Timestamp
end_time: pd.Timestamp
direction : int, optional
trade direction, by default None
- if direction is None, check if tradable for buying and selling.
@@ -320,47 +342,50 @@ class Exchange:
if direction is None:
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
return buy_limit or sell_limit
return bool(buy_limit or sell_limit)
elif direction == Order.BUY:
return self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all"))
elif direction == Order.SELL:
return self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all"))
else:
raise ValueError(f"direction {direction} is not supported!")
def check_stock_suspended(self, stock_id, start_time, end_time):
def check_stock_suspended(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
) -> bool:
# is suspended
if stock_id in self.quote.get_all_stock():
return self.quote.get_data(stock_id, start_time, end_time, "$close") is None
else:
return True
def is_stock_tradable(self, stock_id, start_time, end_time, direction=None):
def is_stock_tradable(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: int = None,
) -> bool:
# check if stock can be traded
# same as check in check_order
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(
stock_id, start_time, end_time, direction
):
return False
else:
return True
return not (
self.check_stock_suspended(stock_id, start_time, end_time)
or self.check_stock_limit(stock_id, start_time, end_time, direction)
)
def check_order(self, order):
def check_order(self, order: Order) -> bool:
# check limit and suspended
if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit(
order.stock_id, order.start_time, order.end_time, order.direction
):
return False
else:
return True
return self.is_stock_tradable(order.stock_id, order.start_time, order.end_time, order.direction)
def deal_order(
self,
order,
order: Order,
trade_account: Account = None,
position: BasePosition = None,
dealt_order_amount: defaultdict = defaultdict(float),
):
dealt_order_amount: Dict[str, float] = defaultdict(float),
) -> Tuple[float, float, float]:
"""
Deal order when the actual transaction
the results section in `Order` will be changed.
@@ -371,9 +396,9 @@ class Exchange:
:return: trade_val, trade_cost, trade_price
"""
# check order first.
if self.check_order(order) is False:
if not self.check_order(order):
order.deal_amount = 0.0
# using np.nan instead of None to make it more convenient to should the value in format string
# using np.nan instead of None to make it more convenient to show the value in format string
self.logger.debug(f"Order failed due to trading limitation: {order}")
return 0.0, 0.0, np.nan
@@ -382,7 +407,9 @@ class Exchange:
# NOTE: order will be changed in this function
trade_price, trade_val, trade_cost = self._calc_trade_info_by_order(
order, trade_account.current_position if trade_account else position, dealt_order_amount
order,
trade_account.current_position if trade_account else position,
dealt_order_amount,
)
if trade_val > 1e-5:
# If the order can only be deal 0 value. Nothing to be updated
@@ -396,35 +423,67 @@ class Exchange:
return trade_val, trade_cost, trade_price
def get_quote_info(self, stock_id, start_time, end_time, method="ts_data_last"):
return self.quote.get_data(stock_id, start_time, end_time, method=method)
def get_quote_info(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
field: str,
method: str = "ts_data_last",
) -> Union[None, int, float, bool, IndexData]:
return self.quote.get_data(stock_id, start_time, end_time, field=field, method=method)
def get_close(self, stock_id, start_time, end_time, method="ts_data_last"):
def get_close(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
method: str = "ts_data_last",
) -> Union[None, int, float, bool, IndexData]:
return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method)
def get_volume(self, stock_id, start_time, end_time, method="sum"):
def get_volume(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
method: Optional[str] = "sum",
) -> float:
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method))
def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method="ts_data_last"):
def get_deal_price(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: OrderDir,
method: Optional[str] = "ts_data_last",
) -> float:
if direction == OrderDir.SELL:
pstr = self.sell_price
elif direction == OrderDir.BUY:
pstr = self.buy_price
else:
raise NotImplementedError(f"This type of input is not supported")
deal_price = self.quote.get_data(stock_id, start_time, end_time, field=pstr, method=method)
if method is not None and (deal_price is None or np.isnan(deal_price) or deal_price <= 1e-08):
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
deal_price = self.get_close(stock_id, start_time, end_time, method)
return deal_price
return cast(float, deal_price)
def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]:
def get_factor(
self,
stock_id: str,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
) -> Optional[float]:
"""
Returns
-------
Union[float, None]:
Optional[float]:
`None`: if the stock is suspended `None` may be returned
`float`: return factor if the factor exists
"""
@@ -434,11 +493,16 @@ class Exchange:
return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method="ts_data_last")
def generate_amount_position_from_weight_position(
self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
):
self,
weight_position: dict,
cash: float,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
direction: OrderDir = OrderDir.BUY,
) -> dict:
"""
The generate the target position according to the weight and the cash.
NOTE: All the cash will assigned to the tadable stock.
NOTE: All the cash will assigned to the tradable stock.
Parameter:
weight_position : dict {stock_id : weight}; allocate cash by weight_position
among then, weight must be in this range: 0 < weight < 1
@@ -451,15 +515,14 @@ class Exchange:
# calculate the total weight of tradable value
tradable_weight = 0.0
for stock_id in weight_position:
for stock_id, wp in weight_position.items():
if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
# weight_position must be greater than 0 and less than 1
if weight_position[stock_id] < 0 or weight_position[stock_id] > 1:
if wp < 0 or wp > 1:
raise ValueError(
"weight_position is {}, "
"weight_position is not in the range of (0, 1).".format(weight_position[stock_id])
"weight_position is {}, " "weight_position is not in the range of (0, 1).".format(wp),
)
tradable_weight += weight_position[stock_id]
tradable_weight += wp
if tradable_weight - 1.0 >= 1e-5:
raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight))
@@ -467,19 +530,24 @@ class Exchange:
amount_dict = {}
for stock_id in weight_position:
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(
stock_id=stock_id, start_time=start_time, end_time=end_time
stock_id=stock_id,
start_time=start_time,
end_time=end_time,
):
amount_dict[stock_id] = (
cash
* weight_position[stock_id]
/ tradable_weight
// self.get_deal_price(
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
stock_id=stock_id,
start_time=start_time,
end_time=end_time,
direction=direction,
)
)
return amount_dict
def get_real_deal_amount(self, current_amount, target_amount, factor):
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
"""
Calculate the real adjust deal amount when considering the trading unit
:param current_amount:
@@ -501,7 +569,13 @@ class Exchange:
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
return -deal_amount
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
def generate_order_for_target_amount_position(
self,
target_position: dict,
current_position: dict,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
) -> List[Order]:
"""
Note: some future information is used in this function
Parameter:
@@ -517,7 +591,8 @@ class Exchange:
# three parts: kept stock_id, dropped stock_id, new stock_id
# handle kept stock_id
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different;
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest
# results of the same parameter are different;
# so here we sort stock_id, and then randomly shuffle the order of stock_id
# because the same random seed is used, the final stock_id order is fixed
sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys())))
@@ -546,7 +621,7 @@ class Exchange:
start_time=start_time,
end_time=end_time,
factor=factor,
)
),
)
else:
# sell stock
@@ -558,14 +633,19 @@ class Exchange:
start_time=start_time,
end_time=end_time,
factor=factor,
)
),
)
# return order_list : buy + sell
return sell_order_list + buy_order_list
def calculate_amount_position_value(
self, amount_dict, start_time, end_time, only_tradable=False, direction=OrderDir.SELL
):
self,
amount_dict: dict,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
only_tradable: bool = False,
direction: OrderDir = OrderDir.SELL,
) -> float:
"""Parameter
position : Position()
amount_dict : {stock_id : amount}
@@ -576,30 +656,44 @@ class Exchange:
"""
value = 0
for stock_id in amount_dict:
if (
only_tradable is True
and self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
or only_tradable is False
if not only_tradable or (
not self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time)
and not self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time)
):
value += (
self.get_deal_price(
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
stock_id=stock_id,
start_time=start_time,
end_time=end_time,
direction=direction,
)
* amount_dict[stock_id]
)
return value
def _get_factor_or_raise_error(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
def _get_factor_or_raise_error(
self,
factor: float = None,
stock_id: str = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> float:
"""Please refer to the docs of get_amount_of_trade_unit"""
if factor is None:
if stock_id is not None and start_time is not None and end_time is not None:
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
else:
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
assert factor is not None
return factor
def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
def get_amount_of_trade_unit(
self,
factor: float = None,
stock_id: str = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> Optional[float]:
"""
get the trade unit of amount based on **factor**
the factor can be given directly or calculated in given time range and stock id.
@@ -617,15 +711,23 @@ class Exchange:
"""
if not self.trade_w_adj_price and self.trade_unit is not None:
factor = self._get_factor_or_raise_error(
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
factor=factor,
stock_id=stock_id,
start_time=start_time,
end_time=end_time,
)
return self.trade_unit / factor
else:
return None
def round_amount_by_trade_unit(
self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None
):
self,
deal_amount: float,
factor: float = None,
stock_id: str = None,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> float:
"""Parameter
Please refer to the docs of get_amount_of_trade_unit
deal_amount : float, adjusted amount
@@ -635,12 +737,15 @@ class Exchange:
if not self.trade_w_adj_price and self.trade_unit is not None:
# the minimal amount is 1. Add 0.1 for solving precision problem.
factor = self._get_factor_or_raise_error(
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
factor=factor,
stock_id=stock_id,
start_time=start_time,
end_time=end_time,
)
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
return deal_amount
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int:
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Optional[float]:
"""parse the capacity limit string and return the actual amount of orders that can be executed.
NOTE:
this function will change the order.deal_amount **inplace**
@@ -652,15 +757,12 @@ class Exchange:
dealt_order_amount : dict
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
"""
if order.direction == Order.BUY:
vol_limit = self.buy_vol_limit
elif order.direction == Order.SELL:
vol_limit = self.sell_vol_limit
vol_limit = self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit
if vol_limit is None:
return order.deal_amount
vol_limit_num = []
vol_limit_num: List[float] = []
for limit in vol_limit:
assert isinstance(limit, tuple)
if limit[0] == "current":
@@ -671,7 +773,7 @@ class Exchange:
field=limit[1],
method="sum",
)
vol_limit_num.append(limit_value)
vol_limit_num.append(cast(float, limit_value))
elif limit[0] == "cum":
limit_value = self.quote.get_data(
order.stock_id,
@@ -689,12 +791,14 @@ class Exchange:
if vol_limit_min < orig_deal_amount:
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
return None
def _get_buy_amount_by_cash_limit(self, trade_price: float, cash: float, cost_ratio: float) -> float:
"""return the real order amount after cash limit for buying.
Parameters
----------
trade_price : float
position : cash
cash : float
cost_ratio : float
Return
@@ -702,7 +806,7 @@ class Exchange:
float
the real order amount after cash limit for buying.
"""
max_trade_amount = 0
max_trade_amount = 0.0
if cash >= self.min_cost:
# critical_price means the stock transaction price when the service fee is equal to min_cost.
critical_price = self.min_cost / cost_ratio + self.min_cost
@@ -714,7 +818,12 @@ class Exchange:
max_trade_amount = (cash - self.min_cost) / trade_price
return max_trade_amount
def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount):
def _calc_trade_info_by_order(
self,
order: Order,
position: Optional[BasePosition],
dealt_order_amount: dict,
) -> Tuple[float, float, float]:
"""
Calculation of trade info
**NOTE**: Order will be changed in this function
@@ -753,7 +862,8 @@ class Exchange:
if not np.isclose(order.deal_amount, current_amount):
# when not selling last stock. rounding is necessary
order.deal_amount = self.round_amount_by_trade_unit(
min(current_amount, order.deal_amount), order.factor
min(current_amount, order.deal_amount),
order.factor,
)
# in case of negative value of cash
@@ -778,7 +888,8 @@ class Exchange:
# The money is not enough
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)
order.deal_amount = self.round_amount_by_trade_unit(
min(max_buy_amount, order.deal_amount), order.factor
min(max_buy_amount, order.deal_amount),
order.factor,
)
self.logger.debug(f"Order clipped due to cash limitation: {order}")
else:
@@ -789,7 +900,7 @@ class Exchange:
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
else:
raise NotImplementedError("order type {} error".format(order.type))
raise NotImplementedError("order direction {} error".format(order.direction))
trade_val = order.deal_amount * trade_price
trade_cost = max(trade_val * cost_ratio, self.min_cost)

View File

@@ -1,19 +1,22 @@
from abc import abstractmethod
from __future__ import annotations
import copy
from abc import abstractmethod
from collections import defaultdict
from types import GeneratorType
from typing import Any, Dict, Generator, List, Tuple, Union, cast
import pandas as pd
from qlib.backtest.account import Account
from qlib.backtest.position import BasePosition
from qlib.log import get_module_logger
from types import GeneratorType
from qlib.backtest.account import Account
import pandas as pd
from typing import List, Tuple, Union
from collections import defaultdict
from .decision import Order, BaseTradeDecision
from .exchange import Exchange
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
from ..utils import init_instance_by_config
from ..strategy.base import BaseStrategy
from ..utils import init_instance_by_config
from .decision import BaseTradeDecision, Order
from .exchange import Exchange
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, get_start_end_idx
class BaseExecutor:
@@ -30,9 +33,9 @@ class BaseExecutor:
track_data: bool = False,
trade_exchange: Exchange = None,
common_infra: CommonInfrastructure = None,
settle_type=BasePosition.ST_NO,
**kwargs,
):
settle_type: str = BasePosition.ST_NO,
**kwargs: Any,
) -> None:
"""
Parameters
----------
@@ -53,15 +56,21 @@ class BaseExecutor:
- 'base_price': the based price than which the trading price is advanced, Optional, default by 'twap'
- If 'base_price' is 'twap', the based price is the time weighted average price
- If 'base_price' is 'vwap', the based price is the volume weighted average price
- 'weight_method': weighted method when calculating total trading pa by different orders' pa in each step, optional, default by 'mean'
- 'weight_method': weighted method when calculating total trading pa by different orders' pa in each
step, optional, default by 'mean'
- If 'weight_method' is 'mean', calculating mean value of different orders' pa
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' pa
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' pa
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different
orders' pa
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different
orders' pa
- 'ffr_config': config for calculating fulfill rate(ffr), optional
- 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each step, optional, default by 'mean'
- 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each
step, optional, default by 'mean'
- If 'weight_method' is 'mean', calculating mean value of different orders' ffr
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' ffr
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' ffr
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different
orders' ffr
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different
orders' ffr
Example:
{
'show_indicator': True,
@@ -79,7 +88,8 @@ class BaseExecutor:
whether to print trading info, by default False
track_data : bool, optional
whether to generate trade_decision, will be used when training rl agent
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will
be generated by `collect_data`
- Else, `trade_decision` will not be generated
trade_exchange : Exchange
@@ -111,10 +121,10 @@ class BaseExecutor:
get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")
# record deal order amount in one day
self.dealt_order_amount = defaultdict(float)
self.dealt_order_amount: Dict[str, float] = defaultdict(float)
self.deal_day = None
def reset_common_infra(self, common_infra, copy_trade_account=False):
def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:
"""
reset infrastructure for trading
- reset trade_account
@@ -125,14 +135,15 @@ class BaseExecutor:
self.common_infra.update(common_infra)
if common_infra.has("trade_account"):
if copy_trade_account:
# NOTE: there is a trick in the code.
# shallow copy is used instead of deepcopy.
# 1. So positions are shared
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
else:
self.trade_account = common_infra.get("trade_account")
# NOTE: there is a trick in the code.
# shallow copy is used instead of deepcopy.
# 1. So positions are shared
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
self.trade_account: Account = (
copy.copy(common_infra.get("trade_account"))
if copy_trade_account
else common_infra.get("trade_account")
)
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
@property
@@ -148,7 +159,7 @@ class BaseExecutor:
"""
return self.level_infra.get("trade_calendar")
def reset(self, common_infra: CommonInfrastructure = None, **kwargs):
def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
"""
- reset `start_time` and `end_time`, used in trade calendar
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
@@ -161,13 +172,13 @@ class BaseExecutor:
if common_infra is not None:
self.reset_common_infra(common_infra)
def get_level_infra(self):
def get_level_infra(self) -> LevelInfrastructure:
return self.level_infra
def finished(self):
def finished(self) -> bool:
return self.trade_calendar.finished()
def execute(self, trade_decision: BaseTradeDecision, level: int = 0):
def execute(self, trade_decision: BaseTradeDecision, level: int = 0) -> List[object]:
"""execute the trade decision and return the executed result
NOTE: this function is never used directly in the framework. Should we delete it?
@@ -184,14 +195,17 @@ class BaseExecutor:
execute_result : List[object]
the executed result for trade decision
"""
return_value = {}
return_value: dict = {}
for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
pass
return return_value.get("execute_result")
return cast(list, return_value.get("execute_result"))
@classmethod
@abstractmethod
def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
def _collect_data(
self,
trade_decision: BaseTradeDecision,
level: int = 0,
) -> Union[Generator[Any, Any, Tuple[List[object], dict]], Tuple[List[object], dict]]:
"""
Please refer to the doc of collect_data
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
@@ -209,8 +223,11 @@ class BaseExecutor:
"""
def collect_data(
self, trade_decision: BaseTradeDecision, return_value: dict = None, level: int = 0
) -> List[object]:
self,
trade_decision: BaseTradeDecision,
return_value: dict = None,
level: int = 0,
) -> Generator[Any, Any, List[object]]:
"""Generator for collecting the trade decision data for rl training
his function will make a step forward
@@ -253,7 +270,9 @@ class BaseExecutor:
obj = self._collect_data(trade_decision=trade_decision, level=level)
if isinstance(obj, GeneratorType):
res, kwargs = yield from obj
yield_res = yield from obj
assert isinstance(yield_res, tuple) and len(yield_res) == 2
res, kwargs = yield_res
else:
# Some concrete executor don't have inner decisions
res, kwargs = obj
@@ -279,7 +298,7 @@ class BaseExecutor:
return_value.update({"execute_result": res})
return res
def get_all_executors(self):
def get_all_executors(self) -> List[BaseExecutor]:
"""get all executors"""
return [self]
@@ -287,7 +306,8 @@ class BaseExecutor:
class NestedExecutor(BaseExecutor):
"""
Nested Executor with inner strategy and executor
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env.
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision`
in a higher frequency env.
"""
def __init__(
@@ -304,8 +324,8 @@ class NestedExecutor(BaseExecutor):
skip_empty_decision: bool = True,
align_range_limit: bool = True,
common_infra: CommonInfrastructure = None,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Parameters
----------
@@ -323,10 +343,14 @@ class NestedExecutor(BaseExecutor):
It is only for nested executor, because range_limit is given by outer strategy
"""
self.inner_executor: BaseExecutor = init_instance_by_config(
inner_executor, common_infra=common_infra, accept_types=BaseExecutor
inner_executor,
common_infra=common_infra,
accept_types=BaseExecutor,
)
self.inner_strategy: BaseStrategy = init_instance_by_config(
inner_strategy, common_infra=common_infra, accept_types=BaseStrategy
inner_strategy,
common_infra=common_infra,
accept_types=BaseStrategy,
)
self._skip_empty_decision = skip_empty_decision
@@ -344,10 +368,10 @@ class NestedExecutor(BaseExecutor):
**kwargs,
)
def reset_common_infra(self, common_infra, copy_trade_account=False):
def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:
"""
reset infrastructure for trading
- reset inner_strategyand inner_executor common infra
- reset inner_strategy and inner_executor common infra
"""
# NOTE: please refer to the docs of BaseExecutor.reset_common_infra for the meaning of `copy_trade_account`
@@ -358,7 +382,7 @@ class NestedExecutor(BaseExecutor):
self.inner_executor.reset_common_infra(common_infra, copy_trade_account=True)
self.inner_strategy.reset_common_infra(common_infra)
def _init_sub_trading(self, trade_decision):
def _init_sub_trading(self, trade_decision: BaseTradeDecision) -> None:
trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
sub_level_infra = self.inner_executor.get_level_infra()
@@ -368,14 +392,18 @@ class NestedExecutor(BaseExecutor):
def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision:
# outer strategy have chance to update decision each iterator
updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar)
if updated_trade_decision is not None:
if updated_trade_decision is not None: # TODO: always is None for now?
trade_decision = updated_trade_decision
# NEW UPDATE
# create a hook for inner strategy to update outer decision
self.inner_strategy.alter_outer_trade_decision(trade_decision)
return trade_decision
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
def _collect_data(
self,
trade_decision: BaseTradeDecision,
level: int = 0,
) -> Generator[Any, Any, Tuple[List[object], dict]]:
execute_result = []
inner_order_indicators = []
decision_list = []
@@ -390,8 +418,8 @@ class NestedExecutor(BaseExecutor):
if trade_decision.empty() and self._skip_empty_decision:
# give one chance for outer strategy to update the strategy
# - For updating some information in the sub executor(the strategy have no knowledge of the inner
# executor when generating the decision)
# - For updating some information in the sub executor (the strategy have no knowledge of the inner
# executor when generating the decision)
break
sub_cal: TradeCalendarManager = self.inner_executor.trade_calendar
@@ -405,15 +433,19 @@ class NestedExecutor(BaseExecutor):
# NOTE: !!!!!
# the two lines below is for a special case in RL
# To solve the confliction below
# - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction loop
# For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=> (inner Qlib Executor)])
# To solve the conflicts below
# - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction
# loop For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=>
# (inner Qlib Executor)])
# - However, RL-based framework has it's own script to run the loop
# For an _RL learning example_, (RL Policy) <=> (RL Env[(inner Qlib Executor)])
# To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution below is proposed
# - The entry script follow the example of _RL learning example_ to be compatible with all kinds of RL Framework
# To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution
# below is proposed
# - The entry script follow the example of _RL learning example_ to be compatible with all kinds of
# RL Framework
# - Each step of (RL Env) will make (inner Qlib Executor) one step forward
# - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env) by `yield from` and wait for the action from the policy
# - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env)
# by `yield from` and wait for the action from the policy
# So the two lines below is the implementation of yielding control rights
if isinstance(res, GeneratorType):
res = yield from res
@@ -427,13 +459,15 @@ class NestedExecutor(BaseExecutor):
# NOTE: Trade Calendar will step forward in the follow line
_inner_execute_result = yield from self.inner_executor.collect_data(
trade_decision=_inner_trade_decision, level=level + 1
trade_decision=_inner_trade_decision,
level=level + 1,
)
assert isinstance(_inner_execute_result, list)
self.post_inner_exe_step(_inner_execute_result)
execute_result.extend(_inner_execute_result)
inner_order_indicators.append(
self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True)
self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True),
)
else:
# do nothing and just step forward
@@ -441,7 +475,7 @@ class NestedExecutor(BaseExecutor):
return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}
def post_inner_exe_step(self, inner_exe_res):
def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:
"""
A hook for doing sth after each step of inner strategy
@@ -451,11 +485,23 @@ class NestedExecutor(BaseExecutor):
the execution result of inner task
"""
def get_all_executors(self):
def get_all_executors(self) -> List[BaseExecutor]:
"""get all executors, including self and inner_executor.get_all_executors()"""
return [self, *self.inner_executor.get_all_executors()]
def _retrieve_orders_from_decision(trade_decision: BaseTradeDecision) -> List[Order]:
"""
IDE-friendly helper function.
"""
decisions = trade_decision.get_decision()
orders: List[Order] = []
for decision in decisions:
assert isinstance(decision, Order)
orders.append(decision)
return orders
class SimulatorExecutor(BaseExecutor):
"""Executor that simulate the true market"""
@@ -464,10 +510,10 @@ class SimulatorExecutor(BaseExecutor):
# available trade_types
TT_SERIAL = "serial"
## The orders will be executed serially in a sequence
# The orders will be executed serially in a sequence
# In each trading step, it is possible that users sell instruments first and use the money to buy new instruments
TT_PARAL = "parallel"
## The orders will be executed parallelly
# The orders will be executed in parallel
# In each trading step, if users try to sell instruments first and buy new instruments with money, failure will
# occur
@@ -482,8 +528,8 @@ class SimulatorExecutor(BaseExecutor):
track_data: bool = False,
common_infra: CommonInfrastructure = None,
trade_type: str = TT_SERIAL,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Parameters
----------
@@ -517,7 +563,7 @@ class SimulatorExecutor(BaseExecutor):
List[Order]:
get a list orders according to `self.trade_type`
"""
orders = trade_decision.get_decision()
orders = _retrieve_orders_from_decision(trade_decision)
if self.trade_type == self.TT_SERIAL:
# Orders will be traded in a parallel way
@@ -525,15 +571,15 @@ class SimulatorExecutor(BaseExecutor):
elif self.trade_type == self.TT_PARAL:
# NOTE: !!!!!!!
# Assumption: there will not be orders in different trading direction in a single step of a strategy !!!!
# The parallel trading failure will be caused only by the confliction of money
# Therefore, make the buying go first will make sure the confliction happen.
# The parallel trading failure will be caused only by the conflicts of money
# Therefore, make the buying go first will make sure the conflicts happen.
# It equals to parallel trading after sorting the order by direction
order_it = sorted(orders, key=lambda order: -order.direction)
else:
raise NotImplementedError(f"This type of input is not supported")
return order_it
def _update_dealt_order_amount(self, order):
def _update_dealt_order_amount(self, order: Order) -> None:
"""update date and dealt order amount in the day."""
now_deal_day = self.trade_calendar.get_step_time()[0].floor(freq="D")
@@ -542,10 +588,9 @@ class SimulatorExecutor(BaseExecutor):
self.deal_day = now_deal_day
self.dealt_order_amount[order.stock_id] += order.deal_amount
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
trade_start_time, _ = self.trade_calendar.get_step_time()
execute_result = []
execute_result: list = []
for order in self._get_order_iterator(trade_decision):
# execute the order.
@@ -559,7 +604,8 @@ class SimulatorExecutor(BaseExecutor):
self._update_dealt_order_amount(order)
if self.verbose:
print(
"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}, cash {:.2f}.".format(
"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, "
"value {:.2f}, cash {:.2f}.".format(
trade_start_time,
"sell" if order.direction == Order.SELL else "buy",
order.stock_id,
@@ -569,6 +615,6 @@ class SimulatorExecutor(BaseExecutor):
order.factor,
trade_val,
self.trade_account.get_cash(),
)
),
)
return execute_result, {"trade_info": execute_result}

View File

@@ -1,24 +1,27 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from functools import lru_cache
import logging
from typing import List, Text, Union, Callable, Iterable, Dict
from collections import OrderedDict
from __future__ import annotations
import inspect
import pandas as pd
import numpy as np
import logging
from collections import OrderedDict
from functools import lru_cache
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Union, cast
import numpy as np
import pandas as pd
import qlib.utils.index_data as idd
from ..log import get_module_logger
from ..utils.index_data import IndexData, SingleData
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
from ..utils.time import is_single_value, Freq
import qlib.utils.index_data as idd
from ..utils.time import Freq, is_single_value
class BaseQuote:
def __init__(self, quote_df: pd.DataFrame, freq):
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
self.logger = get_module_logger("online operator", level=logging.INFO)
def get_all_stock(self) -> Iterable:
@@ -38,7 +41,7 @@ class BaseQuote:
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
field: Union[str],
method: Union[str, None] = None,
method: Optional[str] = None,
) -> Union[None, int, float, bool, IndexData]:
"""get the specific field of stock data during start time and end_time,
and apply method to the data.
@@ -98,7 +101,7 @@ class BaseQuote:
class PandasQuote(BaseQuote):
def __init__(self, quote_df: pd.DataFrame, freq):
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
super().__init__(quote_df=quote_df, freq=freq)
quote_dict = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
@@ -123,7 +126,7 @@ class PandasQuote(BaseQuote):
class NumpyQuote(BaseQuote):
def __init__(self, quote_df: pd.DataFrame, freq, region="cn"):
def __init__(self, quote_df: pd.DataFrame, freq: str, region: str = "cn") -> None:
"""NumpyQuote
Parameters
@@ -177,7 +180,8 @@ class NumpyQuote(BaseQuote):
data = self._agg_data(data, method)
return data
def _agg_data(self, data: IndexData, method):
@staticmethod
def _agg_data(data: IndexData, method: str) -> Union[IndexData, np.ndarray, None]:
"""Agg data by specific method."""
# FIXME: why not call the method of data directly?
if method == "sum":
@@ -223,31 +227,31 @@ class BaseSingleMetric:
"""
raise NotImplementedError(f"Please implement the `__init__` method")
def __add__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __add__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__add__` method")
def __radd__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __radd__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
return self + other
def __sub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __sub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__sub__` method")
def __rsub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __rsub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__rsub__` method")
def __mul__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __mul__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__mul__` method")
def __truediv__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __truediv__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__truediv__` method")
def __eq__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __eq__(self, other: object) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__eq__` method")
def __gt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __gt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__gt__` method")
def __lt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
def __lt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `__lt__` method")
def __len__(self) -> int:
@@ -264,7 +268,7 @@ class BaseSingleMetric:
raise NotImplementedError(f"Please implement the `count` method")
def abs(self) -> "BaseSingleMetric":
def abs(self) -> BaseSingleMetric:
raise NotImplementedError(f"Please implement the `abs` method")
@property
@@ -273,18 +277,18 @@ class BaseSingleMetric:
raise NotImplementedError(f"Please implement the `empty` method")
def add(self, other: "BaseSingleMetric", fill_value: float = None) -> "BaseSingleMetric":
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
"""Replace np.NaN with fill_value in two metrics and add them."""
raise NotImplementedError(f"Please implement the `add` method")
def replace(self, replace_dict: dict) -> "BaseSingleMetric":
def replace(self, replace_dict: dict) -> BaseSingleMetric:
"""Replace the value of metric according to replace_dict."""
raise NotImplementedError(f"Please implement the `replace` method")
def apply(self, func: dict) -> "BaseSingleMetric":
"""Replace the value of metric with func(metric).
def apply(self, func: Callable) -> BaseSingleMetric:
"""Replace the value of metric with func (metric).
Currently, the func is only qlib/backtest/order/Order.parse_dir.
"""
@@ -303,11 +307,11 @@ class BaseOrderIndicator:
to inherit the BaseSingleMetric.
"""
def __init__(self, data):
self.data = data
def __init__(self):
self.data = {} # will be created in the subclass
self.logger = get_module_logger("online operator")
def assign(self, col: str, metric: Union[dict, pd.Series]):
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
"""assign one metric.
Parameters
@@ -327,7 +331,7 @@ class BaseOrderIndicator:
raise NotImplementedError(f"Please implement the 'assign' method")
def transfer(self, func: Callable, new_col: str = None) -> Union[None, BaseSingleMetric]:
def transfer(self, func: Callable, new_col: str = None) -> Optional[BaseSingleMetric]:
"""compute new metric with existing metrics.
Parameters
@@ -351,6 +355,7 @@ class BaseOrderIndicator:
tmp_metric = func(**func_kwargs)
if new_col is not None:
self.data[new_col] = tmp_metric
return None
else:
return tmp_metric
@@ -371,7 +376,7 @@ class BaseOrderIndicator:
raise NotImplementedError(f"Please implement the 'get_metric_series' method")
def get_index_data(self, metric) -> SingleData:
def get_index_data(self, metric: str) -> SingleData:
"""get one metric with the format of SingleData
Parameters
@@ -388,7 +393,12 @@ class BaseOrderIndicator:
raise NotImplementedError(f"Please implement the 'get_index_data' method")
@staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
def sum_all_indicators(
order_indicator: BaseOrderIndicator,
indicators: List[BaseOrderIndicator],
metrics: Union[str, List[str]],
fill_value: float = 0,
) -> None:
"""sum indicators with the same metrics.
and assign to the order_indicator(BaseOrderIndicator).
NOTE: indicators could be a empty list when orders in lower level all fail.
@@ -526,16 +536,17 @@ class PandasSingleMetric(SingleMetric):
def index(self):
return list(self.metric.index)
def add(self, other, fill_value=None):
def add(self, other: BaseSingleMetric, fill_value: float = None) -> PandasSingleMetric:
other = cast(PandasSingleMetric, other)
return self.__class__(self.metric.add(other.metric, fill_value=fill_value))
def replace(self, replace_dict: dict):
def replace(self, replace_dict: dict) -> PandasSingleMetric:
return self.__class__(self.metric.replace(replace_dict))
def apply(self, func: Callable):
def apply(self, func: Callable) -> PandasSingleMetric:
return self.__class__(self.metric.apply(func))
def reindex(self, index, fill_value):
def reindex(self, index: Any, fill_value: float) -> PandasSingleMetric:
return self.__class__(self.metric.reindex(index, fill_value=fill_value))
def __repr__(self):
@@ -549,13 +560,14 @@ class PandasOrderIndicator(BaseOrderIndicator):
Str is the name of metric.
"""
def __init__(self):
def __init__(self) -> None:
super(PandasOrderIndicator, self).__init__()
self.data: Dict[str, PandasSingleMetric] = OrderedDict()
def assign(self, col: str, metric: Union[dict, pd.Series]):
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
self.data[col] = PandasSingleMetric(metric)
def get_index_data(self, metric):
def get_index_data(self, metric: str) -> SingleData:
if metric in self.data:
return idd.SingleData(self.data[metric].metric)
else:
@@ -571,7 +583,12 @@ class PandasOrderIndicator(BaseOrderIndicator):
return {k: v.metric for k, v in self.data.items()}
@staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
def sum_all_indicators(
order_indicator: BaseOrderIndicator,
indicators: List[BaseOrderIndicator],
metrics: Union[str, List[str]],
fill_value: float = 0,
) -> None:
if isinstance(metrics, str):
metrics = [metrics]
for metric in metrics:
@@ -591,13 +608,14 @@ class NumpyOrderIndicator(BaseOrderIndicator):
Str is the name of metric.
"""
def __init__(self):
def __init__(self) -> None:
super(NumpyOrderIndicator, self).__init__()
self.data: Dict[str, SingleData] = OrderedDict()
def assign(self, col: str, metric: dict):
def assign(self, col: str, metric: dict) -> None:
self.data[col] = idd.SingleData(metric)
def get_index_data(self, metric):
def get_index_data(self, metric: str) -> SingleData:
if metric in self.data:
return self.data[metric]
else:
@@ -613,21 +631,27 @@ class NumpyOrderIndicator(BaseOrderIndicator):
return tmp_metric_dict
@staticmethod
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
def sum_all_indicators(
order_indicator: BaseOrderIndicator,
indicators: List[BaseOrderIndicator],
metrics: Union[str, List[str]],
fill_value: float = 0,
) -> None:
# get all index(stock_id)
stocks = set()
stock_set: set = set()
for indicator in indicators:
# set(np.ndarray.tolist()) is faster than set(np.ndarray)
stocks = stocks | set(indicator.data[metrics[0]].index.tolist())
stocks = list(stocks)
stocks.sort()
stock_set = stock_set | set(indicator.data[metrics[0]].index.tolist())
stocks = sorted(list(stock_set))
# add metric by index
if isinstance(metrics, str):
metrics = [metrics]
for metric in metrics:
order_indicator.data[metric] = idd.sum_by_index(
[indicator.data[metric] for indicator in indicators], stocks, fill_value
[indicator.data[metric] for indicator in indicators],
stocks,
fill_value,
)
def __repr__(self):

View File

@@ -2,24 +2,28 @@
# Licensed under the MIT License.
from typing import Dict, List, Union
import pandas as pd
from datetime import timedelta
import numpy as np
from typing import Any, Dict, List, Union
import numpy as np
import pandas as pd
from .decision import Order
from ..data.data import D
from .decision import Order
class BasePosition:
"""
The Position want to maintain the position like a dictionary
The Position wants to maintain the position like a dictionary
Please refer to the `Position` class for the position
"""
def __init__(self, *args, cash=0.0, **kwargs):
def __init__(self, *args: Any, cash: float = 0.0, **kwargs: Any) -> None:
self._settle_type = self.ST_NO
self.position: dict = {}
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
pass
def skip_update(self) -> bool:
"""
@@ -49,7 +53,7 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `check_stock` method")
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
"""
Parameters
----------
@@ -64,7 +68,7 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `update_order` method")
def update_stock_price(self, stock_id, price: float):
def update_stock_price(self, stock_id: str, price: float) -> None:
"""
Updating the latest price of the order
The useful when clearing balance at each bar end
@@ -89,13 +93,16 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `calculate_stock_value` method")
def get_stock_list(self) -> List:
def calculate_value(self) -> float:
raise NotImplementedError(f"Please implement the `calculate_value` method")
def get_stock_list(self) -> List[str]:
"""
Get the list of stocks in the position.
"""
raise NotImplementedError(f"Please implement the `get_stock_list` method")
def get_stock_price(self, code) -> float:
def get_stock_price(self, code: str) -> float:
"""
get the latest price of the stock
@@ -106,7 +113,7 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `get_stock_price` method")
def get_stock_amount(self, code) -> float:
def get_stock_amount(self, code: str) -> float:
"""
get the amount of the stock
@@ -124,18 +131,20 @@ class BasePosition:
def get_cash(self, include_settle: bool = False) -> float:
"""
Parameters
----------
include_settle:
will the unsettled(delayed) cash included
Default: not include those unavailable cash
Returns
-------
float:
the available(tradable) cash in position
include_settle:
will the unsettled(delayed) cash included
Default: not include those unavailable cash
"""
raise NotImplementedError(f"Please implement the `get_cash` method")
def get_stock_amount_dict(self) -> Dict:
def get_stock_amount_dict(self) -> dict:
"""
generate stock amount dict {stock_id : amount of stock}
@@ -146,7 +155,7 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
"""
generate stock weight dict {stock_id : value weight of stock in the position}
it is meaningful in the beginning or the end of each trade step
@@ -165,7 +174,7 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
def add_count_all(self, bar):
def add_count_all(self, bar: str) -> None:
"""
Will be called at the end of each bar on each level
@@ -176,24 +185,19 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `add_count_all` method")
def update_weight_all(self):
def update_weight_all(self) -> None:
"""
Updating the position weight;
# TODO: this function is a little weird. The weight data in the position is in a wrong state after dealing order
# and before updating weight.
Parameters
----------
bar :
The level to be updated
"""
raise NotImplementedError(f"Please implement the `add_count_all` method")
ST_CASH = "cash"
ST_NO = None
ST_NO = "None" # String is more typehint friendly than None
def settle_start(self, settle_type: str):
def settle_start(self, settle_type: str) -> None:
"""
settlement start
It will act like start and commit a transaction
@@ -210,21 +214,16 @@ class BasePosition:
"""
raise NotImplementedError(f"Please implement the `settle_conf` method")
def settle_commit(self):
def settle_commit(self) -> None:
"""
settlement commit
Parameters
----------
settle_type : str
please refer to the documents of Executor
"""
raise NotImplementedError(f"Please implement the `settle_commit` method")
def __str__(self):
def __str__(self) -> str:
return self.__dict__.__str__()
def __repr__(self):
def __repr__(self) -> str:
return self.__dict__.__repr__()
@@ -242,13 +241,11 @@ class Position(BasePosition):
}
"""
def __init__(self, cash: float = 0, position_dict: Dict[str, Dict[str, float]] = {}):
def __init__(self, cash: float = 0, position_dict: Dict[str, Union[Dict[str, float], float]] = {}) -> None:
"""Init position by cash and position_dict.
Parameters
----------
start_time :
the start time of backtest. It's for filling the initial value of stocks.
cash : float, optional
initial cash in account, by default 0
position_dict : Dict[
@@ -268,9 +265,9 @@ class Position(BasePosition):
# Otherwise the initial value
self.init_cash = cash
self.position = position_dict.copy()
for stock in self.position:
if isinstance(self.position[stock], int):
self.position[stock] = {"amount": self.position[stock]}
for stock, value in self.position.items():
if isinstance(value, int):
self.position[stock] = {"amount": value}
self.position["cash"] = cash
# If the stock price information is missing, the account value will not be calculated temporarily
@@ -279,21 +276,23 @@ class Position(BasePosition):
except KeyError:
pass
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30):
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
"""fill the stock value by the close price of latest last_days from qlib.
Parameters
----------
start_time :
the start time of backtest.
freq : str
Frequency
last_days : int, optional
the days to get the latest close price, by default 30.
"""
stock_list = []
for stock in self.position:
if not isinstance(self.position[stock], dict):
for stock, value in self.position.items():
if not isinstance(value, dict):
continue
if ("price" not in self.position[stock]) or (self.position[stock]["price"] is None):
if value.get("price", None) is None:
stock_list.append(stock)
if len(stock_list) == 0:
@@ -304,7 +303,12 @@ class Position(BasePosition):
price_end_time = start_time
price_start_time = start_time - timedelta(days=last_days)
price_df = D.features(
stock_list, ["$close"], price_start_time, price_end_time, freq=freq, disk_cache=True
stock_list,
["$close"],
price_start_time,
price_end_time,
freq=freq,
disk_cache=True,
).dropna()
price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict()
@@ -316,7 +320,7 @@ class Position(BasePosition):
self.position[stock]["price"] = price_dict[stock]
self.position["now_account_value"] = self.calculate_value()
def _init_stock(self, stock_id, amount, price=None):
def _init_stock(self, stock_id: str, amount: float, price: float = None) -> None:
"""
initialization the stock in current position
@@ -334,7 +338,7 @@ class Position(BasePosition):
self.position[stock_id]["price"] = price
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
def _buy_stock(self, stock_id, trade_val, cost, trade_price):
def _buy_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None:
trade_amount = trade_val / trade_price
if stock_id not in self.position:
self._init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
@@ -344,15 +348,16 @@ class Position(BasePosition):
self.position["cash"] -= trade_val + cost
def _sell_stock(self, stock_id, trade_val, cost, trade_price):
def _sell_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None:
trade_amount = trade_val / trade_price
if stock_id not in self.position:
raise KeyError("{} not in current position".format(stock_id))
else:
if np.isclose(self.position[stock_id]["amount"], trade_amount):
# Selling all the stocks
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both ralative amount and absolute amount
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both
# relative amount and absolute amount
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
self._del_stock(stock_id)
else:
# decrease the amount of stock
@@ -361,8 +366,10 @@ class Position(BasePosition):
if self.position[stock_id]["amount"] < -1e-5:
raise ValueError(
"only have {} {}, require {}".format(
self.position[stock_id]["amount"] + trade_amount, stock_id, trade_amount
)
self.position[stock_id]["amount"] + trade_amount,
stock_id,
trade_amount,
),
)
new_cash = trade_val - cost
@@ -373,13 +380,13 @@ class Position(BasePosition):
else:
raise NotImplementedError(f"This type of input is not supported")
def _del_stock(self, stock_id):
def _del_stock(self, stock_id: str) -> None:
del self.position[stock_id]
def check_stock(self, stock_id):
def check_stock(self, stock_id: str) -> bool:
return stock_id in self.position
def update_order(self, order, trade_val, cost, trade_price):
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
# handle order, order is a order class, defined in exchange.py
if order.direction == Order.BUY:
# BUY
@@ -390,54 +397,54 @@ class Position(BasePosition):
else:
raise NotImplementedError("do not support order direction {}".format(order.direction))
def update_stock_price(self, stock_id, price):
def update_stock_price(self, stock_id: str, price: float) -> None:
self.position[stock_id]["price"] = price
def update_stock_count(self, stock_id, bar, count):
def update_stock_count(self, stock_id: str, bar: str, count: float) -> None: # TODO: check type of `bar`
self.position[stock_id][f"count_{bar}"] = count
def update_stock_weight(self, stock_id, weight):
def update_stock_weight(self, stock_id: str, weight: float) -> None:
self.position[stock_id]["weight"] = weight
def calculate_stock_value(self):
def calculate_stock_value(self) -> float:
stock_list = self.get_stock_list()
value = 0
for stock_id in stock_list:
value += self.position[stock_id]["amount"] * self.position[stock_id]["price"]
return value
def calculate_value(self):
def calculate_value(self) -> float:
value = self.calculate_stock_value()
value += self.position["cash"] + self.position.get("cash_delay", 0.0)
return value
def get_stock_list(self):
def get_stock_list(self) -> List[str]:
stock_list = list(set(self.position.keys()) - {"cash", "now_account_value", "cash_delay"})
return stock_list
def get_stock_price(self, code):
def get_stock_price(self, code: str) -> float:
return self.position[code]["price"]
def get_stock_amount(self, code):
def get_stock_amount(self, code: str) -> float:
return self.position[code]["amount"] if code in self.position else 0
def get_stock_count(self, code, bar):
def get_stock_count(self, code: str, bar: str) -> float:
"""the days the account has been hold, it may be used in some special strategies"""
if f"count_{bar}" in self.position[code]:
return self.position[code][f"count_{bar}"]
else:
return 0
def get_stock_weight(self, code):
def get_stock_weight(self, code: str) -> float:
return self.position[code]["weight"]
def get_cash(self, include_settle=False):
def get_cash(self, include_settle: bool = False) -> float:
cash = self.position["cash"]
if include_settle:
cash += self.position.get("cash_delay", 0.0)
return cash
def get_stock_amount_dict(self):
def get_stock_amount_dict(self) -> dict:
"""generate stock amount dict {stock_id : amount of stock}"""
d = {}
stock_list = self.get_stock_list()
@@ -445,7 +452,7 @@ class Position(BasePosition):
d[stock_code] = self.get_stock_amount(code=stock_code)
return d
def get_stock_weight_dict(self, only_stock=False):
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
"""get_stock_weight_dict
generate stock weight dict {stock_id : value weight of stock in the position}
it is meaningful in the beginning or the end of each trade date
@@ -463,7 +470,7 @@ class Position(BasePosition):
d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value
return d
def add_count_all(self, bar):
def add_count_all(self, bar: str) -> None:
stock_list = self.get_stock_list()
for code in stock_list:
if f"count_{bar}" in self.position[code]:
@@ -471,18 +478,18 @@ class Position(BasePosition):
else:
self.position[code][f"count_{bar}"] = 1
def update_weight_all(self):
def update_weight_all(self) -> None:
weight_dict = self.get_stock_weight_dict()
for stock_code, weight in weight_dict.items():
self.update_stock_weight(stock_code, weight)
def settle_start(self, settle_type):
def settle_start(self, settle_type: str) -> None:
assert self._settle_type == self.ST_NO, "Currently, settlement can't be nested!!!!!"
self._settle_type = settle_type
if settle_type == self.ST_CASH:
self.position["cash_delay"] = 0.0
def settle_commit(self):
def settle_commit(self) -> None:
if self._settle_type != self.ST_NO:
if self._settle_type == self.ST_CASH:
self.position["cash"] += self.position["cash_delay"]
@@ -507,10 +514,10 @@ class InfPosition(BasePosition):
# InfPosition always have any stocks
return True
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
pass
def update_stock_price(self, stock_id, price: float):
def update_stock_price(self, stock_id: str, price: float) -> None:
pass
def calculate_stock_value(self) -> float:
@@ -522,33 +529,36 @@ class InfPosition(BasePosition):
"""
return np.inf
def get_stock_list(self) -> List:
def calculate_value(self) -> float:
raise NotImplementedError(f"InfPosition doesn't support calculating value")
def get_stock_list(self) -> List[str]:
raise NotImplementedError(f"InfPosition doesn't support stock list position")
def get_stock_price(self, code) -> float:
def get_stock_price(self, code: str) -> float:
"""the price of the inf position is meaningless"""
return np.nan
def get_stock_amount(self, code) -> float:
def get_stock_amount(self, code: str) -> float:
return np.inf
def get_cash(self, include_settle=False) -> float:
def get_cash(self, include_settle: bool = False) -> float:
return np.inf
def get_stock_amount_dict(self) -> Dict:
def get_stock_amount_dict(self) -> dict:
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
def add_count_all(self, bar):
def add_count_all(self, bar: str) -> None:
raise NotImplementedError(f"InfPosition doesn't support add_count_all")
def update_weight_all(self):
def update_weight_all(self) -> None:
raise NotImplementedError(f"InfPosition doesn't support update_weight_all")
def settle_start(self, settle_type: str):
def settle_start(self, settle_type: str) -> None:
pass
def settle_commit(self):
def settle_commit(self) -> None:
pass

View File

@@ -4,14 +4,16 @@
This module is not well maintained.
"""
import numpy as np
import pandas as pd
from .position import Position
from ..data import D
from ..config import C
import datetime
from pathlib import Path
import numpy as np
import pandas as pd
from ..config import C
from ..data import D
from .position import Position
def get_benchmark_weight(
bench,
@@ -214,7 +216,9 @@ def get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, g
for idx, row in (~bench_stock_weight_df.isna()).iterrows():
bench_values = stock_group_field_df.loc[idx, row[row].index]
new_stock_group_df.loc[idx] = get_daily_bin_group(
bench_values, stock_group_field_df.loc[idx], group_n=group_n
bench_values,
stock_group_field_df.loc[idx],
group_n=group_n,
)
return new_stock_group_df
@@ -315,7 +319,7 @@ def brinson_pa(
# The excess profit from the interaction of assets allocation and stocks selection
"RIN": Q4 - Q3 - Q2 + Q1,
"RTotal": Q4 - Q1, # The totoal excess profit
}
},
),
{
"port_group_ret": port_group_ret_df,

View File

@@ -2,19 +2,20 @@
# Licensed under the MIT License.
from collections import OrderedDict
import pathlib
from typing import Dict, List, Tuple, Union
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Text, Tuple, Type, Union, cast
import numpy as np
import pandas as pd
from qlib.backtest.exchange import Exchange
import qlib.utils.index_data as idd
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
from qlib.backtest.exchange import Exchange
from ..tests.config import CSI300_BENCH
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
import qlib.utils.index_data as idd
from .high_performance_ds import BaseOrderIndicator, BaseSingleMetric, NumpyOrderIndicator
class PortfolioMetrics:
@@ -37,7 +38,7 @@ class PortfolioMetrics:
update report
"""
def __init__(self, freq: str = "day", benchmark_config: dict = {}):
def __init__(self, freq: str = "day", benchmark_config: dict = {}) -> None:
"""
Parameters
----------
@@ -48,13 +49,17 @@ class PortfolioMetrics:
- benchmark : Union[str, list, pd.Series]
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
example:
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
print(
D.features(D.instruments('csi500'),
['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()
)
2017-01-04 0.011693
2017-01-05 0.000721
2017-01-06 -0.004322
2017-01-09 0.006874
2017-01-10 -0.003350
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the
'bench'.
- If `benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000300 CSI300
- start_time : Union[str, pd.Timestamp], optional
@@ -69,25 +74,26 @@ class PortfolioMetrics:
self.init_vars()
self.init_bench(freq=freq, benchmark_config=benchmark_config)
def init_vars(self):
self.accounts = OrderedDict() # account position value for each trade time
self.returns = OrderedDict() # daily return rate for each trade time
self.total_turnovers = OrderedDict() # total turnover for each trade time
self.turnovers = OrderedDict() # turnover for each trade time
self.total_costs = OrderedDict() # total trade cost for each trade time
self.costs = OrderedDict() # trade cost rate for each trade time
self.values = OrderedDict() # value for each trade time
self.cashes = OrderedDict()
self.benches = OrderedDict()
self.latest_pm_time = None # pd.TimeStamp
def init_vars(self) -> None:
self.accounts: dict = OrderedDict() # account position value for each trade time
self.returns: dict = OrderedDict() # daily return rate for each trade time
self.total_turnovers: dict = OrderedDict() # total turnover for each trade time
self.turnovers: dict = OrderedDict() # turnover for each trade time
self.total_costs: dict = OrderedDict() # total trade cost for each trade time
self.costs: dict = OrderedDict() # trade cost rate for each trade time
self.values: dict = OrderedDict() # value for each trade time
self.cashes: dict = OrderedDict()
self.benches: dict = OrderedDict()
self.latest_pm_time: Optional[pd.TimeStamp] = None
def init_bench(self, freq=None, benchmark_config=None):
def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
if freq is not None:
self.freq = freq
self.benchmark_config = benchmark_config
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
def _cal_benchmark(self, benchmark_config, freq):
@staticmethod
def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.Series]:
if benchmark_config is None:
return None
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
@@ -109,7 +115,12 @@ class PortfolioMetrics:
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
def _sample_benchmark(
self,
bench: pd.Series,
trade_start_time: Union[str, pd.Timestamp],
trade_end_time: Union[str, pd.Timestamp],
) -> Optional[float]:
if self.bench is None:
return None
@@ -119,35 +130,35 @@ class PortfolioMetrics:
_ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
return 0.0 if _ret is None else _ret - 1
def is_empty(self):
def is_empty(self) -> bool:
return len(self.accounts) == 0
def get_latest_date(self):
def get_latest_date(self) -> pd.Timestamp:
return self.latest_pm_time
def get_latest_account_value(self):
def get_latest_account_value(self) -> float:
return self.accounts[self.latest_pm_time]
def get_latest_total_cost(self):
def get_latest_total_cost(self) -> Any:
return self.total_costs[self.latest_pm_time]
def get_latest_total_turnover(self):
def get_latest_total_turnover(self) -> Any:
return self.total_turnovers[self.latest_pm_time]
def update_portfolio_metrics_record(
self,
trade_start_time=None,
trade_end_time=None,
account_value=None,
cash=None,
return_rate=None,
total_turnover=None,
turnover_rate=None,
total_cost=None,
cost_rate=None,
stock_value=None,
bench_value=None,
):
trade_start_time: Union[str, pd.Timestamp] = None,
trade_end_time: Union[str, pd.Timestamp] = None,
account_value: float = None,
cash: float = None,
return_rate: float = None,
total_turnover: float = None,
turnover_rate: float = None,
total_cost: float = None,
cost_rate: float = None,
stock_value: float = None,
bench_value: float = None,
) -> None:
# check data
if None in [
trade_start_time,
@@ -161,7 +172,8 @@ class PortfolioMetrics:
stock_value,
]:
raise ValueError(
"None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, total_cost, cost_rate, stock_value]"
"None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, "
"total_cost, cost_rate, stock_value]",
)
if trade_end_time is None and bench_value is None:
@@ -183,7 +195,7 @@ class PortfolioMetrics:
self.latest_pm_time = trade_start_time
# finish pm update in each step
def generate_portfolio_metrics_dataframe(self):
def generate_portfolio_metrics_dataframe(self) -> pd.DataFrame:
pm = pd.DataFrame()
pm["account"] = pd.Series(self.accounts)
pm["return"] = pd.Series(self.returns)
@@ -197,19 +209,18 @@ class PortfolioMetrics:
pm.index.name = "datetime"
return pm
def save_portfolio_metrics(self, path):
def save_portfolio_metrics(self, path: str) -> None:
r = self.generate_portfolio_metrics_dataframe()
r.to_csv(path)
def load_portfolio_metrics(self, path):
def load_portfolio_metrics(self, path: str) -> None:
"""load pm from a file
should have format like
columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
:param
path: str/ pathlib.Path()
"""
path = pathlib.Path(path)
with path.open("rb") as f:
with pathlib.Path(path).open("rb") as f:
r = pd.read_csv(f, index_col=0)
r.index = pd.DatetimeIndex(r.index)
@@ -259,30 +270,30 @@ class Indicator:
"""
def __init__(self, order_indicator_cls=NumpyOrderIndicator):
def __init__(self, order_indicator_cls: Type[BaseOrderIndicator] = NumpyOrderIndicator) -> None:
self.order_indicator_cls = order_indicator_cls
# order indicator is metrics for a single order for a specific step
self.order_indicator_his = OrderedDict()
self.order_indicator_his: dict = OrderedDict()
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
# trade indicator is metrics for all orders for a specific step
self.trade_indicator_his = OrderedDict()
self.trade_indicator: Dict[str, float] = OrderedDict()
self.trade_indicator_his: dict = OrderedDict()
self.trade_indicator: Dict[str, Optional[BaseSingleMetric]] = OrderedDict()
self._trade_calendar = None
# def reset(self, trade_calendar: TradeCalendarManager):
def reset(self):
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
def reset(self) -> None:
self.order_indicator = self.order_indicator_cls()
self.trade_indicator = OrderedDict()
# self._trade_calendar = trade_calendar
def record(self, trade_start_time):
def record(self, trade_start_time: Union[str, pd.Timestamp]) -> None:
self.order_indicator_his[trade_start_time] = self.get_order_indicator()
self.trade_indicator_his[trade_start_time] = self.get_trade_indicator()
def _update_order_trade_info(self, trade_info: list):
def _update_order_trade_info(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
amount = dict()
deal_amount = dict()
trade_price = dict()
@@ -311,7 +322,7 @@ class Indicator:
self.order_indicator.assign("trade_dir", trade_dir)
self.order_indicator.assign("pa", pa)
def _update_order_fulfill_rate(self):
def _update_order_fulfill_rate(self) -> None:
def func(deal_amount, amount):
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
@@ -320,11 +331,11 @@ class Indicator:
self.order_indicator.transfer(func, "ffr")
def update_order_indicators(self, trade_info: list):
def update_order_indicators(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
self._update_order_trade_info(trade_info=trade_info)
self._update_order_fulfill_rate()
def _agg_order_trade_info(self, inner_order_indicators: List[Dict[str, pd.Series]]):
def _agg_order_trade_info(self, inner_order_indicators: List[BaseOrderIndicator]) -> None:
# calculate total trade amount with each inner order indicator.
def trade_amount_func(deal_amount, trade_price):
return deal_amount * trade_price
@@ -335,7 +346,10 @@ class Indicator:
# sum inner order indicators with same metric.
all_metric = ["inner_amount", "deal_amount", "trade_price", "trade_value", "trade_cost", "trade_dir"]
self.order_indicator_cls.sum_all_indicators(
self.order_indicator, inner_order_indicators, all_metric, fill_value=0
self.order_indicator,
inner_order_indicators,
all_metric,
fill_value=0,
)
def func(trade_price, deal_amount):
@@ -350,9 +364,9 @@ class Indicator:
self.order_indicator.transfer(func_apply, "trade_dir")
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision):
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision) -> None:
# NOTE: these indicator is designed for order execution, so the
decision: List[Order] = outer_trade_decision.get_decision()
decision: List[Order] = cast(List[Order], outer_trade_decision.get_decision())
if len(decision) == 0:
self.order_indicator.assign("amount", {})
else:
@@ -367,7 +381,7 @@ class Indicator:
decision: BaseTradeDecision,
trade_exchange: Exchange,
pa_config: dict = {},
):
) -> Tuple[Optional[float], Optional[float]]:
"""
Get the base volume and price information
All the base price values are rooted from this function
@@ -378,12 +392,17 @@ class Indicator:
if decision.trade_range is not None:
trade_start_time, trade_end_time = decision.trade_range.clip_time_range(
start_time=trade_start_time, end_time=trade_end_time
start_time=trade_start_time,
end_time=trade_end_time,
)
if price == "deal_price":
price_s = trade_exchange.get_deal_price(
inst, trade_start_time, trade_end_time, direction=direction, method=None
inst,
trade_start_time,
trade_end_time,
direction=direction,
method=None,
)
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -402,31 +421,35 @@ class Indicator:
# NOTE: there are some zeros in the trading price. These cases are known meaningless
# for aligning the previous logic, remove it.
# remove zero and negative values.
price_s = price_s.loc[(price_s > 1e-08).data.astype(np.bool)]
assert isinstance(price_s, idd.SingleData)
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
# ~(np.NaN < 1e-8) -> ~(False) -> True
assert isinstance(price_s, idd.SingleData)
if agg == "vwap":
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
if isinstance(volume_s, (int, float, np.number)):
volume_s = idd.SingleData(volume_s, [trade_start_time])
assert isinstance(volume_s, idd.SingleData)
volume_s = volume_s.reindex(price_s.index)
elif agg == "twap":
volume_s = idd.SingleData(1, price_s.index)
else:
raise NotImplementedError(f"This type of input is not supported")
assert isinstance(volume_s, idd.SingleData)
base_volume = volume_s.sum()
base_price = (price_s * volume_s).sum() / base_volume
return base_price, base_volume
def _agg_base_price(
self,
inner_order_indicators: List[Dict[str, Union[SingleMetric, idd.SingleData]]],
inner_order_indicators: List[BaseOrderIndicator],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
trade_exchange: Exchange,
pa_config: dict = {},
):
) -> None:
"""
# NOTE:!!!!
# Strong assumption!!!!!!
@@ -434,7 +457,7 @@ class Indicator:
Parameters
----------
inner_order_indicators : List[Dict[str, pd.Series]]
inner_order_indicators : List[BaseOrderIndicator]
the indicators of account of inner executor
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
a list of decisions according to inner_order_indicators
@@ -479,14 +502,17 @@ class Indicator:
bv_new = idd.SingleData(bv_new)
bp_all.append(bp_new)
bv_all.append(bv_new)
bp_all = idd.concat(bp_all, axis=1)
bv_all = idd.concat(bv_all, axis=1)
bp_all_multi_data = idd.concat(bp_all, axis=1)
bv_all_multi_data = idd.concat(bv_all, axis=1)
base_volume = bv_all.sum(axis=1)
base_volume = bv_all_multi_data.sum(axis=1)
self.order_indicator.assign("base_volume", base_volume.to_dict())
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict())
self.order_indicator.assign(
"base_price",
((bp_all_multi_data * bv_all_multi_data).sum(axis=1) / base_volume).to_dict(),
)
def _agg_order_price_advantage(self):
def _agg_order_price_advantage(self) -> None:
def if_empty_func(trade_price):
return trade_price.empty
@@ -503,12 +529,12 @@ class Indicator:
def agg_order_indicators(
self,
inner_order_indicators: List[Dict[str, pd.Series]],
inner_order_indicators: List[BaseOrderIndicator],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
outer_trade_decision: BaseTradeDecision,
trade_exchange: Exchange,
indicator_config={},
):
indicator_config: dict = {},
) -> None:
self._agg_order_trade_info(inner_order_indicators)
self._update_trade_amount(outer_trade_decision)
self._update_order_fulfill_rate()
@@ -516,71 +542,66 @@ class Indicator:
self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO
self._agg_order_price_advantage()
def _cal_trade_fulfill_rate(self, method="mean"):
def _cal_trade_fulfill_rate(self, method: str = "mean") -> Optional[BaseSingleMetric]:
if method == "mean":
def func(ffr):
return ffr.mean()
return self.order_indicator.transfer(
lambda ffr: ffr.mean(),
)
elif method == "amount_weighted":
def func(ffr, deal_amount):
return (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum())
return self.order_indicator.transfer(
lambda ffr, deal_amount: (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
)
elif method == "value_weighted":
def func(ffr, trade_value):
return (ffr * trade_value.abs()).sum() / (trade_value.abs().sum())
return self.order_indicator.transfer(
lambda ffr, trade_value: (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()),
)
else:
raise ValueError(f"method {method} is not supported!")
return self.order_indicator.transfer(func)
def _cal_trade_price_advantage(self, method="mean"):
def _cal_trade_price_advantage(self, method: str = "mean") -> Optional[BaseSingleMetric]:
if method == "mean":
def func(pa):
return pa.mean()
return self.order_indicator.transfer(lambda pa: pa.mean())
elif method == "amount_weighted":
def func(pa, deal_amount):
return (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum())
return self.order_indicator.transfer(
lambda pa, deal_amount: (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
)
elif method == "value_weighted":
def func(pa, trade_value):
return (pa * trade_value.abs()).sum() / (trade_value.abs().sum())
return self.order_indicator.transfer(
lambda pa, trade_value: (pa * trade_value.abs()).sum() / (trade_value.abs().sum()),
)
else:
raise ValueError(f"method {method} is not supported!")
return self.order_indicator.transfer(func)
def _cal_trade_positive_rate(self):
def _cal_trade_positive_rate(self) -> Optional[BaseSingleMetric]:
def func(pa):
return (pa > 0).sum() / pa.count()
return self.order_indicator.transfer(func)
def _cal_deal_amount(self):
def _cal_deal_amount(self) -> Optional[BaseSingleMetric]:
def func(deal_amount):
return deal_amount.abs().sum()
return self.order_indicator.transfer(func)
def _cal_trade_value(self):
def _cal_trade_value(self) -> Optional[BaseSingleMetric]:
def func(trade_value):
return trade_value.abs().sum()
return self.order_indicator.transfer(func)
def _cal_trade_order_count(self):
def _cal_trade_order_count(self) -> Optional[BaseSingleMetric]:
def func(amount):
return amount.count()
return self.order_indicator.transfer(func)
def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}):
def cal_trade_indicators(
self,
trade_start_time: Union[str, pd.Timestamp],
freq: str,
indicator_config: dict = {},
) -> None:
show_indicator = indicator_config.get("show_indicator", False)
ffr_config = indicator_config.get("ffr_config", {})
pa_config = indicator_config.get("pa_config", {})
@@ -598,18 +619,22 @@ class Indicator:
self.trade_indicator["count"] = order_count
if show_indicator:
print(
"[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
freq, trade_start_time, fulfill_rate, price_advantage, positive_rate
)
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
freq,
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
fulfill_rate,
price_advantage,
positive_rate,
),
)
def get_order_indicator(self, raw: bool = True):
if raw:
return self.order_indicator
return self.order_indicator.to_series()
def get_order_indicator(self, raw: bool = True) -> Union[BaseOrderIndicator, Dict[Text, pd.Series]]:
return self.order_indicator if raw else self.order_indicator.to_series()
def get_trade_indicator(self):
def get_trade_indicator(self) -> Dict[str, Optional[BaseSingleMetric]]:
return self.trade_indicator
def generate_trade_indicators_dataframe(self):
def generate_trade_indicators_dataframe(self) -> pd.DataFrame:
return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index")

View File

@@ -1,13 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.utils import init_instance_by_config
import abc
from typing import Dict, List, Text, Tuple, Union
from ..model.base import BaseModel
import pandas as pd
from qlib.utils import init_instance_by_config
from ..data.dataset import Dataset
from ..data.dataset.utils import convert_index_format
from ..model.base import BaseModel
from ..utils.resam import resam_ts_data
import pandas as pd
import abc
class Signal(metaclass=abc.ABCMeta):
@@ -19,7 +22,7 @@ class Signal(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame, None]:
"""
get the signal at the end of the decision step(from `start_time` to `end_time`)
@@ -36,13 +39,14 @@ class SignalWCache(Signal):
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
"""
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
def __init__(self, signal: Union[pd.Series, pd.DataFrame]) -> None:
"""
Parameters
----------
signal : Union[pd.Series, pd.DataFrame]
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
The expected format of the signal is like the data below (the order of index is not important and can be
automatically adjusted)
instrument datetime
SH600000 2008-01-02 0.079704
@@ -53,8 +57,8 @@ class SignalWCache(Signal):
"""
self.signal_cache = convert_index_format(signal, level="datetime")
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
# the frequency of the signal may not algin with the decision frequency of strategy
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame]:
# the frequency of the signal may not align with the decision frequency of strategy
# so resampling from the data is necessary
# the latest signal leverage more recent data and therefore is used in trading.
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
@@ -62,7 +66,7 @@ class SignalWCache(Signal):
class ModelSignal(SignalWCache):
def __init__(self, model: BaseModel, dataset: Dataset):
def __init__(self, model: BaseModel, dataset: Dataset) -> None:
self.model = model
self.dataset = dataset
pred_scores = self.model.predict(dataset)
@@ -70,7 +74,7 @@ class ModelSignal(SignalWCache):
pred_scores = pred_scores.iloc[:, 0]
super().__init__(pred_scores)
def _update_model(self):
def _update_model(self) -> None:
"""
When using online data, update model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
@@ -82,7 +86,7 @@ class ModelSignal(SignalWCache):
def create_signal_from(
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame]
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame],
) -> Signal:
"""
create signal from diverse information

View File

@@ -2,16 +2,22 @@
# Licensed under the MIT License.
from __future__ import annotations
import bisect
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Set, Tuple, Union
import numpy as np
from qlib.utils.time import epsilon_change
from typing import TYPE_CHECKING, Tuple, Union
if TYPE_CHECKING:
from qlib.backtest.decision import BaseTradeDecision
import pandas as pd
import warnings
import pandas as pd
from ..data.data import Cal
@@ -26,8 +32,8 @@ class TradeCalendarManager:
freq: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
level_infra: "LevelInfrastructure" = None,
):
level_infra: LevelInfrastructure = None,
) -> None:
"""
Parameters
----------
@@ -43,19 +49,26 @@ class TradeCalendarManager:
self.level_infra = level_infra
self.reset(freq=freq, start_time=start_time, end_time=end_time)
def reset(self, freq, start_time, end_time):
def reset(
self,
freq: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
) -> None:
"""
Please refer to the docs of `__init__`
Reset the trade calendar
- self.trade_len : The total count for trading step
- self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1]
- self.trade_step : The number of trading step finished, self.trade_step can be
[0, 1, 2, ..., self.trade_len - 1]
"""
self.freq = freq
self.start_time = pd.Timestamp(start_time) if start_time else None
self.end_time = pd.Timestamp(end_time) if end_time else None
_calendar = Cal.calendar(freq=freq, future=True)
assert isinstance(_calendar, np.ndarray)
self._calendar = _calendar
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, future=True)
self.start_index = _start_index
@@ -63,7 +76,7 @@ class TradeCalendarManager:
self.trade_len = _end_index - _start_index + 1
self.trade_step = 0
def finished(self):
def finished(self) -> bool:
"""
Check if the trading finished
- Should check before calling strategy.generate_decisions and executor.execute
@@ -72,29 +85,32 @@ class TradeCalendarManager:
"""
return self.trade_step >= self.trade_len
def step(self):
def step(self) -> None:
if self.finished():
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
self.trade_step = self.trade_step + 1
self.trade_step += 1
def get_freq(self):
def get_freq(self) -> str:
return self.freq
def get_trade_len(self):
def get_trade_len(self) -> int:
"""get the total step length"""
return self.trade_len
def get_trade_step(self):
def get_trade_step(self) -> int:
return self.trade_step
def get_step_time(self, trade_step=None, shift=0):
def get_step_time(self, trade_step: int = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
"""
Get the left and right endpoints of the trade_step'th trading interval
About the endpoints:
- Qlib uses the closed interval in time-series data selection, which has the same performance as pandas.Series.loc
# - The returned right endpoints should minus 1 seconds because of the closed interval representation in Qlib.
# Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval.
- Qlib uses the closed interval in time-series data selection, which has the same performance as
pandas.Series.loc
# - The returned right endpoints should minus 1 seconds because of the closed interval representation in
# Qlib.
# Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time
# interval.
Parameters
----------
@@ -105,15 +121,14 @@ class TradeCalendarManager:
Returns
-------
Tuple[pd.Timestamp, pd.Timestap]
Tuple[pd.Timestamp, pd.Timestamp]
- If shift == 0, return the trading time range
- If shift > 0, return the trading time range of the earlier shift bars
- If shift < 0, return the trading time range of the later shift bar
"""
if trade_step is None:
trade_step = self.get_trade_step()
trade_step = trade_step - shift
calendar_index = self.start_index + trade_step
calendar_index = self.start_index + trade_step - shift
return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1])
def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]:
@@ -126,7 +141,7 @@ class TradeCalendarManager:
Parameters
----------
rtype: str
- "full": return the full limitation of the deicsion in the day
- "full": return the full limitation of the decision in the day
- "step": return the limitation of current step
Returns
@@ -134,6 +149,8 @@ class TradeCalendarManager:
Tuple[int, int]:
"""
# potential performance issue
assert self.level_infra is not None
day_start = pd.Timestamp(self.start_time.date())
day_end = epsilon_change(day_start + pd.Timedelta(days=1))
freq = self.level_infra.get("common_infra").get("trade_exchange").freq
@@ -148,7 +165,7 @@ class TradeCalendarManager:
return start_idx - day_start_idx, end_index - day_start_idx
def get_all_time(self):
def get_all_time(self) -> Tuple[pd.Timestamp, pd.Timestamp]:
"""Get the start_time and end_time for trading"""
return self.start_time, self.end_time
@@ -167,30 +184,33 @@ class TradeCalendarManager:
Tuple[int, int]:
the index of the range. **the left and right are closed**
"""
left, right = (
bisect.bisect_right(self._calendar, start_time) - 1,
bisect.bisect_right(self._calendar, end_time) - 1,
)
left = bisect.bisect_right(list(self._calendar), start_time) - 1
right = bisect.bisect_right(list(self._calendar), end_time) - 1
left -= self.start_index
right -= self.start_index
def clip(idx):
def clip(idx: int) -> int:
return min(max(0, idx), self.trade_len - 1)
return clip(left), clip(right)
def __repr__(self) -> str:
return f"class: {self.__class__.__name__}; {self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]"
return (
f"class: {self.__class__.__name__}; "
f"{self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: "
f"[{self.trade_step}/{self.trade_len}]"
)
class BaseInfrastructure:
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
self.reset_infra(**kwargs)
def get_support_infra(self):
@abstractmethod
def get_support_infra(self) -> Set[str]:
raise NotImplementedError("`get_support_infra` is not implemented!")
def reset_infra(self, **kwargs):
def reset_infra(self, **kwargs: Any) -> None:
support_infra = self.get_support_infra()
for k, v in kwargs.items():
if k in support_infra:
@@ -198,53 +218,58 @@ class BaseInfrastructure:
else:
warnings.warn(f"{k} is ignored in `reset_infra`!")
def get(self, infra_name):
def get(self, infra_name: str) -> Any:
if hasattr(self, infra_name):
return getattr(self, infra_name)
else:
warnings.warn(f"infra {infra_name} is not found!")
def has(self, infra_name):
def has(self, infra_name: str) -> bool:
return infra_name in self.get_support_infra() and hasattr(self, infra_name)
def update(self, other):
def update(self, other: BaseInfrastructure) -> None:
support_infra = other.get_support_infra()
infra_dict = {_infra: getattr(other, _infra) for _infra in support_infra if hasattr(other, _infra)}
self.reset_infra(**infra_dict)
class CommonInfrastructure(BaseInfrastructure):
def get_support_infra(self):
return ["trade_account", "trade_exchange"]
def get_support_infra(self) -> Set[str]:
return {"trade_account", "trade_exchange"}
class LevelInfrastructure(BaseInfrastructure):
"""level infrastructure is created by executor, and then shared to strategies on the same level"""
def get_support_infra(self):
def get_support_infra(self) -> Set[str]:
"""
Descriptions about the infrastructure
sub_level_infra:
- **NOTE**: this will only work after _init_sub_trading !!!
"""
return ["trade_calendar", "sub_level_infra", "common_infra"]
return {"trade_calendar", "sub_level_infra", "common_infra"}
def reset_cal(self, freq, start_time, end_time):
def reset_cal(
self,
freq: str,
start_time: Union[str, pd.Timestamp, None],
end_time: Union[str, pd.Timestamp, None],
) -> None:
"""reset trade calendar manager"""
if self.has("trade_calendar"):
self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time)
else:
self.reset_infra(
trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self)
trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self),
)
def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure):
"""this will make the calendar access easier when acrossing multi-levels"""
def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure) -> None:
"""this will make the calendar access easier when crossing multi-levels"""
self.reset_infra(sub_level_infra=sub_level_infra)
def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Union[int, int]:
def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Tuple[int, int]:
"""
A helper function for getting the decision-level index range limitation for inner strategy
- NOTE: this function is not applicable to order-level

View File

@@ -75,6 +75,17 @@ class Config:
def set_conf_from_C(self, config_c):
self.update(**config_c.__dict__["_config"])
def register_from_C(self, config, skip_register=True):
from .utils import set_log_with_config # pylint: disable=C0415
if C.registered and skip_register:
return
C.set_conf_from_C(config)
if C.logging_config:
set_log_with_config(C.logging_config)
C.register()
# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format
PROTOCOL_VERSION = 4
@@ -102,7 +113,7 @@ _default_config = {
# "~/.qlib/stock_data/cn_data"
# # dict
# {"day": "~/.qlib/stock_data/cn_data", "1min": "~/.qlib/stock_data/cn_data_1min"}
# NOTE: provider_uri priority
# NOTE: provider_uri priority:
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
# 3. qlib.init: provider_uri

View File

@@ -63,11 +63,20 @@ def _get_date_parse_fn(target):
get_date_parse_fn(20120101)('2017-01-01') => 20170101
"""
if isinstance(target, int):
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
def _fn(x):
return int(str(x).replace("-", "")[:8]) # 20200201
elif isinstance(target, str) and len(target) == 8:
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
def _fn(x):
return str(x).replace("-", "")[:8] # '20200201'
else:
_fn = lambda x: x # '2021-01-01'
def _fn(x):
return x # '2021-01-01'
return _fn

View File

@@ -255,7 +255,10 @@ class Alpha158(DataHandlerLP):
exclude = config["rolling"].get("exclude", [])
# `exclude` in dataset config unnecessary filed
# `include` in dataset config necessary field
use = lambda x: x not in exclude and (include is None or x in include)
def use(x):
return x not in exclude and (include is None or x in include)
if use("ROC"):
fields += ["Ref($close, %d)/$close" % d for d in windows]
names += ["ROC%d" % d for d in windows]

View File

@@ -48,7 +48,9 @@ def calc_long_short_prec(
group = df.groupby(level=date_col)
N = lambda x: int(len(x) * quantile)
def N(x):
return int(len(x) * quantile)
# find the top/low quantile of prediction and treat them as long and short target
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
@@ -98,7 +100,10 @@ def calc_long_short_return(
if dropna:
df.dropna(inplace=True)
group = df.groupby(level=date_col)
N = lambda x: int(len(x) * quantile)
def N(x):
return int(len(x) * quantile)
r_long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label.mean())
r_short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label.mean())
r_avg = group.label.mean()

View File

@@ -26,6 +26,13 @@ logger = get_module_logger("Evaluate")
def risk_analysis(r, N: int = None, freq: str = "day"):
"""Risk Analysis
NOTE:
The calculation of annulaized return is different from the definition of annualized return.
It is implemented by design.
Qlib tries to cumulated returns by summation instead of production to avoid the cumulated curve being skewed exponentially.
All the calculation of annualized returns follows this principle in Qlib.
TODO: add a parameter to enable calculating metrics with production accumulation of return.
Parameters
----------
@@ -332,7 +339,7 @@ def long_short_backtest(
for stock in long_stocks:
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
if np.isnan(profit):
long_profit.append(0)
else:
@@ -341,17 +348,17 @@ def long_short_backtest(
for stock in short_stocks:
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
if np.isnan(profit):
short_profit.append(0)
else:
short_profit.append(-profit)
short_profit.append(profit * -1)
for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):
# exclude the suspend stock
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
if np.isnan(profit):
all_profit.append(0)
else:

View File

@@ -217,7 +217,7 @@ class MetaDatasetDS(MetaTaskDataset):
----------
task_tpl : Union[dict, list]
Decide what tasks are used.
- dict : the task template the prepared task is generated with `step`, `trunc_days` and `RollingGen`
- dict : the task template, the prepared task is generated with `step`, `trunc_days` and `RollingGen`
- list : when list, use the list of tasks directly
the list is supposed to be sorted according timeline
step : int
@@ -290,7 +290,7 @@ class MetaDatasetDS(MetaTaskDataset):
ic_df = self.internal_data.data_ic_df
segs = task["dataset"]["kwargs"]["segments"]
end = max([segs[k][1] for k in ("train", "valid") if k in segs])
end = max(segs[k][1] for k in ("train", "valid") if k in segs)
ic_df_avail = ic_df.loc[:end, pd.IndexSlice[:, :end]]
# meta data set focus on the **information** instead of preprocess

View File

@@ -92,7 +92,10 @@ class HFLGBModel(ModelFT, LightGBMFInt):
# Convert label into alpha
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
mapping_fn = lambda x: 0 if x < 0 else 1
def mapping_fn(x):
return 0 if x < 0 else 1
df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn)
df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn)
x_train, y_train = df_train["feature"], df_train["label_c"].values

View File

@@ -292,7 +292,9 @@ class HIST(Model):
pretrained_model.load_state_dict(torch.load(self.model_path))
model_dict = self.HIST_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
pretrained_dict = {
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
}
model_dict.update(pretrained_dict)
self.HIST_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")

View File

@@ -53,7 +53,7 @@ class TabnetModel(Model):
"""
TabNet model for Qlib
Args
Args:
ps: probability to generate the bernoulli mask
"""
# set hyper-parameters.

View File

@@ -167,8 +167,8 @@ class TRAModel(Model):
for param in self.tra.predictors.parameters():
param.requires_grad_(False)
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters() if p.requires_grad]))
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters() if p.requires_grad]))
self.logger.info("# model params: %d" % sum(p.numel() for p in self.model.parameters() if p.requires_grad))
self.logger.info("# tra params: %d" % sum(p.numel() for p in self.tra.parameters() if p.requires_grad))
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)

View File

@@ -32,7 +32,6 @@ from ..utils import (
hash_args,
normalize_cache_fields,
code_to_fname,
set_log_with_config,
time_to_slc_point,
read_period_data,
get_period_list,
@@ -109,14 +108,16 @@ class CalendarProvider(abc.ABC):
_, _, si, ei = self.locate_index(start_time, end_time, freq, future)
return _calendar[si : ei + 1]
def locate_index(self, start_time, end_time, freq, future=False):
def locate_index(
self, start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], freq: str, future: bool = False
):
"""Locate the start time index and end time index in a calendar under certain frequency.
Parameters
----------
start_time : str
start_time : pd.Timestamp
start of the time range.
end_time : str
end_time : pd.Timestamp
end of the time range.
freq : str
time frequency, available: year/quarter/month/week/day.
@@ -603,11 +604,7 @@ class DatasetProvider(abc.ABC):
"""
# FIXME: Windows OS or MacOS using spawn: https://docs.python.org/3.8/library/multiprocessing.html?highlight=spawn#contexts-and-start-methods
# NOTE: This place is compatible with windows, windows multi-process is spawn
if not C.registered:
C.set_conf_from_C(g_config)
if C.logging_config:
set_log_with_config(C.logging_config)
C.register()
C.register_from_C(g_config)
obj = dict()
for field in column_names:

View File

@@ -438,7 +438,7 @@ class TSDataSampler:
@property
def empty(self):
return self.__len__() == 0
return len(self) == 0
def _get_indices(self, row: int, col: int) -> np.array:
"""

View File

@@ -24,7 +24,7 @@ class FileStorageMixin:
"""
# NOTE: provider_uri priority
# NOTE: provider_uri priority:
# 1. self._provider_uri : if provider_uri is provided.
# 2. provider_uri in qlib.config.C
@@ -106,10 +106,7 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
if not self.uri.exists():
self._write_calendar(values=[])
with self.uri.open("rb") as fp:
return [
str(x)
for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8")
]
return [str(x) for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8")]
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
with self.uri.open(mode=mode) as fp:

View File

@@ -8,6 +8,7 @@ Ensemble module can merge the objects in an Ensemble. For example, if there are
from typing import Union
import pandas as pd
from qlib.utils import FLATTEN_TUPLE, flatten_dict
from qlib.log import get_module_logger
class Ensemble:
@@ -79,6 +80,7 @@ class RollingEnsemble(Ensemble):
"""
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
get_module_logger("RollingEnsemble").info(f"keys in group: {list(ensemble_dict.keys())}")
artifact_list = list(ensemble_dict.values())
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
artifact = pd.concat(artifact_list)
@@ -121,6 +123,7 @@ class AverageEnsemble(Ensemble):
"""
# need to flatten the nested dict
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
get_module_logger("AverageEnsemble").info(f"keys in group: {list(ensemble_dict.keys())}")
values = list(ensemble_dict.values())
# NOTE: this may change the style underlying data!!!!
# from pd.DataFrame to pd.Series

View File

@@ -12,16 +12,25 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
"""
import socket
from typing import Callable, List
from typing import Callable, List, Optional
from tqdm.auto import tqdm
from qlib.config import C
from qlib.data.dataset import Dataset
from qlib.data.dataset.weight import Reweighter
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import flatten_dict, init_instance_by_config, auto_filter_kwargs, fill_placeholder
from qlib.utils import (
auto_filter_kwargs,
fill_placeholder,
flatten_dict,
init_instance_by_config,
)
from qlib.utils.paral import call_in_subproc
from qlib.workflow import R
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.data.dataset.weight import Reweighter
def _log_task_info(task_config: dict):
@@ -210,17 +219,26 @@ class TrainerR(Trainer):
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
def __init__(self, experiment_name: str = None, train_func: Callable = task_train):
def __init__(
self,
experiment_name: Optional[str] = None,
train_func: Callable = task_train,
call_in_subproc: bool = False,
default_rec_name: Optional[str] = None,
):
"""
Init TrainerR.
Args:
experiment_name (str, optional): the default name of experiment.
train_func (Callable, optional): default training method. Defaults to `task_train`.
call_in_subproc (bool): call the process in subprocess to force memory release
"""
super().__init__()
self.experiment_name = experiment_name
self.default_rec_name = default_rec_name
self.train_func = train_func
self._call_in_subproc = call_in_subproc
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
@@ -245,7 +263,10 @@ class TrainerR(Trainer):
experiment_name = self.experiment_name
recs = []
for task in tqdm(tasks, desc="train tasks"):
rec = train_func(task, experiment_name, **kwargs)
if self._call_in_subproc:
get_module_logger("TrainerR").info("running models in sub process (for forcing release memroy).")
train_func = call_in_subproc(train_func, C)
rec = train_func(task, experiment_name, recorder_name=self.default_rec_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
recs.append(rec)
return recs
@@ -272,7 +293,9 @@ class DelayTrainerR(TrainerR):
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
"""
def __init__(self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train):
def __init__(
self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train, **kwargs
):
"""
Init TrainerRM.
@@ -281,7 +304,7 @@ class DelayTrainerR(TrainerR):
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
"""
super().__init__(experiment_name, train_func)
super().__init__(experiment_name, train_func, **kwargs)
self.end_train_func = end_train_func
self.delay = True
@@ -330,7 +353,12 @@ class TrainerRM(Trainer):
TM_ID = "_id in TaskManager"
def __init__(
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
self,
experiment_name: str = None,
task_pool: str = None,
train_func=task_train,
skip_run_task: bool = False,
default_rec_name: Optional[str] = None,
):
"""
Init TrainerR.
@@ -349,6 +377,7 @@ class TrainerRM(Trainer):
self.task_pool = task_pool
self.train_func = train_func
self.skip_run_task = skip_run_task
self.default_rec_name = default_rec_name
def train(
self,
@@ -357,6 +386,7 @@ class TrainerRM(Trainer):
experiment_name: str = None,
before_status: str = TaskManager.STATUS_WAITING,
after_status: str = TaskManager.STATUS_DONE,
default_rec_name: Optional[str] = None,
**kwargs,
) -> List[Recorder]:
"""
@@ -384,6 +414,8 @@ class TrainerRM(Trainer):
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
if default_rec_name is None:
default_rec_name = self.default_rec_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
@@ -398,6 +430,7 @@ class TrainerRM(Trainer):
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
recorder_name=default_rec_name,
**kwargs,
)
@@ -466,6 +499,7 @@ class DelayTrainerRM(TrainerRM):
train_func=begin_task_train,
end_train_func=end_task_train,
skip_run_task: bool = False,
**kwargs,
):
"""
Init DelayTrainerRM.
@@ -480,7 +514,7 @@ class DelayTrainerRM(TrainerRM):
Only run_task in the worker. Otherwise skip run_task.
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
"""
super().__init__(experiment_name, task_pool, train_func)
super().__init__(experiment_name, task_pool, train_func, **kwargs)
self.end_train_func = end_train_func
self.delay = True
self.skip_run_task = skip_run_task

View File

@@ -248,7 +248,7 @@ def load_orders(
Order(
row["instrument"],
row["amount"],
int(row["order_type"]),
OrderDir(int(row["order_type"])),
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
)

View File

@@ -1,7 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Train, test, inference utilities.
The APIs in this directory are NOT considered final and are subject to change!
"""

View File

@@ -1,99 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import copy
from typing import Callable, Sequence
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from qlib.constant import INF
from qlib.log import get_module_logger
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.reward import Reward
from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
_logger = get_module_logger(__name__)
def backtest(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: LogWriter | list[LogWriter],
reward: Reward | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
) -> None:
"""Backtest with the parallelism provided by RL framework.
Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to test against.
logger
Logger to record the backtest results. Logger must be present because
without logger, all information will be lost.
reward
Optional reward function. For backtest, this is for testing the rewards
and logging them only.
finite_env_type
Type of finite env implementation.
concurrency
Parallel workers.
"""
# To save bandwidth
min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
def env_factory():
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
# and could be thread unsafe.
# I'm not sure whether it's a design flaw.
# I'll rethink about this when designing the trainer.
if finite_env_type == "dummy":
# We could only experience the "threading-unsafe" problem in dummy.
state = copy.deepcopy(state_interpreter)
action = copy.deepcopy(action_interpreter)
rew = copy.deepcopy(reward)
else:
state, action, rew = state_interpreter, action_interpreter, reward
return EnvWrapper(
simulator_fn,
state,
action,
seed_iterator,
rew,
logger=LogCollector(min_loglevel=min_loglevel),
)
with DataQueue(initial_states) as seed_iterator:
vector_env = vectorize_env(
env_factory,
finite_env_type,
concurrency,
logger,
)
policy.eval()
with vector_env.collector_guard():
test_collector = Collector(policy, vector_env)
_logger.info("All ready. Start backtest.")
test_collector.collect(n_step=INF * len(vector_env))

View File

@@ -1,4 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# TBD

View File

@@ -9,4 +9,5 @@ Multi-asset is on the way.
from .interpreter import *
from .network import *
from .policy import *
from .reward import *
from .simulator_simple import *

View File

@@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import cast
import numpy as np
from qlib.rl.reward import Reward
from .simulator_simple import SAOEState, SAOEMetrics
__all__ = ["PAPenaltyReward"]
class PAPenaltyReward(Reward[SAOEState]):
"""Encourage higher PAs, but penalize stacking all the amounts within a very short time.
Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`.
Parameters
----------
penalty
The penalty for large volume in a short time.
"""
def __init__(self, penalty: float = 100.0):
self.penalty = penalty
def reward(self, simulator_state: SAOEState) -> float:
whole_order = simulator_state.order.amount
assert whole_order > 0
last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict())
pa = last_step["pa"] * last_step["amount"] / whole_order
# Inspect the "break-down" of the latest step: trading amount at every tick
last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :]
penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum()
reward = pa + penalty
# Throw error in case of NaN
assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}"
self.log("reward/pa", pa)
self.log("reward/penalty", penalty)
return reward

View File

@@ -131,11 +131,14 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
"""
history_exec: pd.DataFrame
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns.
Index is ``datetime``.
"""
history_steps: pd.DataFrame
"""Positions at each step. The position before first step is also recorded.
See :class:`SAOEMetrics` for available columns."""
See :class:`SAOEMetrics` for available columns.
Index is ``datetime``, which is the **starting** time of each step."""
metrics: SAOEMetrics | None
"""Metrics. Only available when done."""

View File

@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Train, test, inference utilities."""
from .api import backtest, train
from .callbacks import EarlyStopping, Checkpoint
from .trainer import Trainer
from .vessel import TrainingVessel, TrainingVesselBase

118
qlib/rl/trainer/api.py Normal file
View File

@@ -0,0 +1,118 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Callable, Sequence, cast, Any
from tianshou.policy import BasePolicy
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.reward import Reward
from qlib.rl.utils import FiniteEnvType, LogWriter
from .vessel import TrainingVessel
from .trainer import Trainer
def train(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
reward: Reward,
vessel_kwargs: dict[str, Any],
trainer_kwargs: dict[str, Any],
) -> None:
"""Train a policy with the parallelism provided by RL framework.
Experimental API. Parameters might change shortly.
Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to train against.
reward
Reward function.
vessel_kwargs
Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``.
trainer_kwargs
Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``.
"""
vessel = TrainingVessel(
simulator_fn=simulator_fn,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
train_initial_states=initial_states,
reward=reward, # ignore none
**vessel_kwargs,
)
trainer = Trainer(**trainer_kwargs)
trainer.fit(vessel)
def backtest(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: LogWriter | list[LogWriter],
reward: Reward | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
) -> None:
"""Backtest with the parallelism provided by RL framework.
Experimental API. Parameters might change shortly.
Parameters
----------
simulator_fn
Callable receiving initial seed, returning a simulator.
state_interpreter
Interprets the state of simulators.
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to test against.
logger
Logger to record the backtest results. Logger must be present because
without logger, all information will be lost.
reward
Optional reward function. For backtest, this is for testing the rewards
and logging them only.
finite_env_type
Type of finite env implementation.
concurrency
Parallel workers.
"""
vessel = TrainingVessel(
simulator_fn=simulator_fn,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
test_initial_states=initial_states,
reward=cast(Reward, reward), # ignore none
)
trainer = Trainer(
finite_env_type=finite_env_type,
concurrency=concurrency,
loggers=logger,
)
trainer.test(vessel)

View File

@@ -0,0 +1,267 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Callbacks to insert customized recipes during the training.
Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of RL.
"""
from __future__ import annotations
import copy
import shutil
import time
from datetime import datetime
from pathlib import Path
from typing import Any, TYPE_CHECKING
import numpy as np
import torch
from qlib.log import get_module_logger
from qlib.typehint import Literal
if TYPE_CHECKING:
from .trainer import Trainer
from .vessel import TrainingVesselBase
_logger = get_module_logger(__name__)
class Callback:
"""Base class of all callbacks."""
def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called before the whole fit process begins."""
def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called after the whole fit process ends."""
def on_train_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when each collect for training begins."""
def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when the training ends.
To access all outputs produced during training, cache the data in either trainer and vessel,
and post-process them in this hook.
"""
def on_validate_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when every run for validation begins."""
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when the validation ends."""
def on_test_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when every run of testing begins."""
def on_test_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when the testing ends."""
def on_iter_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called when every iteration (i.e., collect) starts."""
def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
"""Called upon every end of iteration.
This is called **after** the bump of ``current_iter``,
when the previous iteration is considered complete.
"""
def state_dict(self) -> Any:
"""Get a state dict of the callback for pause and resume."""
def load_state_dict(self, state_dict: Any) -> None:
"""Resume the callback from a saved state dict."""
class EarlyStopping(Callback):
"""Stop training when a monitored metric has stopped improving.
The earlystopping callback will be triggered each time validation ends.
It will examine the metrics produced in validation,
and get the metric with name ``monitor` (``monitor`` is ``reward`` by default),
to check whether it's no longer increasing / decreasing.
It takes ``min_delta`` and ``patience`` if applicable.
If it's found to be not increasing / decreasing any more.
``trainer.should_stop`` will be set to true,
and the training terminates.
Implementation reference: https://github.com/keras-team/keras/blob/v2.9.0/keras/callbacks.py#L1744-L1893
"""
def __init__(
self,
monitor: str = "reward",
min_delta: float = 0.0,
patience: int = 0,
mode: Literal["min", "max"] = "max",
baseline: float | None = None,
restore_best_weights: bool = False,
):
super().__init__()
self.monitor = monitor
self.patience = patience
self.baseline = baseline
self.min_delta = abs(min_delta)
self.restore_best_weights = restore_best_weights
self.best_weights: Any | None = None
if mode not in ["min", "max"]:
raise ValueError("Unsupported earlystopping mode: " + mode)
if mode == "min":
self.monitor_op = np.less
elif mode == "max":
self.monitor_op = np.greater
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def state_dict(self) -> dict:
return {"wait": self.wait, "best": self.best, "best_weights": self.best_weights, "best_iter": self.best_iter}
def load_state_dict(self, state_dict: dict) -> None:
self.wait = state_dict["wait"]
self.best = state_dict["best"]
self.best_weights = state_dict["best_weights"]
self.best_iter = state_dict["best_iter"]
def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
# Allow instances to be re-used
self.wait = 0
self.best = np.inf if self.monitor_op == np.less else -np.inf
self.best_weights = None
self.best_iter = 0
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
current = self.get_monitor_value(trainer)
if current is None:
return
if self.restore_best_weights and self.best_weights is None:
# Restore the weights after first iteration if no progress is ever made.
self.best_weights = copy.deepcopy(vessel.state_dict())
self.wait += 1
if self._is_improvement(current, self.best):
self.best = current
self.best_iter = trainer.current_iter
if self.restore_best_weights:
self.best_weights = copy.deepcopy(vessel.state_dict())
# Only restart wait if we beat both the baseline and our previous best.
if self.baseline is None or self._is_improvement(current, self.baseline):
self.wait = 0
# Only check after the first epoch.
if self.wait >= self.patience and trainer.current_iter > 0:
trainer.should_stop = True
_logger.info(f"On iteration %d: early stopping", trainer.current_iter + 1)
if self.restore_best_weights and self.best_weights is not None:
_logger.info("Restoring model weights from the end of the best iteration: %d", self.best_iter + 1)
vessel.load_state_dict(self.best_weights)
def get_monitor_value(self, trainer: Trainer) -> Any:
monitor_value = trainer.metrics.get(self.monitor)
if monitor_value is None:
_logger.warning(
"Early stopping conditioned on metric `%s` which is not available. Available metrics are: %s",
self.monitor,
",".join(list(trainer.metrics.keys())),
)
return monitor_value
def _is_improvement(self, monitor_value, reference_value):
return self.monitor_op(monitor_value - self.min_delta, reference_value)
class Checkpoint(Callback):
"""Save checkpoints periodically for persistence and recovery.
Reference: https://github.com/PyTorchLightning/pytorch-lightning/blob/bfa8b7be/pytorch_lightning/callbacks/model_checkpoint.py
Parameters
----------
dirpath
Directory to save the checkpoint file.
filename
Checkpoint filename. Can contain named formatting options to be auto-filled.
For example: ``{iter:03d}-{reward:.2f}.pth``.
Supported argument names are:
- iter (int)
- metrics in ``trainer.metrics``
- time string, in the format of ``%Y%m%d%H%M%S``
save_latest
Save the latest checkpoint in ``latest.pth``.
If ``link``, ``latest.pth`` will be created as a softlink.
If ``copy``, ``latest.pth`` will be stored as an individual copy.
Set to none to disable this.
every_n_iters
Checkpoints are saved at the end of every n iterations of training,
after validation if applicable.
time_interval
Maximum time (seconds) before checkpoints save again.
save_on_fit_end
Save one last checkpoint at the end to fit.
Do nothing if a checkpoint is already saved there.
"""
def __init__(
self,
dirpath: Path,
filename: str = "{iter:03d}.pth",
save_latest: Literal["link", "copy"] | None = "link",
every_n_iters: int | None = None,
time_interval: int | None = None,
save_on_fit_end: bool = True,
):
self.dirpath = Path(dirpath)
self.filename = filename
self.save_latest = save_latest
self.every_n_iters = every_n_iters
self.time_interval = time_interval
self.save_on_fit_end = save_on_fit_end
self._last_checkpoint_name: str | None = None
self._last_checkpoint_iter: int | None = None
self._last_checkpoint_time: float | None = None
def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
if self.save_on_fit_end and (trainer.current_iter != self._last_checkpoint_iter):
self._save_checkpoint(trainer)
def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
should_save_ckpt = False
if self.every_n_iters is not None and (trainer.current_iter + 1) % self.every_n_iters == 0:
should_save_ckpt = True
if self.time_interval is not None and (
self._last_checkpoint_time is None or (time.time() - self._last_checkpoint_time) >= self.time_interval
):
should_save_ckpt = True
if should_save_ckpt:
self._save_checkpoint(trainer)
def _save_checkpoint(self, trainer: Trainer) -> None:
self.dirpath.mkdir(exist_ok=True, parents=True)
self._last_checkpoint_name = self._new_checkpoint_name(trainer)
self._last_checkpoint_iter = trainer.current_iter
self._last_checkpoint_time = time.time()
torch.save(trainer.state_dict(), self.dirpath / self._last_checkpoint_name)
latest_pth = self.dirpath / "latest.pth"
# Remove first before saving
if self.save_latest and latest_pth.exists():
latest_pth.unlink()
if self.save_latest == "link":
latest_pth.symlink_to(self.dirpath / self._last_checkpoint_name)
elif self.save_latest == "copy":
shutil.copyfile(self.dirpath / self._last_checkpoint_name, latest_pth)
def _new_checkpoint_name(self, trainer: Trainer) -> str:
return self.filename.format(
iter=trainer.current_iter, time=datetime.now().strftime("%Y%m%d%H%M%S"), **trainer.metrics
)

343
qlib/rl/trainer/trainer.py Normal file
View File

@@ -0,0 +1,343 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import copy
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any, Iterable, TypeVar, Sequence, cast
import torch
from qlib.rl.simulator import InitialStateType
from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogCollector, LogWriter, LogBuffer, vectorize_env, LogLevel
from qlib.log import get_module_logger
from qlib.rl.utils.finite_env import FiniteVectorEnv
from qlib.typehint import Literal
from .callbacks import Callback
from .vessel import TrainingVesselBase
_logger = get_module_logger(__name__)
T = TypeVar("T")
class Trainer:
"""
Utility to train a policy on a particular task.
Different from traditional DL trainer, the iteration of this trainer is "collect",
rather than "epoch", or "mini-batch".
In each collect, :class:`Collector` collects a number of policy-env interactions, and accumulates
them into a replay buffer. This buffer is used as the "data" to train the policy.
At the end of each collect, the policy is *updated* several times.
The API has some resemblence with `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/>`__,
but it's essentially different because this trainer is built for RL applications, and thus
most configurations are under RL context.
We are still looking for ways to incorporate existing trainer libraries, because it looks like
big efforts to build a trainer as powerful as those libraries, and also, that's not our primary goal.
It's essentially different
`tianshou's built-in trainers <https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html>`__,
as it's far much more complicated than that.
Parameters
----------
max_iters
Maximum iterations before stopping.
val_every_n_iters
Perform validation every n iterations (i.e., training collects).
logger
Logger to record the backtest results. Logger must be present because
without logger, all information will be lost.
finite_env_type
Type of finite env implementation.
concurrency
Parallel workers.
fast_dev_run
Create a subset for debugging.
How this is implemented depends on the implementation of training vessel.
For :class:`~qlib.rl.vessel.TrainingVessel`, if greater than zero,
a random subset sized ``fast_dev_run`` will be used
instead of ``train_initial_states`` and ``val_initial_states``.
"""
should_stop: bool
"""Set to stop the training."""
metrics: dict
"""Numeric metrics of produced in train/val/test.
In the middle of training / validation, metrics will be of the latest episode.
When each iteration of training / validation finishes, metrics will be the aggregation
of all episodes encountered in this iteration.
Cleared on every new iteration of training.
In fit, validation metrics will be prefixed with ``val/``.
"""
current_iter: int
"""Current iteration (collect) of training."""
loggers: list[LogWriter]
"""A list of log writers."""
def __init__(
self,
*,
max_iters: int | None = None,
val_every_n_iters: int | None = None,
loggers: LogWriter | list[LogWriter] | None = None,
callbacks: list[Callback] | None = None,
finite_env_type: FiniteEnvType = "subproc",
concurrency: int = 2,
fast_dev_run: int | None = None,
):
self.max_iters = max_iters
self.val_every_n_iters = val_every_n_iters
if isinstance(loggers, list):
self.loggers = loggers
elif isinstance(loggers, LogWriter):
self.loggers = [loggers]
else:
self.loggers = []
self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel()))
self.callbacks: list[Callback] = callbacks if callbacks is not None else []
self.finite_env_type = finite_env_type
self.concurrency = concurrency
self.fast_dev_run = fast_dev_run
self.current_stage: Literal["train", "val", "test"] = "train"
self.vessel: TrainingVesselBase = cast(TrainingVesselBase, None)
def initialize(self):
"""Initialize the whole training process.
The states here should be synchronized with state_dict.
"""
self.should_stop = False
self.current_iter = 0
self.current_episode = 0
self.current_stage = "train"
def initialize_iter(self):
"""Initialize one iteration / collect."""
self.metrics = {}
def state_dict(self) -> dict:
"""Putting every states of current training into a dict, at best effort.
It doesn't try to handle all the possible kinds of states in the middle of one training collect.
For most cases at the end of each iteration, things should be usually correct.
Note that it's also intended behavior that replay buffer data in the collector will be lost.
"""
return {
"vessel": self.vessel.state_dict(),
"callbacks": {name: callback.state_dict() for name, callback in self.named_callbacks().items()},
"loggers": {name: logger.state_dict() for name, logger in self.named_loggers().items()},
"should_stop": self.should_stop,
"current_iter": self.current_iter,
"current_episode": self.current_episode,
"current_stage": self.current_stage,
"metrics": self.metrics,
}
def load_state_dict(self, state_dict: dict) -> None:
"""Load all states into current trainer."""
self.vessel.load_state_dict(state_dict["vessel"])
for name, callback in self.named_callbacks().items():
callback.load_state_dict(state_dict["callbacks"][name])
for name, logger in self.named_loggers().items():
logger.load_state_dict(state_dict["loggers"][name])
self.should_stop = state_dict["should_stop"]
self.current_iter = state_dict["current_iter"]
self.current_episode = state_dict["current_episode"]
self.current_stage = state_dict["current_stage"]
self.metrics = state_dict["metrics"]
def named_callbacks(self) -> dict[str, Callback]:
"""Retrieve a collection of callbacks where each one has a name.
Useful when saving checkpoints.
"""
return _named_collection(self.callbacks)
def named_loggers(self) -> dict[str, LogWriter]:
"""Retrieve a collection of loggers where each one has a name.
Useful when saving checkpoints.
"""
return _named_collection(self.loggers)
def fit(self, vessel: TrainingVesselBase, ckpt_path: Path | None = None) -> None:
"""Train the RL policy upon the defined simulator.
Parameters
----------
vessel
A bundle of all elements used in training.
ckpt_path
Load a pre-trained / paused training checkpoint.
"""
self.vessel = vessel
vessel.assign_trainer(self)
if ckpt_path is not None:
_logger.info("Resuming states from %s", str(ckpt_path))
self.load_state_dict(torch.load(ckpt_path))
else:
self.initialize()
self._call_callback_hooks("on_fit_start")
while not self.should_stop:
self.initialize_iter()
self._call_callback_hooks("on_iter_start")
self.current_stage = "train"
self._call_callback_hooks("on_train_start")
# TODO
# Add a feature that supports reloading the training environment every few iterations.
with _wrap_context(vessel.train_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.train(vector_env)
self._call_callback_hooks("on_train_end")
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
# Implementation of validation loop
self.current_stage = "val"
self._call_callback_hooks("on_validate_start")
with _wrap_context(vessel.val_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.validate(vector_env)
self._call_callback_hooks("on_validate_end")
# This iteration is considered complete.
# Bumping the current iteration counter.
self.current_iter += 1
if self.max_iters is not None and self.current_iter >= self.max_iters:
self.should_stop = True
self._call_callback_hooks("on_iter_end")
self._call_callback_hooks("on_fit_end")
def test(self, vessel: TrainingVesselBase) -> None:
"""Test the RL policy against the simulator.
The simulator will be fed with data generated in ``test_seed_iterator``.
Parameters
----------
vessel
A bundle of all related elements.
"""
self.vessel = vessel
vessel.assign_trainer(self)
self.initialize_iter()
self.current_stage = "test"
self._call_callback_hooks("on_test_start")
with _wrap_context(vessel.test_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.test(vector_env)
self._call_callback_hooks("on_test_end")
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
"""Create a vectorized environment from iterator and the training vessel."""
def env_factory():
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
# and could be thread unsafe.
# I'm not sure whether it's a design flaw.
# I'll rethink about this when designing the trainer.
if self.finite_env_type == "dummy":
# We could only experience the "threading-unsafe" problem in dummy.
state = copy.deepcopy(self.vessel.state_interpreter)
action = copy.deepcopy(self.vessel.action_interpreter)
rew = copy.deepcopy(self.vessel.reward)
else:
state = self.vessel.state_interpreter
action = self.vessel.action_interpreter
rew = self.vessel.reward
return EnvWrapper(
self.vessel.simulator_fn,
state,
action,
iterator,
rew,
logger=LogCollector(min_loglevel=self._min_loglevel()),
)
return vectorize_env(
env_factory,
self.finite_env_type,
self.concurrency,
self.loggers,
)
def _metrics_callback(self, on_episode: bool, on_collect: bool, log_buffer: LogBuffer) -> None:
if on_episode:
# Update the global counter.
self.current_episode = log_buffer.global_episode
metrics = log_buffer.episode_metrics()
elif on_collect:
# Update the latest metrics.
metrics = log_buffer.collect_metrics()
if self.current_stage == "val":
metrics = {"val/" + name: value for name, value in metrics.items()}
self.metrics.update(metrics)
def _call_callback_hooks(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
for callback in self.callbacks:
fn = getattr(callback, hook_name)
fn(self, self.vessel, *args, **kwargs)
def _min_loglevel(self):
if not self.loggers:
return LogLevel.PERIODIC
else:
# To save bandwidth
return min(lg.loglevel for lg in self.loggers)
@contextmanager
def _wrap_context(obj):
"""Make any object a (possibly dummy) context manager."""
if isinstance(obj, AbstractContextManager):
# obj has __enter__ and __exit__
with obj as ctx:
yield ctx
else:
yield obj
def _named_collection(seq: Sequence[T]) -> dict[str, T]:
"""Convert a list into a dict, where each item is named with its type."""
res = {}
for item in seq:
typename = type(item).__name__.lower()
if typename not in res:
res[typename] = item
else:
# names are auto-labelled as earlystop1, earlystop2, ...
for retry in range(1, 1000):
if f"{typename}{retry}" not in res:
res[f"{typename}{retry}"] = item
return res

214
qlib/rl/trainer/vessel.py Normal file
View File

@@ -0,0 +1,214 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import weakref
from typing import Callable, ContextManager, Generic, Iterable, TYPE_CHECKING, Sequence, Any, TypeVar, cast, Dict
import numpy as np
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy
from qlib.constant import INF
from qlib.rl.interpreter import StateType, ActType, ObsType, PolicyActType
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.reward import Reward
from qlib.rl.utils import DataQueue
from qlib.log import get_module_logger
from qlib.rl.utils.finite_env import FiniteVectorEnv
if TYPE_CHECKING:
from .trainer import Trainer
T = TypeVar("T")
_logger = get_module_logger(__name__)
class SeedIteratorNotAvailable(BaseException):
pass
class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]):
"""A ship that contains simulator, interpreter, and policy, will be sent to trainer.
This class controls algorithm-related parts of training, while trainer is responsible for runtime part.
The ship also defines the most important logic of the core training part,
and (optionally) some callbacks to insert customized logics at specific events.
"""
simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]]
state_interpreter: StateInterpreter[StateType, ObsType]
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType]
policy: BasePolicy
reward: Reward
trainer: Trainer
def assign_trainer(self, trainer: Trainer) -> None:
self.trainer = weakref.proxy(trainer) # type: ignore
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for training.
If the iterable is a context manager, the whole training will be invoked in the with-block,
and the iterator will be automatically closed after the training is done."""
raise SeedIteratorNotAvailable("Seed iterator for training is not available.")
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for validation."""
raise SeedIteratorNotAvailable("Seed iterator for validation is not available.")
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for testing."""
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
def train(self, vector_env: BaseVectorEnv) -> dict[str, Any]:
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
raise NotImplementedError()
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
"""Implement this to validate the policy once."""
raise NotImplementedError()
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
"""Implement this to evaluate the policy on test environment once."""
raise NotImplementedError()
def log(self, name: str, value: Any) -> None:
# FIXME: this is a workaround to make the log at least show somewhere.
# Need a refactor in logger to formalize this.
if isinstance(value, (np.ndarray, list)):
value = np.mean(value)
_logger.info(f"[Iter {self.trainer.current_iter + 1}] {name} = {value}")
def log_dict(self, data: dict[str, Any]) -> None:
for name, value in data.items():
self.log(name, value)
def state_dict(self) -> dict:
"""Return a checkpoint of current vessel state."""
return {"policy": self.policy.state_dict()}
def load_state_dict(self, state_dict: dict) -> None:
"""Restore a checkpoint from a previously saved state dict."""
self.policy.load_state_dict(state_dict["policy"])
class TrainingVessel(TrainingVesselBase):
"""The default implementation of training vessel.
``__init__`` accepts a sequence of initial states so that iterator can be created.
``train``, ``validate``, ``test`` each do one collect (and also update in train).
By default, the train initial states will be repeated infinitely during training,
and collector will control the number of episodes for each iteration.
In validation and testing, the val / test initial states will be used exactly once.
Extra hyper-parameters (only used in train) include:
- ``buffer_size``: Size of replay buffer.
- ``episode_per_iter``: Episodes per collect at training. Can be overridden by fast dev run.
- ``update_kwargs``: Keyword arguments appearing in ``policy.update``.
For example, ``dict(repeat=10, batch_size=64)``.
"""
def __init__(
self,
*,
simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]],
state_interpreter: StateInterpreter[StateType, ObsType],
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
policy: BasePolicy,
reward: Reward,
train_initial_states: Sequence[InitialStateType] | None = None,
val_initial_states: Sequence[InitialStateType] | None = None,
test_initial_states: Sequence[InitialStateType] | None = None,
buffer_size: int = 20000,
episode_per_iter: int = 1000,
update_kwargs: dict[str, Any] = cast(Dict[str, Any], None),
):
self.simulator_fn = simulator_fn # type: ignore
self.state_interpreter = state_interpreter
self.action_interpreter = action_interpreter
self.policy = policy
self.reward = reward
self.train_initial_states = train_initial_states
self.val_initial_states = val_initial_states
self.test_initial_states = test_initial_states
self.buffer_size = buffer_size
self.episode_per_iter = episode_per_iter
self.update_kwargs = update_kwargs or {}
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
if self.train_initial_states is not None:
_logger.info("Training initial states collection size: %d", len(self.train_initial_states))
# Implement fast_dev_run here.
train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run)
return DataQueue(train_initial_states, repeat=-1, shuffle=True)
return super().train_seed_iterator()
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
if self.val_initial_states is not None:
_logger.info("Validation initial states collection size: %d", len(self.val_initial_states))
val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run)
return DataQueue(val_initial_states, repeat=1)
return super().val_seed_iterator()
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
if self.test_initial_states is not None:
_logger.info("Testing initial states collection size: %d", len(self.test_initial_states))
test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run)
return DataQueue(test_initial_states, repeat=1)
return super().test_seed_iterator()
def train(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
"""Create a collector and collects ``episode_per_iter`` episodes.
Update the policy on the collected replay buffer.
"""
self.policy.train()
with vector_env.collector_guard():
collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)))
# Number of episodes collected in each training iteration can be overridden by fast dev run.
if self.trainer.fast_dev_run is not None:
episodes = self.trainer.fast_dev_run
else:
episodes = self.episode_per_iter
col_result = collector.collect(n_episode=episodes)
update_result = self.policy.update(sample_size=0, buffer=collector.buffer, **self.update_kwargs)
res = {**col_result, **update_result}
self.log_dict(res)
return res
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
self.policy.eval()
with vector_env.collector_guard():
test_collector = Collector(self.policy, vector_env)
res = test_collector.collect(n_step=INF * len(vector_env))
self.log_dict(res)
return res
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
self.policy.eval()
with vector_env.collector_guard():
test_collector = Collector(self.policy, vector_env)
res = test_collector.collect(n_step=INF * len(vector_env))
self.log_dict(res)
return res
@staticmethod
def _random_subset(name: str, collection: Sequence[T], size: int | None) -> Sequence[T]:
if size is None:
# Size = None -> original collection
return collection
order = np.random.permutation(len(collection))
res = [collection[o] for o in order[:size]]
_logger.info(
"Fast running in development mode. Cut %s initial states from %d to %d.", name, len(collection), len(res)
)
return res

View File

@@ -145,7 +145,9 @@ class DataQueue(Generic[T]):
def __iter__(self):
if not self._activated:
raise ValueError(
"Need to call activate() to launch a daemon worker " "to produce data into data queue before using it."
"Need to call activate() to launch a daemon worker "
"to produce data into data queue before using it. "
"You probably have forgotten to use the DataQueue in a with block."
)
return self._consumer()
@@ -161,19 +163,21 @@ class DataQueue(Generic[T]):
# pytorch dataloader is used here only because we need its sampler and multi-processing
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
dataloader = DataLoader(
cast(Dataset[T], self.dataset),
batch_size=None,
num_workers=self.producer_num_workers,
shuffle=self.shuffle,
collate_fn=lambda t: t, # identity collate fn
)
repeat = 10**18 if self.repeat == -1 else self.repeat
for _rep in range(repeat):
for data in dataloader:
if self._done.value:
# Already done.
return
self._queue.put(data)
_logger.debug(f"Dataloader loop done. Repeat {_rep}.")
self.mark_as_done()
try:
dataloader = DataLoader(
cast(Dataset[T], self.dataset),
batch_size=None,
num_workers=self.producer_num_workers,
shuffle=self.shuffle,
collate_fn=lambda t: t, # identity collate fn
)
repeat = 10**18 if self.repeat == -1 else self.repeat
for _rep in range(repeat):
for data in dataloader:
if self._done.value:
# Already done.
return
self._queue.put(data)
_logger.debug(f"Dataloader loop done. Repeat {_rep}.")
finally:
self.mark_as_done()

View File

@@ -120,12 +120,19 @@ class FiniteVectorEnv(BaseVectorEnv):
from child workers. See :class:`qlib.rl.utils.LogWriter`.
"""
_logger: list[LogWriter]
def __init__(
self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any
self, logger: LogWriter | list[LogWriter] | None, env_fns: list[Callable[..., gym.Env]], **kwargs: Any
) -> None:
super().__init__(env_fns, **kwargs)
self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger]
if isinstance(logger, list):
self._logger = logger
elif isinstance(logger, LogWriter):
self._logger = [logger]
else:
self._logger = []
self._alive_env_ids: Set[int] = set()
self._reset_alive_envs()
self._default_obs = self._default_info = self._default_rew = None
@@ -177,7 +184,7 @@ class FiniteVectorEnv(BaseVectorEnv):
1. Catch and ignore the StopIteration exception, which is the stopping signal
thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit.
2. Notify the loggers that the collect is done what it's done.
2. Notify the loggers that the collect is ready / done what it's ready / done.
Examples
--------
@@ -186,6 +193,9 @@ class FiniteVectorEnv(BaseVectorEnv):
"""
self._collector_guarded = True
for logger in self._logger:
logger.on_env_all_ready()
try:
yield self
except StopIteration:
@@ -298,7 +308,21 @@ def vectorize_env(
concurrency: int,
logger: LogWriter | list[LogWriter],
) -> FiniteVectorEnv:
"""Helper function to create a vector env.
"""Helper function to create a vector env. Can be used to replace usual VectorEnv.
For example, once you wrote: ::
DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])
Now you can replace it with: ::
finite_env_factory(lambda: gym.make(task), "dummy", env_num, my_logger)
By doing such replacement, you have two additional features enabled (compared to normal VectorEnv):
1. The vector env will check for NaN observation and kill the worker when its found.
See :class:`FiniteVectorEnv` for why we need this.
2. A logger to explicit collect logs from environment workers.
Parameters
----------

View File

@@ -12,13 +12,16 @@ in each worker, and writes them to console, log files, or tensorboard...
The two modules communicate by the "log" field in "info" returned by ``env.step()``.
"""
# NOTE: This file contains many hardcoded / ad-hoc rules.
# Refactoring it will be one of the future tasks.
from __future__ import annotations
import logging
from collections import defaultdict
from enum import IntEnum
from pathlib import Path
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence, Callable
import numpy as np
import pandas as pd
@@ -29,7 +32,7 @@ if TYPE_CHECKING:
from .env_wrapper import InfoDict
__all__ = ["LogCollector", "LogWriter", "LogLevel", "ConsoleWriter", "CsvWriter"]
__all__ = ["LogCollector", "LogWriter", "LogLevel", "LogBuffer", "ConsoleWriter", "CsvWriter"]
ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
@@ -175,18 +178,53 @@ class LogWriter(Generic[ObsType, ActType]):
self.clear()
def clear(self):
"""Clear all the metrics for a fresh start.
To make the logger instance reusable.
"""
self.episode_count = self.step_count = 0
self.active_env_ids = set()
self.logs = []
def aggregation(self, array: Sequence[Any]) -> Any:
def state_dict(self) -> dict:
"""Save the states of the logger to a dict."""
return {
"episode_count": self.episode_count,
"step_count": self.step_count,
"global_step": self.global_step,
"global_episode": self.global_episode,
"active_env_ids": self.active_env_ids,
"episode_lengths": self.episode_lengths,
"episode_rewards": self.episode_rewards,
"episode_logs": self.episode_logs,
}
def load_state_dict(self, state_dict: dict) -> None:
"""Load the states of current logger from a dict."""
self.episode_count = state_dict["episode_count"]
self.step_count = state_dict["step_count"]
self.global_step = state_dict["global_step"]
self.global_episode = state_dict["global_episode"]
# These are runtime infos.
# Though they are loaded, I don't think it really helps.
self.active_env_ids = state_dict["active_env_ids"]
self.episode_lenghts = state_dict["episode_lengths"]
self.episode_rewards = state_dict["episode_rewards"]
self.episode_logs = state_dict["episode_logs"]
def aggregation(self, array: Sequence[Any], name: str | None = None) -> Any:
"""Aggregation function from step-wise to episode-wise.
If it's a sequence of float, take the mean.
Otherwise, take the first element.
If a name is specified and,
- if it's ``reward``, the reduction will be sum.
"""
assert len(array) > 0, "The aggregated array must be not empty."
if all(isinstance(v, float) for v in array):
if name == "reward":
return np.sum(array)
return np.mean(array)
else:
return array[0]
@@ -253,10 +291,93 @@ class LogWriter(Generic[ObsType, ActType]):
self.episode_rewards[env_id] = []
self.episode_logs[env_id] = []
def on_env_all_ready(self) -> None:
"""When all environments are ready to run.
Usually, loggers should be reset here.
"""
self.clear()
def on_env_all_done(self) -> None:
"""All done. Time for cleanup."""
class LogBuffer(LogWriter):
"""Keep all numbers in memory.
Objects that can't be aggregated like strings, tensors, images can't be stored in the buffer.
To persist them, please use :class:`PickleWriter`.
Every time, Log buffer receives a new metric, the callback is triggered,
which is useful when tracking metrics inside a trainer.
Parameters
----------
callback
A callback receiving three arguments:
- on_episode: Whether it's called at the end of an episode
- on_collect: Whether it's called at the end of a collect
- log_buffer: the :class:`LogBbuffer`object
No return value is expected.
"""
# FIXME: needs a metric count
def __init__(self, callback: Callable[[bool, bool, LogBuffer], None], loglevel: int | LogLevel = LogLevel.PERIODIC):
super().__init__(loglevel)
self.callback = callback
def state_dict(self) -> dict:
return {
**super().state_dict(),
"latest_metrics": self._latest_metrics,
"aggregated_metrics": self._aggregated_metrics,
}
def load_state_dict(self, state_dict: dict) -> None:
self._latest_metrics = state_dict["latest_metrics"]
self._aggregated_metrics = state_dict["aggregated_metrics"]
return super().load_state_dict(state_dict)
def clear(self):
super().clear()
self._latest_metrics: dict[str, float] | None = None
self._aggregated_metrics: dict[str, float] = defaultdict(float)
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
# FIXME Dup of ConsoleWriter
episode_wise_contents: dict[str, list] = defaultdict(list)
for step_contents in contents:
for name, value in step_contents.items():
# FIXME This could be false-negative for some numpy types
if isinstance(value, float):
episode_wise_contents[name].append(value)
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values, name) # type: ignore
self._aggregated_metrics[name] += logs[name]
self._latest_metrics = logs
self.callback(True, False, self)
def on_env_all_done(self) -> None:
# This happens when collect exits
self.callback(False, True, self)
def episode_metrics(self) -> dict[str, float]:
"""Retrieve the numeric metrics of the latest episode."""
if self._latest_metrics is None:
raise ValueError("No episode metrics available yet.")
return self._latest_metrics
def collect_metrics(self) -> dict[str, float]:
"""Retrieve the aggregated metrics of the latest collect."""
return {name: value / self.episode_count for name, value in self._aggregated_metrics.items()}
class ConsoleWriter(LogWriter):
"""Write log messages to console periodically.
@@ -289,6 +410,8 @@ class ConsoleWriter(LogWriter):
self.console_logger = get_module_logger(__name__, level=logging.INFO)
# FIXME: save & reload
def clear(self):
super().clear()
# Clear average meters
@@ -308,7 +431,7 @@ class ConsoleWriter(LogWriter):
# This should be done at every step, regardless of periodic or not.
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values) # type: ignore
logs[name] = self.aggregation(values, name) # type: ignore
for name, value in logs.items():
self.metric_counts[name] += 1
@@ -350,6 +473,8 @@ class CsvWriter(LogWriter):
all_records: list[dict[str, Any]]
# FIXME: save & reload
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
super().__init__(loglevel)
self.output_dir = output_dir
@@ -370,7 +495,7 @@ class CsvWriter(LogWriter):
logs: dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values) # type: ignore
logs[name] = self.aggregation(values, name) # type: ignore
self.all_records.append(logs)
@@ -392,7 +517,3 @@ class TensorboardWriter(LogWriter):
class MlflowWriter(LogWriter):
"""Add logs to mlflow."""
class LogBuffer(LogWriter):
"""Keep everything in memory."""

View File

@@ -1,17 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import TYPE_CHECKING
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Generator, Optional
if TYPE_CHECKING:
from qlib.backtest.exchange import Exchange
from qlib.backtest.position import BasePosition
from typing import Tuple, Union
from ..backtest.decision import BaseTradeDecision
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
from ..rl.interpreter import ActionInterpreter, StateInterpreter
from ..utils import init_instance_by_config
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
from ..backtest.decision import BaseTradeDecision
__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"]
@@ -25,12 +28,13 @@ class BaseStrategy:
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
trade_exchange: Exchange = None,
):
) -> None:
"""
Parameters
----------
outer_trade_decision : BaseTradeDecision, optional
the trade decision of outer strategy which this strategy relies, and it will be traded in [start_time, end_time], by default None
the trade decision of outer strategy which this strategy relies, and it will be traded in
[start_time, end_time], by default None
- If the strategy is used to split trade decision, it will be used
- If the strategy is used for portfolio management, it can be ignored
level_infra : LevelInfrastructure, optional
@@ -41,9 +45,10 @@ class BaseStrategy:
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
- It allowes different trade_exchanges is used in different executions.
- It allows different trade_exchanges is used in different executions.
- For example:
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is
recommended because it run faster.
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
"""
@@ -63,13 +68,13 @@ class BaseStrategy:
"""get trade exchange in a prioritized order"""
return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange")
def reset_level_infra(self, level_infra: LevelInfrastructure):
def reset_level_infra(self, level_infra: LevelInfrastructure) -> None:
if not hasattr(self, "level_infra"):
self.level_infra = level_infra
else:
self.level_infra.update(level_infra)
def reset_common_infra(self, common_infra: CommonInfrastructure):
def reset_common_infra(self, common_infra: CommonInfrastructure) -> None:
if not hasattr(self, "common_infra"):
self.common_infra: CommonInfrastructure = common_infra
else:
@@ -79,9 +84,9 @@ class BaseStrategy:
self,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
outer_trade_decision=None,
**kwargs,
):
outer_trade_decision: BaseTradeDecision = None,
**kwargs, # TODO: remove this?
) -> None:
"""
- reset `level_infra`, used to reset trade calendar, .etc
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
@@ -89,18 +94,20 @@ class BaseStrategy:
**NOTE**:
split this function into `reset` and `_reset` will make following cases more convenient
1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset` called
when initialization
1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset`
called when initialization
"""
self._reset(
level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision, **kwargs
level_infra=level_infra,
common_infra=common_infra,
outer_trade_decision=outer_trade_decision,
)
def _reset(
self,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
outer_trade_decision=None,
outer_trade_decision: BaseTradeDecision = None,
):
"""
Please refer to the docs of `reset`
@@ -114,7 +121,11 @@ class BaseStrategy:
if outer_trade_decision is not None:
self.outer_trade_decision = outer_trade_decision
def generate_trade_decision(self, execute_result=None):
@abstractmethod
def generate_trade_decision(
self,
execute_result: list = None,
) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]:
"""Generate trade decision in each trading bar
Parameters
@@ -125,9 +136,11 @@ class BaseStrategy:
"""
raise NotImplementedError("generate_trade_decision is not implemented!")
@staticmethod
def update_trade_decision(
self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager
) -> Union[BaseTradeDecision, None]:
trade_decision: BaseTradeDecision,
trade_calendar: TradeCalendarManager,
) -> Optional[BaseTradeDecision]:
"""
update trade decision in each step of inner execution, this method enable all order
@@ -145,7 +158,8 @@ class BaseStrategy:
# default to return None, which indicates that the trade decision is not changed
return None
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision):
# FIXME: do not define this method as an abstract one since it is never implemented
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
"""
A method for updating the outer_trade_decision.
The outer strategy may change its decision during updating.
@@ -154,6 +168,10 @@ class BaseStrategy:
----------
outer_trade_decision : BaseTradeDecision
the decision updated by the outer strategy
Returns
-------
BaseTradeDecision
"""
# default to reset the decision directly
# NOTE: normally, user should do something to the strategy due to the change of outer decision
@@ -200,7 +218,7 @@ class RLStrategy(BaseStrategy):
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
) -> None:
"""
Parameters
----------
@@ -223,7 +241,7 @@ class RLIntStrategy(RLStrategy):
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
) -> None:
"""
Parameters
----------
@@ -242,7 +260,7 @@ class RLIntStrategy(RLStrategy):
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
def generate_trade_decision(self, execute_result=None):
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
_interpret_state = self.state_interpreter.interpret(execute_result=execute_result)
_action = self.policy.step(_interpret_state)
_trade_decision = self.action_interpreter.interpret(action=_action)

View File

@@ -16,7 +16,7 @@ from qlib.utils import exists_qlib_data
class GetData:
DATASET_VERSION = "v2"
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
REMOTE_URL = "https://qlibpublic.blob.core.windows.net/data/default/stock_data"
QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip"
def __init__(self, delete_zip_file=False):

View File

@@ -376,7 +376,7 @@ get_cls_kwargs = get_callable_kwargs # NOTE: this is for compatibility for the
def init_instance_by_config(
config: Union[str, dict, object, Path],
config: Union[str, dict, object, Path], # TODO: use a user-defined type to replace this Union.
default_module=None,
accept_types: Union[type, Tuple[type]] = (),
try_kwargs: Dict = {},
@@ -949,6 +949,10 @@ def auto_filter_kwargs(func: Callable, warning=True) -> Callable:
The decrated function will ignore and give warning when the parameter is not acceptable
For example, if you have a function `f` which may optionally consume the keywards `bar`.
then you can call it by `auto_filter_kwargs(f)(bar=3)`, which will automatically filter out
`bar` when f does not need bar
Parameters
----------
func : Callable
@@ -1063,4 +1067,5 @@ __all__ = [
"unpack_archive_with_buffer",
"get_tmp_file_with_buffer",
"set_log_with_config",
"init_instance_by_config",
]

View File

@@ -9,6 +9,8 @@ Motivation of index_data
`index_data` try to behave like pandas (some API will be different because we try to be simpler and more intuitive) but don't compromise the performance. It provides the basic numpy data and simple indexing feature. If users call APIs which may compromise the performance, index_data will raise Errors.
"""
from __future__ import annotations
from typing import Dict, Tuple, Union, Callable, List
import bisect
@@ -16,7 +18,7 @@ import numpy as np
import pandas as pd
def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
def concat(data_list: Union[SingleData], axis=0) -> MultiData:
"""concat all SingleData by index.
TODO: now just for SingleData.
@@ -52,7 +54,7 @@ def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
raise ValueError(f"axis must be 0 or 1")
def sum_by_index(data_list: Union["SingleData"], new_index: list, fill_value=0) -> "SingleData":
def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) -> SingleData:
"""concat all SingleData by new index.
Parameters
@@ -554,7 +556,7 @@ class SingleData(IndexData):
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
)
def reindex(self, index: Index, fill_value=np.NaN):
def reindex(self, index: Index, fill_value=np.NaN) -> SingleData:
"""reindex data and fill the missing value with np.NaN.
Parameters
@@ -580,7 +582,7 @@ class SingleData(IndexData):
pass
return SingleData(tmp_data, index)
def add(self, other: "SingleData", fill_value=0):
def add(self, other: SingleData, fill_value=0):
# TODO: add and __add__ are a little confusing.
# This could be a more general
common_index = self.index | other.index

View File

@@ -10,6 +10,9 @@ from joblib._parallel_backends import MultiprocessingBackend
import pandas as pd
from queue import Queue
import concurrent
from qlib.config import C, QlibConfig
class ParallelExt(Parallel):
@@ -273,3 +276,40 @@ def complex_parallel(paral: Parallel, complex_iter):
dt.set_res(res)
complex_iter = _recover_dt(complex_iter)
return complex_iter
class call_in_subproc:
"""
When we repeating run functions, it is hard to avoid memory leakage.
So we run it in the subprocess to ensure it is OK.
NOTE: Because local object can't be pickled. So we can't implement it via closure.
We have to implement it via callable Class
"""
def __init__(self, func: Callable, qlib_config: QlibConfig = None):
"""
Parameters
----------
func : Callable
the function to be wrapped
qlib_config : QlibConfig
Qlib config for initialization in subprocess
Returns
-------
Callable
"""
self.func = func
self.qlib_config = qlib_config
def _func_mod(self, *args, **kwargs):
"""Modify the initial function by adding Qlib initialization"""
if self.qlib_config is not None:
C.register_from_C(self.qlib_config)
return self.func(*args, **kwargs)
def __call__(self, *args, **kwargs):
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
return executor.submit(self._func_mod, *args, **kwargs).result()

View File

@@ -131,7 +131,7 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
.. note::
the start_time is not included in the hist_ref
the start_time is not included in the `hist_ref`; So the `hist_ref` will be `step_len - 1` in most cases
loader_cls : type
the class to load the model and dataset
@@ -184,9 +184,9 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
dataset: DatasetH = self.record.load_object("dataset") if unprepared_dataset is None else unprepared_dataset
# Special treatment of historical dependencies
if isinstance(dataset, TSDatasetH):
hist_ref = dataset.step_len
hist_ref = dataset.step_len - 1
else:
hist_ref = 0
hist_ref = 0 # if only the lastest data is used, then only current data will be used and no historical data will be used
else:
hist_ref = self.hist_ref

View File

@@ -2,11 +2,13 @@
# Licensed under the MIT License.
import os
import sys
import mlflow
import logging
import shutil
import pickle
import tempfile
import subprocess
from pathlib import Path
from datetime import datetime
@@ -296,8 +298,32 @@ class MLflowRecorder(Recorder):
# - This may cause delay when uploading results
# - The logging time may not be accurate
self.async_log = AsyncCaller()
# TODO: currently, this is only supported in MLflowRecorder.
# Maybe we can make this feature more general.
self._log_uncommitted_code()
self.log_params(**{"cmd-sys.argv": " ".join(sys.argv)}) # log the command to produce current experiment
return run
def _log_uncommitted_code(self):
"""
Mlflow only log the commit id of the current repo. But usually, user will have a lot of uncommitted changes.
So this tries to automatically to log them all.
"""
# TODO: the sub-directories maybe git repos.
# So it will be better if we can walk the sub-directories and log the uncommitted changes.
for cmd, fname in [
("git diff", "code_diff.txt"),
("git status", "code_status.txt"),
("git diff --cached", "code_cached.txt"),
]:
try:
out = subprocess.check_output(cmd, shell=True)
self.client.log_text(self.id, out.decode(), fname) # this behaves same as above
except subprocess.CalledProcessError:
logger.info(f"Fail to log the uncommitted code of $CWD when run `{cmd}`")
def end_run(self, status: str = Recorder.STATUS_S):
assert status in [
Recorder.STATUS_S,

View File

@@ -169,7 +169,10 @@ class RecorderCollector(Collector):
self.experiment = experiment
self.artifacts_path = artifacts_path
if rec_key_func is None:
rec_key_func = lambda rec: rec.info["id"]
def rec_key_func(rec):
return rec.info["id"]
if artifacts_key is None:
artifacts_key = list(self.artifacts_path.keys())
self.rec_key_func = rec_key_func

View File

@@ -488,7 +488,7 @@ class DumpDataUpdate(DumpDataBase):
except Exception:
error_code[futures[_future]] = traceback.format_exc()
p_bar.update()
logger.info(f"dump bin errors {error_code}")
logger.info(f"dump bin errors: {error_code}")
logger.info("end of features dump.\n")

View File

@@ -80,6 +80,8 @@ REQUIRED = [
"filelock",
"jinja2<3.1.0", # for passing the readthedocs workflow.
"gym",
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
"protobuf<=3.20.1;python_version<='3.8'",
]
# Numpy include
@@ -135,10 +137,24 @@ setup(
"sphinx",
"sphinx_rtd_theme",
"pre-commit",
# CI dependencies
"wheel",
"setuptools",
"black",
"pylint",
"mypy",
"flake8",
"readthedocs_sphinx_ext",
"cmake",
"lxml",
"baostock",
"yahooquery",
"beautifulsoup4",
"tianshou",
"gym>=0.24", # If you do not put gym at the end, gym will degrade causing pytest results to fail.
],
"rl": [
"tianshou",
"gym",
"torch",
],
},

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
import unittest
from qlib.backtest import backtest, decision
from qlib.backtest import backtest
from qlib.tests import TestAutoData
import pandas as pd
from pathlib import Path
@@ -52,13 +52,12 @@ class FileStrTest(TestAutoData):
factor = df["$factor"].item()
price_unit = price / factor * 100
dealt_num_for_1000 = (account_money // price_unit) * (100 / factor)
print(price, factor, price_unit, dealt_num_for_1000)
# 2) generate orders
orders = self._gen_orders(dealt_num_for_1000)
print(orders)
orders.to_csv(self.EXAMPLE_FILE)
orders = pd.read_csv(self.EXAMPLE_FILE, index_col=["datetime", "instrument"])
print(orders)
# 3) run the strategy
strategy_config = {
@@ -101,7 +100,11 @@ class FileStrTest(TestAutoData):
},
},
}
report_dict, indicator_dict = backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
report_dict, indicator_dict = backtest(
executor=executor_config,
strategy=strategy_config,
**backtest_config,
)
# ffr valid
ffr_dict = indicator_dict["1day"]["ffr"].to_dict()

View File

@@ -1,4 +1,6 @@
[pytest]
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
filterwarnings =
ignore:.*rng.randint:DeprecationWarning
ignore:.*Casting input x to numpy array:UserWarning

View File

@@ -81,7 +81,7 @@ def test_simple_env_logger(caplog):
line = line.strip()
if line:
line_counter += 1
assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert line_counter >= 3

View File

@@ -17,7 +17,7 @@ from qlib.backtest import Order
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.data import pickle_styled
from qlib.rl.entries.test import backtest
from qlib.rl.trainer import backtest, train
from qlib.rl.order_execution import *
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
@@ -306,3 +306,26 @@ def test_cn_ppo_strategy():
assert np.isclose(metrics["pa"].mean(), -16.21578303474833)
assert np.isclose(metrics["market_price"].mean(), 58.68277690875527)
assert np.isclose(metrics["trade_price"].mean(), 58.76063985000002)
def test_ppo_train():
set_log_with_config(C.logging_config)
# The data starts with 9:31 and ends with 15:00
orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58"))
assert len(orders) == 40
state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6)
action_interp = CategoricalActionInterpreter(4)
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
train(
partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
policy,
PAPenaltyReward(),
vessel_kwargs={"episode_per_iter": 100, "update_kwargs": {"batch_size": 64, "repeat": 5}},
trainer_kwargs={"max_iters": 2, "loggers": ConsoleWriter(total_episodes=100)},
)

202
tests/rl/test_trainer.py Normal file
View File

@@ -0,0 +1,202 @@
import os
import random
import sys
from pathlib import Path
import pytest
import torch
import torch.nn as nn
from gym import spaces
from tianshou.policy import PPOPolicy
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.simulator import Simulator
from qlib.rl.reward import Reward
from qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
class ZeroSimulator(Simulator):
def __init__(self, *args, **kwargs):
self.action = self.correct = 0
def step(self, action):
self.action = action
self.correct = action == 0
self._done = random.choice([False, True])
if self._done:
self.env.logger.add_scalar("acc", self.correct * 100)
def get_state(self):
return {
"acc": self.correct * 100,
"action": self.action,
}
def done(self) -> bool:
return self._done
class NoopStateInterpreter(StateInterpreter):
observation_space = spaces.Dict(
{
"acc": spaces.Discrete(200),
"action": spaces.Discrete(2),
}
)
def interpret(self, simulator_state):
return simulator_state
class NoopActionInterpreter(ActionInterpreter):
action_space = spaces.Discrete(2)
def interpret(self, simulator_state, action):
return action
class AccReward(Reward):
def reward(self, simulator_state):
if self.env.status["done"]:
return simulator_state["acc"] / 100
return 0.0
class PolicyNet(nn.Module):
def __init__(self, out_features=1, return_state=False):
super().__init__()
self.fc = nn.Linear(32, out_features)
self.return_state = return_state
def forward(self, obs, state=None, **kwargs):
res = self.fc(torch.randn(obs["acc"].shape[0], 32))
if self.return_state:
return nn.functional.softmax(res, dim=-1), state
else:
return res
def _ppo_policy():
actor = PolicyNet(2, True)
critic = PolicyNet()
policy = PPOPolicy(
actor,
critic,
torch.optim.Adam(tuple(actor.parameters()) + tuple(critic.parameters())),
torch.distributions.Categorical,
action_space=NoopActionInterpreter().action_space,
)
return policy
def test_trainer():
set_log_with_config(C.logging_config)
trainer = Trainer(max_iters=10, finite_env_type="subproc")
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.current_iter == 10
assert trainer.current_episode == 5000
assert abs(trainer.metrics["acc"] - trainer.metrics["reward"] * 100) < 1e-4
assert trainer.metrics["acc"] > 80
trainer.test(vessel)
assert trainer.metrics["acc"] > 60
def test_trainer_fast_dev_run():
set_log_with_config(C.logging_config)
trainer = Trainer(max_iters=2, fast_dev_run=2, finite_env_type="shmem")
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.current_episode == 4
def test_trainer_earlystop():
# TODO this is just sanity check.
# need to see the logs to check whether it works.
set_log_with_config(C.logging_config)
trainer = Trainer(
max_iters=10,
val_every_n_iters=1,
finite_env_type="dummy",
callbacks=[EarlyStopping("val/reward", restore_best_weights=True)],
)
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=500,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert trainer.metrics["val/acc"] > 30
assert trainer.current_iter == 2 # second iteration
def test_trainer_checkpoint():
set_log_with_config(C.logging_config)
output_dir = Path(__file__).parent / ".output"
trainer = Trainer(max_iters=2, finite_env_type="dummy", callbacks=[Checkpoint(output_dir, every_n_iters=1)])
policy = _ppo_policy()
vessel = TrainingVessel(
simulator_fn=lambda init: ZeroSimulator(init),
state_interpreter=NoopStateInterpreter(),
action_interpreter=NoopActionInterpreter(),
policy=policy,
train_initial_states=list(range(100)),
val_initial_states=list(range(10)),
test_initial_states=list(range(10)),
reward=AccReward(),
episode_per_iter=100,
update_kwargs=dict(repeat=10, batch_size=64),
)
trainer.fit(vessel)
assert (output_dir / "001.pth").exists()
assert (output_dir / "002.pth").exists()
assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth")
trainer.load_state_dict(torch.load(output_dir / "001.pth"))
assert trainer.current_iter == 1
assert trainer.current_episode == 100
# Reload the checkpoint at first iteration
trainer.fit(vessel, ckpt_path=output_dir / "001.pth")

View File

@@ -1,5 +1,6 @@
import copy
import unittest
import pytest
import fire
import pandas as pd
@@ -14,6 +15,7 @@ from qlib.workflow.online.update import LabelUpdater
class TestRolling(TestAutoData):
@pytest.mark.slow
def test_update_pred(self):
"""
This test is for testing if it will raise error if the `to_date` is out of the boundary.
@@ -73,6 +75,7 @@ class TestRolling(TestAutoData):
# this range is fixed now
self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item())
@pytest.mark.slow
def test_update_label(self):
task = copy.deepcopy(CSI300_GBDT_TASK)

View File

@@ -4,6 +4,7 @@
import sys
import shutil
import unittest
import pytest
from pathlib import Path
import qlib
@@ -184,16 +185,19 @@ class TestAllFlow(TestAutoData):
def tearDownClass(cls) -> None:
shutil.rmtree(cls.URI_PATH.lstrip("file:"))
@pytest.mark.slow
def test_0_train_with_sigana(self):
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana(self.URI_PATH)
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
@pytest.mark.slow
def test_1_train(self):
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train(self.URI_PATH)
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
@pytest.mark.slow
def test_2_backtest(self):
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID, self.URI_PATH)
self.assertGreaterEqual(
@@ -203,6 +207,7 @@ class TestAllFlow(TestAutoData):
)
self.assertTrue(not analyze_df.isna().any().any(), "backtest failed")
@pytest.mark.slow
def test_3_expmanager(self):
pass_default, pass_current, uri_path = fake_experiment()
self.assertTrue(pass_default, msg="default uri is incorrect")

View File

@@ -4,6 +4,7 @@
from qlib.workflow.record_temp import SignalRecord
import shutil
import unittest
import pytest
from pathlib import Path
from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord
@@ -47,9 +48,11 @@ class TestAllFlow(TestAutoData):
def tearDownClass(cls) -> None:
shutil.rmtree(cls.URI_PATH.lstrip("file:"))
@pytest.mark.slow
def test_0_multiseg(self):
uri_path = train_multiseg(self.URI_PATH)
@pytest.mark.slow
def test_1_mse(self):
uri_path = train_mse(self.URI_PATH)

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import unittest
import pytest
import sys
from qlib.tests import TestAutoData
from qlib.data.dataset import TSDatasetH
@@ -11,6 +12,7 @@ from qlib.data.dataset.handler import DataHandlerLP
class TestDataset(TestAutoData):
@pytest.mark.slow
def testTSDataset(self):
tsdh = TSDatasetH(
handler={

View File

@@ -6,12 +6,13 @@ import sys
import qlib
import shutil
import unittest
import pytest
import pandas as pd
import baostock as bs
from pathlib import Path
from qlib.data import D
from scripts.get_data import GetData
from qlib.tests.data import GetData
from scripts.dump_pit import DumpPitData
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts/data_collector/pit")))
@@ -39,17 +40,21 @@ class TestPIT(unittest.TestCase):
pit_dir = str(SOURCE_DIR.joinpath("pit").resolve())
pit_normalized_dir = str(SOURCE_DIR.joinpath("pit_normalized").resolve())
GetData().qlib_data(name="qlib_data_simple", target_dir=cn_data_dir, region="cn")
bs.login()
Run(
source_dir=pit_dir,
interval="quarterly",
).download_data(start="2000-01-01", end="2020-01-01", symbol_regex="^(600519|000725).*")
GetData().qlib_data(name="qlib_data", target_dir=pit_dir, region="pit")
# NOTE: This code does the same thing as line 43, but since baostock is not stable in downloading data, we have chosen to download offline data.
# bs.login()
# Run(
# source_dir=pit_dir,
# interval="quarterly",
# ).download_data(start="2000-01-01", end="2020-01-01", symbol_regex="^(600519|000725).*")
# bs.logout()
Run(
source_dir=pit_dir,
normalize_dir=pit_normalized_dir,
interval="quarterly",
).normalize_data()
bs.logout()
DumpPitData(
csv_path=pit_normalized_dir,
qlib_dir=cn_data_dir,
@@ -119,6 +124,7 @@ class TestPIT(unittest.TestCase):
"""
self.check_same(data, expect)
@pytest.mark.slow
def test_expr(self):
fields = [
"P(Mean($$roewa_q, 1))",