mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 09:01:18 +08:00
Compare commits
146 Commits
v0.8.2
...
you-n-g-pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
670ae6aa61 | ||
|
|
7e3ca3c5f4 | ||
|
|
ab0174a363 | ||
|
|
8c72ed99c2 | ||
|
|
b9624b074f | ||
|
|
23c657a7a2 | ||
|
|
9bf3423a64 | ||
|
|
25ecb1135f | ||
|
|
2ca0d88d2d | ||
|
|
50d74b5560 | ||
|
|
a87b02619a | ||
|
|
da676a20a2 | ||
|
|
13d904d9a9 | ||
|
|
36950b905d | ||
|
|
58540f76ee | ||
|
|
3e6e2865ce | ||
|
|
3fcbaa33fa | ||
|
|
50409ff17b | ||
|
|
afcea404a5 | ||
|
|
e24ef67663 | ||
|
|
2d5eecb9a2 | ||
|
|
89972f6c6f | ||
|
|
1ef8e61abd | ||
|
|
1a4114b683 | ||
|
|
e874ef2bc1 | ||
|
|
14b2b355a7 | ||
|
|
64fadff218 | ||
|
|
a02ac95538 | ||
|
|
cc94c32db6 | ||
|
|
9a40fd3cdc | ||
|
|
c4281121e3 | ||
|
|
2de9903200 | ||
|
|
2cf842bcfe | ||
|
|
9e381493c2 | ||
|
|
a73b60d05a | ||
|
|
64979ad769 | ||
|
|
c5cf8fb9cc | ||
|
|
5d579d1a20 | ||
|
|
3c9c76b384 | ||
|
|
9d0a8f61d1 | ||
|
|
701b18af1b | ||
|
|
84ff662a26 | ||
|
|
00e40e775b | ||
|
|
45fe5e6974 | ||
|
|
366a9c33f3 | ||
|
|
982e0da715 | ||
|
|
cd5e5d5235 | ||
|
|
caea495f40 | ||
|
|
d934c8caba | ||
|
|
a139986f4e | ||
|
|
12c3de42d0 | ||
|
|
fe0f9427f2 | ||
|
|
a973e4fb66 | ||
|
|
c60366addd | ||
|
|
41447f320b | ||
|
|
e1271a83f7 | ||
|
|
30b531086c | ||
|
|
87926513cb | ||
|
|
7bfc7e1797 | ||
|
|
85e7cdcac3 | ||
|
|
08fd1d3f42 | ||
|
|
defd6758f6 | ||
|
|
61cc1a3867 | ||
|
|
655ed982cf | ||
|
|
2952c443ca | ||
|
|
7f1293ec34 | ||
|
|
73438807f9 | ||
|
|
962751c72d | ||
|
|
56cfa480dc | ||
|
|
6edd0bf298 | ||
|
|
fe155703b0 | ||
|
|
3c4f4bfd44 | ||
|
|
5200ff520a | ||
|
|
30e457119c | ||
|
|
243e516cf1 | ||
|
|
e229b567ad | ||
|
|
f129bfef5d | ||
|
|
9dd5e07819 | ||
|
|
00ed35fc1b | ||
|
|
3f53a097b0 | ||
|
|
fb230a8097 | ||
|
|
ff4724e248 | ||
|
|
73d90f7f44 | ||
|
|
b7988e6428 | ||
|
|
8efc8b92ef | ||
|
|
f2a5ecd98a | ||
|
|
705354cc28 | ||
|
|
1b5d0d4d6d | ||
|
|
f4a481945b | ||
|
|
5f18ba7970 | ||
|
|
2681c61c60 | ||
|
|
776b0c5bb4 | ||
|
|
829ad9f5e9 | ||
|
|
921c13cc90 | ||
|
|
0f519f6053 | ||
|
|
2ed806c846 | ||
|
|
d2f0bebf60 | ||
|
|
615a381038 | ||
|
|
568a88fddb | ||
|
|
058f976727 | ||
|
|
faa99f30fa | ||
|
|
837067b9e1 | ||
|
|
3a911bc09b | ||
|
|
90be21bb40 | ||
|
|
7540b1257b | ||
|
|
57f7ed9914 | ||
|
|
9e3d0249f7 | ||
|
|
2ac964c470 | ||
|
|
07f0d4f599 | ||
|
|
ea4fb33ff2 | ||
|
|
ed0c238787 | ||
|
|
80af395b3c | ||
|
|
4dc66932d5 | ||
|
|
40dd84857c | ||
|
|
74cc21fc2c | ||
|
|
ec8969a3ae | ||
|
|
528f74af09 | ||
|
|
d482726f28 | ||
|
|
cfc3e886ed | ||
|
|
60d45ad770 | ||
|
|
0e8b94a552 | ||
|
|
4bf127eba5 | ||
|
|
c149c8616c | ||
|
|
3274e16c95 | ||
|
|
d496cf7476 | ||
|
|
357ee74b6f | ||
|
|
5da5cf5175 | ||
|
|
6a946761cf | ||
|
|
76b7b5f24b | ||
|
|
d7d19feb4e | ||
|
|
bba6972a55 | ||
|
|
18af288692 | ||
|
|
ba056850cb | ||
|
|
aed5b8ebc0 | ||
|
|
79355666a9 | ||
|
|
144e1e2459 | ||
|
|
635632e4ed | ||
|
|
c5834476e2 | ||
|
|
01afd06e18 | ||
|
|
d533219738 | ||
|
|
5b5c99fe75 | ||
|
|
da48f42f3f | ||
|
|
f979dcf5e8 | ||
|
|
97aa16a078 | ||
|
|
094be9be86 | ||
|
|
d9b9386032 |
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -8,7 +8,7 @@
|
||||
<!--- Why is this change required? What problem does it solve? -->
|
||||
|
||||
## How Has This Been Tested?
|
||||
<! --- Put an `x` in all the boxes that apply: --->
|
||||
<!--- Put an `x` in all the boxes that apply: --->
|
||||
- [ ] Pass the test by running: `pytest qlib/tests/test_all_pipeline.py` under upper directory of `qlib`.
|
||||
- [ ] If you are adding a new feature, test on your own test scripts.
|
||||
|
||||
|
||||
3
.github/workflows/python-publish.yml
vendored
3
.github/workflows/python-publish.yml
vendored
@@ -12,7 +12,8 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest, macos-11]
|
||||
os: [windows-latest, macos-11]
|
||||
# FIXME: macos-latest will raise error now.
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
|
||||
66
.github/workflows/test.yml
vendored
66
.github/workflows/test.yml
vendored
@@ -1,66 +0,0 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install black wheel
|
||||
black qlib -l 120 --check --diff
|
||||
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
pip install numpy==1.19.5 ruamel.yaml
|
||||
pip install pyqlib --ignore-installed
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install -e .
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install black pytest
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
python -m pytest . --durations=10
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
75
.github/workflows/test_macos.yml
vendored
75
.github/workflows/test_macos.yml
vendored
@@ -1,75 +0,0 @@
|
||||
# There are some issues (in the downloading data phase) on MacOS when running with other tests. So we split it into an individual config.
|
||||
name: Test MacOS
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
python -m pip install pip --upgrade
|
||||
python -m pip install wheel --upgrade
|
||||
python -m pip install black
|
||||
python -m black qlib -l 120 --check --diff
|
||||
# Test Qlib installed with pip
|
||||
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
python -m pip install numpy==1.19.5
|
||||
python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
- name: Install Lightgbm for MacOS
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install -e .
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -U pyopenssl idna
|
||||
python -m pip install black pytest
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
python -m pytest . --durations=0
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
57
.github/workflows/test_qlib_from_pip.yml
vendored
Normal file
57
.github/workflows/test_qlib_from_pip.yml
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
name: Test qlib from pip
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
timeout-minutes: 120
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from pip
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Update pip to the latest version
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Qlib installation test
|
||||
run: |
|
||||
python -m pip install pyqlib
|
||||
# Specify the numpy version because the numpy upgrade caused the CI test to fail,
|
||||
# and this line of code will be removed when the next version of qlib is released.
|
||||
python -m pip install "numpy<1.23"
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
|
||||
- name: Downloads dependencies data
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Test workflow by config
|
||||
run: |
|
||||
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
150
.github/workflows/test_qlib_from_source.yml
vendored
Normal file
150
.github/workflows/test_qlib_from_source.yml
vendored
Normal file
@@ -0,0 +1,150 @@
|
||||
name: Test qlib from source
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
timeout-minutes: 120
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Update pip to the latest version
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
run: |
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Installing pytorch for ubuntu
|
||||
if: ${{ matrix.os == 'ubuntu-18.04' || matrix.os == 'ubuntu-20.04' }}
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
- name: Installing pytorch for windows
|
||||
if: ${{ matrix.os == 'windows-latest' }}
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Set up Python tools
|
||||
run: |
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install -e .[dev]
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
black . -l 120 --check --diff
|
||||
|
||||
- name: Make html with sphinx
|
||||
run: |
|
||||
cd docs
|
||||
sphinx-build -b html . build
|
||||
cd ..
|
||||
|
||||
# Check Qlib with pylint
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
# C0209: consider-using-f-string
|
||||
# R0402: consider-using-from-import
|
||||
# R1705: no-else-return
|
||||
# R1710: inconsistent-return-statements
|
||||
# R1725: super-with-arguments
|
||||
# R1735: use-dict-literal
|
||||
# W0102: dangerous-default-value
|
||||
# W0212: protected-access
|
||||
# W0221: arguments-differ
|
||||
# W0223: abstract-method
|
||||
# W0231: super-init-not-called
|
||||
# W0237: arguments-renamed
|
||||
# W0612: unused-variable
|
||||
# W0621: redefined-outer-name
|
||||
# W0622: redefined-builtin
|
||||
# FIXME: specify exception type
|
||||
# W0703: broad-except
|
||||
# W1309: f-string-without-interpolation
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
- name: Check Qlib with pylint
|
||||
run: |
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
|
||||
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
# Description: We have used black to limit the length of each line to 120.
|
||||
# F541 f-string is missing placeholders
|
||||
# Description: The same thing is done when using pylint for detection.
|
||||
# E266 too many leading '#' for block comment
|
||||
# Description: To make the code more readable, a lot of "#" is used.
|
||||
# This error code appears centrally in:
|
||||
# qlib/backtest/executor.py
|
||||
# qlib/data/ops.py
|
||||
# qlib/utils/__init__.py
|
||||
# E402 module level import not at top of file
|
||||
# Description: There are times when module level import is not available at the top of the file.
|
||||
# W503 line break before binary operator
|
||||
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
|
||||
# E731 do not assign a lambda expression, use a def
|
||||
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
|
||||
# E203 whitespace before ':'
|
||||
# Description: If there is whitespace before ":", it cannot pass the black check.
|
||||
- name: Check Qlib with flake8
|
||||
run: |
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
# https://github.com/python/mypy/issues/10600
|
||||
- name: Check Qlib with mypy
|
||||
run: |
|
||||
mypy qlib --install-types --non-interactive || true
|
||||
mypy qlib --verbose
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl /tmp/qlibpublic/data --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
# Version 0.52.0 of numba must be installed manually in CI, otherwise it will cause incompatibility with the latest version of numpy.
|
||||
python -m pip install numba==0.52.0
|
||||
# You must update numpy manually, because when installing python tools, it will try to uninstall numpy and cause CI to fail.
|
||||
python -m pip install --upgrade numpy
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
python -m pytest . -m "not slow" --durations=0
|
||||
56
.github/workflows/test_qlib_from_source_slow.yml
vendored
Normal file
56
.github/workflows/test_qlib_from_source_slow.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Test qlib from source slow
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
timeout-minutes: 120
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source slow
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
run: |
|
||||
pip install --upgrade cython numpy pip
|
||||
pip install -e .[dev]
|
||||
|
||||
- name: Downloads dependencies data
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
uses: nick-fields/retry@v2
|
||||
with:
|
||||
timeout_minutes: 120
|
||||
max_attempts: 3
|
||||
command: |
|
||||
cd tests
|
||||
python -m pytest . -m "slow" --durations=0
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -27,6 +27,10 @@ examples/estimator/estimator_example/
|
||||
|
||||
*.egg-info/
|
||||
|
||||
# test related
|
||||
test-output.xml
|
||||
.output
|
||||
.data
|
||||
|
||||
# special software
|
||||
mlruns/
|
||||
@@ -34,8 +38,10 @@ mlruns/
|
||||
tags
|
||||
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
.idea/
|
||||
|
||||
17
.mypy.ini
Normal file
17
.mypy.ini
Normal file
@@ -0,0 +1,17 @@
|
||||
[mypy]
|
||||
exclude = (?x)(
|
||||
^qlib/backtest/high_performance_ds\.py$
|
||||
| ^qlib/contrib
|
||||
| ^qlib/data
|
||||
| ^qlib/model
|
||||
| ^qlib/strategy
|
||||
| ^qlib/tests
|
||||
| ^qlib/utils
|
||||
| ^qlib/workflow
|
||||
| ^qlib/config\.py$
|
||||
| ^qlib/log\.py$
|
||||
| ^qlib/__init__\.py$
|
||||
)
|
||||
ignore_missing_imports = true
|
||||
disallow_incomplete_defs = true
|
||||
follow_imports = skip
|
||||
12
.pre-commit-config.yaml
Normal file
12
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: ["qlib", "-l 120"]
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 4.0.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: ["--ignore=E501,F541,E266,E402,W503,E731,E203"]
|
||||
5
.pylintrc
Normal file
5
.pylintrc
Normal file
@@ -0,0 +1,5 @@
|
||||
[TYPECHECK]
|
||||
# https://stackoverflow.com/a/53572939
|
||||
# List of members which are set dynamically and missed by Pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed.
|
||||
generated-members=numpy.*, torch.*
|
||||
95
README.md
95
README.md
@@ -11,7 +11,11 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Arctic Provider Backend & Orderbook data example | :hammer: [Rleased](https://github.com/microsoft/qlib/pull/744) on Jan 17, 2022 |
|
||||
| HIST and IGMTF models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1040) on Apr 10, 2022 |
|
||||
| Qlib [notebook tutorial](https://github.com/microsoft/qlib/tree/main/examples/tutorial) | 📖 [Released](https://github.com/microsoft/qlib/pull/1037) on Apr 7, 2022 |
|
||||
| Ibovespa index data | :rice: [Released](https://github.com/microsoft/qlib/pull/990) on Apr 6, 2022 |
|
||||
| Point-in-Time database | :hammer: [Released](https://github.com/microsoft/qlib/pull/343) on Mar 10, 2022 |
|
||||
| Arctic Provider Backend & Orderbook data example | :hammer: [Released](https://github.com/microsoft/qlib/pull/744) on Jan 17, 2022 |
|
||||
| Meta-Learning-based framework & DDG-DA | :chart_with_upwards_trend: :hammer: [Released](https://github.com/microsoft/qlib/pull/743) on Jan 10, 2022 |
|
||||
| Planning-based portfolio optimization | :hammer: [Released](https://github.com/microsoft/qlib/pull/754) on Dec 28, 2021 |
|
||||
| Release Qlib v0.8.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.8.0) on Dec 8, 2021 |
|
||||
@@ -28,7 +32,7 @@ Recent released features
|
||||
| High-frequency data processing example | :hammer: [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
|
||||
| High-frequency trading example | :chart_with_upwards_trend: [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 |
|
||||
| High-frequency data(1min) | :rice: [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 |
|
||||
| Tabnet Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
|
||||
| Tabnet Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
|
||||
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
@@ -45,34 +49,58 @@ With Qlib, users can easily try ideas to create better Quant investment strategi
|
||||
|
||||
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
|
||||
|
||||
- [**Plans**](#plans)
|
||||
- [Framework of Qlib](#framework-of-qlib)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Installation](#installation)
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
|
||||
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
|
||||
- [Main Challenges & Solutions in Quant Research](#main-challenges--solutions-in-quant-research)
|
||||
- [Forecasting: Finding Valuable Signals/Patterns](#forecasting-finding-valuable-signalspatterns)
|
||||
- [**Quant Model (Paper) Zoo**](#quant-model-paper-zoo)
|
||||
- [Run a Single Model](#run-a-single-model)
|
||||
- [Run Multiple Models](#run-multiple-models)
|
||||
- [Adapting to Market Dynamics](#adapting-to-market-dynamics)
|
||||
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
|
||||
- [More About Qlib](#more-about-qlib)
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
|
||||
- [Related Reports](#related-reports)
|
||||
- [Contact Us](#contact-us)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<th>Frameworks, Tutorial, Data & DevOps</th>
|
||||
<th>Main Challenges & Solutions in Quant Research</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<li><a href="#plans"><strong>Plans</strong></a></li>
|
||||
<li><a href="#framework-of-qlib">Framework of Qlib</a></li>
|
||||
<li><a href="#quick-start">Quick Start</a></li>
|
||||
<ul dir="auto">
|
||||
<li type="circle"><a href="#installation">Installation</a> </li>
|
||||
<li type="circle"><a href="#data-preparation">Data Preparation</a></li>
|
||||
<li type="circle"><a href="#auto-quant-research-workflow">Auto Quant Research Workflow</a></li>
|
||||
<li type="circle"><a href="#building-customized-quant-research-workflow-by-code">Building Customized Quant Research Workflow by Code</a></li></ul>
|
||||
<li><a href="#quant-dataset-zoo"><strong>Quant Dataset Zoo</strong></a></li>
|
||||
<li><a href="#more-about-qlib">More About Qlib</a></li>
|
||||
<li><a href="#offline-mode-and-online-mode">Offline Mode and Online Mode</a>
|
||||
<ul>
|
||||
<li type="circle"><a href="#performance-of-qlib-data-server">Performance of Qlib Data Server</a></li></ul>
|
||||
<li><a href="#related-reports">Related Reports</a></li>
|
||||
<li><a href="#contact-us">Contact Us</a></li>
|
||||
<li><a href="#contributing">Contributing</a></li>
|
||||
</td>
|
||||
<td valign="baseline">
|
||||
<li><a href="#main-challenges--solutions-in-quant-research">Main Challenges & Solutions in Quant Research</a>
|
||||
<ul>
|
||||
<li type="circle"><a href="#forecasting-finding-valuable-signalspatterns">Forecasting: Finding Valuable Signals/Patterns</a>
|
||||
<ul>
|
||||
<li type="disc"><a href="#quant-model-paper-zoo"><strong>Quant Model (Paper) Zoo</strong></a>
|
||||
<ul>
|
||||
<li type="circle"><a href="#run-a-single-model">Run a Single Model</a></li>
|
||||
<li type="circle"><a href="#run-multiple-models">Run Multiple Models</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li type="circle"><a href="#adapting-to-market-dynamics">Adapting to Market Dynamics</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
# Plans
|
||||
New features under development(order by estimated release time).
|
||||
Your feedbacks about the features are very important.
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
|
||||
<!-- | Feature | Status | -->
|
||||
<!-- | -- | ------ | -->
|
||||
|
||||
# Framework of Qlib
|
||||
|
||||
@@ -80,7 +108,6 @@ Your feedbacks about the features are very important.
|
||||
<img src="docs/_static/img/framework.svg" />
|
||||
</div>
|
||||
|
||||
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules, and each component could be used stand-alone.
|
||||
|
||||
| Name | Description |
|
||||
@@ -92,6 +119,8 @@ At the module level, Qlib is a platform that consists of the above components. T
|
||||
* The modules with hand-drawn style are under development and will be released in the future.
|
||||
* The modules with dashed borders are highly user-customizable and extendible.
|
||||
|
||||
(p.s. framework image is created with https://draw.io/)
|
||||
|
||||
|
||||
# Quick Start
|
||||
|
||||
@@ -311,6 +340,8 @@ Here is a list of models built on `Qlib`.
|
||||
- [TCN based on pytorch (Shaojie Bai, et al. 2018)](examples/benchmarks/TCN/)
|
||||
- [ADARNN based on pytorch (YunTao Du, et al. 2021)](examples/benchmarks/ADARNN/)
|
||||
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
|
||||
- [IGMTF based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/IGMTF/)
|
||||
- [HIST based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/HIST/)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -359,6 +390,8 @@ Dataset plays a very important role in Quant. Here is a list of the datasets bui
|
||||
Your PR to build new Quant dataset is highly welcomed.
|
||||
|
||||
# More About Qlib
|
||||
If you want to have a quick glance at the most frequently used components of qlib, you can try notebooks [here](examples/tutorial/).
|
||||
|
||||
The detailed documents are organized in [docs](docs/).
|
||||
[Sphinx](http://www.sphinx-doc.org) and the readthedocs theme is required to build the documentation in html formats.
|
||||
```bash
|
||||
@@ -425,7 +458,7 @@ Before we released Qlib as an open-source project on Github in Sep 2020, Qlib is
|
||||
|
||||
This project welcomes contributions and suggestions.
|
||||
**Here are some
|
||||
[code standards](docs/developer/code_standard.rst) for submiting a pull request.**
|
||||
[code standards and development guidance](docs/developer/code_standard_and_dev_guide.rst) for submiting a pull request.**
|
||||
|
||||
Making contributions is not a hard thing. Solving an issue(maybe just answering a question raised in [issues list](https://github.com/microsoft/qlib/issues) or [gitter](https://gitter.im/Microsoft/qlib)), fixing/issuing a bug, improving the documents and even fixing a typo are important contributions to Qlib.
|
||||
|
||||
@@ -441,9 +474,13 @@ If you don't know how to start to contribute, you can refer to the following exa
|
||||
| Docs | [Improve docs quality](https://github.com/microsoft/qlib/pull/797/files) ; [Fix a typo](https://github.com/microsoft/qlib/pull/774) |
|
||||
| Feature | Implement a [requested feature](https://github.com/microsoft/qlib/projects) like [this](https://github.com/microsoft/qlib/pull/754); [Refactor interfaces](https://github.com/microsoft/qlib/pull/539/files) |
|
||||
| Dataset | [Add a dataset](https://github.com/microsoft/qlib/pull/733) |
|
||||
| Models | [Implement a new model](https://github.com/microsoft/qlib/pull/689) |
|
||||
| Models | [Implement a new model](https://github.com/microsoft/qlib/pull/689), [some instructions to contribute models](https://github.com/microsoft/qlib/tree/main/examples/benchmarks#contributing) |
|
||||
|
||||
If you would like to become one of Qlib's maintainers to contribute more (e.g. help merge PR, triage issues), please contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)). We are glad to help you to set the right permission.
|
||||
[Good first issues](https://github.com/microsoft/qlib/labels/good%20first%20issue) are labelled to indicate that they are easy to start your contributions.
|
||||
|
||||
You can find some impefect implementation in Qlib by `rg 'TODO|FIXME' qlib`
|
||||
|
||||
If you would like to become one of Qlib's maintainers to contribute more (e.g. help merge PR, triage issues), please contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)). We are glad to help to upgrade your permission.
|
||||
|
||||
## Licence
|
||||
Most contributions require you to agree to a
|
||||
|
||||
136
docs/advanced/PIT.rst
Normal file
136
docs/advanced/PIT.rst
Normal file
@@ -0,0 +1,136 @@
|
||||
.. _pit:
|
||||
|
||||
===========================
|
||||
(P)oint-(I)n-(T)ime Database
|
||||
===========================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
------------
|
||||
Point-in-time data is a very important consideration when performing any sort of historical market analysis.
|
||||
|
||||
For example, let’s say we are backtesting a trading strategy and we are using the past five years of historical data as our input.
|
||||
Our model is assumed to trade once a day, at the market close, and we’ll say we are calculating the trading signal for 1 January 2020 in our backtest. At that point, we should only have data for 1 January 2020, 31 December 2019, 30 December 2019 etc.
|
||||
|
||||
In financial data (especially financial reports), the same piece of data may be amended for multiple times overtime. If we only use the latest version for historical backtesting, data leakage will happen.
|
||||
Point-in-time database is designed for solving this problem to make sure user get the right version of data at any historical timestamp. It will keep the performance of online trading and historical backtesting the same.
|
||||
|
||||
|
||||
|
||||
Data Preparation
|
||||
----------------
|
||||
|
||||
Qlib provides a crawler to help users to download financial data and then a converter to dump the data in Qlib format.
|
||||
Please follow `scripts/data_collector/pit/README.md <https://github.com/microsoft/qlib/tree/main/scripts/data_collector/pit/>`_ to download and convert data.
|
||||
Besides, you can find some additional usage examples there.
|
||||
|
||||
|
||||
File-based design for PIT data
|
||||
------------------------------
|
||||
|
||||
Qlib provides a file-based storage for PIT data.
|
||||
|
||||
For each feature, it contains 4 columns, i.e. date, period, value, _next.
|
||||
Each row corresponds to a statement.
|
||||
|
||||
The meaning of each feature with filename like `XXX_a.data`:
|
||||
|
||||
- `date`: the statement's date of publication.
|
||||
- `period`: the period of the statement. (e.g. it will be quarterly frequency in most of the markets)
|
||||
- If it is an annual period, it will be an integer corresponding to the year
|
||||
- If it is an quarterly periods, it will be an integer like `<year><index of quarter>`. The last two decimal digits represents the index of quarter. Others represent the year.
|
||||
- `value`: the described value
|
||||
- `_next`: the byte index of the next occurance of the field.
|
||||
|
||||
Besides the feature data, an index `XXX_a.index` is included to speed up the querying performance
|
||||
|
||||
The statements are soted by the `date` in ascending order from the beginning of the file.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# the data format from XXXX.data
|
||||
array([(20070428, 200701, 0.090219 , 4294967295),
|
||||
(20070817, 200702, 0.13933 , 4294967295),
|
||||
(20071023, 200703, 0.24586301, 4294967295),
|
||||
(20080301, 200704, 0.3479 , 80),
|
||||
(20080313, 200704, 0.395989 , 4294967295),
|
||||
(20080422, 200801, 0.100724 , 4294967295),
|
||||
(20080828, 200802, 0.24996801, 4294967295),
|
||||
(20081027, 200803, 0.33412001, 4294967295),
|
||||
(20090325, 200804, 0.39011699, 4294967295),
|
||||
(20090421, 200901, 0.102675 , 4294967295),
|
||||
(20090807, 200902, 0.230712 , 4294967295),
|
||||
(20091024, 200903, 0.30072999, 4294967295),
|
||||
(20100402, 200904, 0.33546099, 4294967295),
|
||||
(20100426, 201001, 0.083825 , 4294967295),
|
||||
(20100812, 201002, 0.200545 , 4294967295),
|
||||
(20101029, 201003, 0.260986 , 4294967295),
|
||||
(20110321, 201004, 0.30739301, 4294967295),
|
||||
(20110423, 201101, 0.097411 , 4294967295),
|
||||
(20110831, 201102, 0.24825101, 4294967295),
|
||||
(20111018, 201103, 0.318919 , 4294967295),
|
||||
(20120323, 201104, 0.4039 , 420),
|
||||
(20120411, 201104, 0.403925 , 4294967295),
|
||||
(20120426, 201201, 0.112148 , 4294967295),
|
||||
(20120810, 201202, 0.26484701, 4294967295),
|
||||
(20121026, 201203, 0.370487 , 4294967295),
|
||||
(20130329, 201204, 0.45004699, 4294967295),
|
||||
(20130418, 201301, 0.099958 , 4294967295),
|
||||
(20130831, 201302, 0.21044201, 4294967295),
|
||||
(20131016, 201303, 0.30454299, 4294967295),
|
||||
(20140325, 201304, 0.394328 , 4294967295),
|
||||
(20140425, 201401, 0.083217 , 4294967295),
|
||||
(20140829, 201402, 0.16450299, 4294967295),
|
||||
(20141030, 201403, 0.23408499, 4294967295),
|
||||
(20150421, 201404, 0.319612 , 4294967295),
|
||||
(20150421, 201501, 0.078494 , 4294967295),
|
||||
(20150828, 201502, 0.137504 , 4294967295),
|
||||
(20151023, 201503, 0.201709 , 4294967295),
|
||||
(20160324, 201504, 0.26420501, 4294967295),
|
||||
(20160421, 201601, 0.073664 , 4294967295),
|
||||
(20160827, 201602, 0.136576 , 4294967295),
|
||||
(20161029, 201603, 0.188062 , 4294967295),
|
||||
(20170415, 201604, 0.244385 , 4294967295),
|
||||
(20170425, 201701, 0.080614 , 4294967295),
|
||||
(20170728, 201702, 0.15151 , 4294967295),
|
||||
(20171026, 201703, 0.25416601, 4294967295),
|
||||
(20180328, 201704, 0.32954201, 4294967295),
|
||||
(20180428, 201801, 0.088887 , 4294967295),
|
||||
(20180802, 201802, 0.170563 , 4294967295),
|
||||
(20181029, 201803, 0.25522 , 4294967295),
|
||||
(20190329, 201804, 0.34464401, 4294967295),
|
||||
(20190425, 201901, 0.094737 , 4294967295),
|
||||
(20190713, 201902, 0. , 1040),
|
||||
(20190718, 201902, 0.175322 , 4294967295),
|
||||
(20191016, 201903, 0.25581899, 4294967295)],
|
||||
dtype=[('date', '<u4'), ('period', '<u4'), ('value', '<f8'), ('_next', '<u4')])
|
||||
# - each row contains 20 byte
|
||||
|
||||
|
||||
# The data format from XXXX.index. It consists of two parts
|
||||
# 1) the start index of the data. So the first part of the info will be like
|
||||
2007
|
||||
# 2) the remain index data will be like information below
|
||||
# - The data indicate the **byte index** of first data update of a period.
|
||||
# - e.g. Because the info at both byte 80 and 100 corresponds to 200704. The byte index of first occurance (i.e. 100) is recorded in the data.
|
||||
array([ 0, 20, 40, 60, 100,
|
||||
120, 140, 160, 180, 200,
|
||||
220, 240, 260, 280, 300,
|
||||
320, 340, 360, 380, 400,
|
||||
440, 460, 480, 500, 520,
|
||||
540, 560, 580, 600, 620,
|
||||
640, 660, 680, 700, 720,
|
||||
740, 760, 780, 800, 820,
|
||||
840, 860, 880, 900, 920,
|
||||
940, 960, 980, 1000, 1020,
|
||||
1060, 4294967295], dtype=uint32)
|
||||
|
||||
|
||||
|
||||
|
||||
Known limitations:
|
||||
|
||||
- Currently, the PIT database is designed for quarterly or annually factors, which can handle fundamental data of financial reports in most markets.
|
||||
- Qlib leverage the file name to identify the type of the data. File with name like `XXX_q.data` corresponds to quarterly data. File with name like `XXX_a.data` corresponds to annual data.
|
||||
- The caclulation of PIT is not performed in the optimal way. There is great potential to boost the performance of PIT data calcuation.
|
||||
@@ -52,7 +52,8 @@ Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency
|
||||
Qlib Format Dataset
|
||||
--------------------
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows.
|
||||
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
Users can leverage `$factor` to get the original trading price (e.g. `$close / $factor` to get the original close price).
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@@ -436,7 +437,7 @@ Dataset
|
||||
|
||||
The ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing.
|
||||
|
||||
The motivation of this module is that we want to maximize the flexibility of of different models to handle data that are suitable for themselves. This module gives the model the flexibility to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
|
||||
The motivation of this module is that we want to maximize the flexibility of different models to handle data that are suitable for themselves. This module gives the model the flexibility to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
|
||||
|
||||
If user's model need process its data in a different way, user could implement his own ``Dataset`` class. If the model's
|
||||
data processing is not special, ``DatasetH`` can be used directly.
|
||||
|
||||
@@ -28,4 +28,11 @@ The frequency of trading algorithm, decision content and execution environment c
|
||||
Example
|
||||
===========================
|
||||
|
||||
An example of nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.
|
||||
An example of nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.
|
||||
|
||||
|
||||
Besides, the above examples, here are some other related work about high-frequency trading in Qlib.
|
||||
|
||||
- `Prediction with high-frequency data <https://github.com/microsoft/qlib/tree/main/examples/highfreq#benchmarks-performance-predicting-the-price-trend-in-high-frequency-data>`_
|
||||
- `Examples <https://github.com/microsoft/qlib/blob/main/examples/orderbook_data/>`_ to extract features form high-frequency data without fixed frequency.
|
||||
- `A paper <https://github.com/microsoft/qlib/tree/high-freq-execution#high-frequency-execution>`_ for high-frequency trading.
|
||||
|
||||
@@ -143,3 +143,9 @@ Here is a simple exampke of what is done in ``PortAnaRecord``, which users can r
|
||||
print(analysis_df)
|
||||
|
||||
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
|
||||
|
||||
|
||||
|
||||
Known Limitations
|
||||
=================
|
||||
- The Python objects are saved based on pickle, which may results in issues when the environment dumping objects and loading objects are different.
|
||||
|
||||
@@ -20,6 +20,9 @@ Introduction
|
||||
- model_performance_graph
|
||||
|
||||
|
||||
All of the accumulated profit metrics(e.g. return, max drawdown) in Qlib are calculated by summation.
|
||||
This avoids the metrics or the plots being skewed exponentially over time.
|
||||
|
||||
Graphical Reports
|
||||
===================
|
||||
|
||||
@@ -101,7 +104,7 @@ Graphical Result
|
||||
- Axis Y:
|
||||
- `ic`
|
||||
The `Pearson correlation coefficient` series between `label` and `prediction score`.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
In the above example, the `label` is formulated as `Ref($close, -2)/Ref($close, -1)-1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
|
||||
- `rank_ic`
|
||||
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.
|
||||
|
||||
@@ -24,11 +24,8 @@ BaseStrategy
|
||||
|
||||
Qlib provides a base class ``qlib.strategy.base.BaseStrategy``. All strategy classes need to inherit the base class and implement its interface.
|
||||
|
||||
- `get_risk_degree`
|
||||
Return the proportion of your total value you will use in investment. Dynamically risk_degree will result in Market timing.
|
||||
|
||||
- `generate_order_list`
|
||||
Return the order list.
|
||||
- `generate_trade_decision`
|
||||
generate_trade_decision is a key interface that generates trade decisions in each trading bar.
|
||||
The frequency to call this method depends on the executor frequency("time_per_step"="day" by default). But the trading frequency can be decided by users' implementation.
|
||||
For example, if the user wants to trading in weekly while the `time_per_step` is "day" in executor, user can return non-empty TradeDecision weekly(otherwise return empty like `this <https://github.com/microsoft/qlib/blob/main/qlib/contrib/strategy/signal_strategy.py#L132>`_ ).
|
||||
|
||||
@@ -69,18 +66,24 @@ TopkDropoutStrategy
|
||||
- Adopt the ``Topk-Drop`` algorithm to calculate the target amount of each stock
|
||||
|
||||
.. note::
|
||||
``Topk-Drop`` algorithm:
|
||||
There are two parameters for the ``Topk-Drop`` algorithm:
|
||||
|
||||
- `Topk`: The number of stocks held
|
||||
- `Drop`: The number of stocks sold on each trading day
|
||||
|
||||
Currently, the number of held stocks is `Topk`.
|
||||
On each trading day, the `Drop` number of held stocks with the worst `prediction score` will be sold, and the same number of unheld stocks with the best `prediction score` will be bought.
|
||||
|
||||
In general, the number of stocks currently held is `Topk`, with the exception of being zero at the beginning period of trading.
|
||||
For each trading day, let $d$ be the number of the instruments currently held and with a rank $\gt K$ when ranked by the prediction scores from high to low.
|
||||
Then `d` number of stocks currently held with the worst `prediction score` will be sold, and the same number of unheld stocks with the best `prediction score` will be bought.
|
||||
|
||||
In general, $d=$`Drop`, especially when the pool of the candidate instruments is large, $K$ is large, and `Drop` is small.
|
||||
|
||||
In most cases, ``TopkDrop`` algorithm sells and buys `Drop` stocks every trading day, which yields a turnover rate of 2$\times$`Drop`/$K$.
|
||||
|
||||
The following images illustrate a typical scenario.
|
||||
.. image:: ../_static/img/topk_drop.png
|
||||
:alt: Topk-Drop
|
||||
|
||||
``TopkDrop`` algorithm sells `Drop` stocks every trading day, which guarantees a fixed turnover rate.
|
||||
|
||||
|
||||
- Generate the order list from the target amount
|
||||
|
||||
@@ -126,7 +129,9 @@ A prediction sample is shown as follows.
|
||||
|
||||
Normally, the prediction score is the output of the models. But some models are learned from a label with a different scale. So the scale of the prediction score may be different from your expectation(e.g. the return of instruments).
|
||||
|
||||
Qlib didn't add a step to scale the prediction score to a unified scale. Because not every trading strategy cares about the scale(e.g. TopkDropoutStrategy only cares about the order). So the strategy is responsible for rescaling the prediction score(e.g. some portfolio-optimization-based strategies may require a meaningful scale).
|
||||
Qlib didn't add a step to scale the prediction score to a unified scale due to the following reasons.
|
||||
- Because not every trading strategy cares about the scale(e.g. TopkDropoutStrategy only cares about the order). So the strategy is responsible for rescaling the prediction score(e.g. some portfolio-optimization-based strategies may require a meaningful scale).
|
||||
- The model has the flexibility to define the target, loss, and data processing. So we don't think there is a silver bullet to rescale it back directly barely based on the model's outputs. If you want to scale it back to some meaningful values(e.g. stock returns.), an intuitive solution is to create a regression model for the model's recent outputs and your recent target values.
|
||||
|
||||
Running backtest
|
||||
-----------------
|
||||
@@ -162,12 +167,9 @@ Running backtest
|
||||
start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_obj
|
||||
)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"], freq=analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq
|
||||
)
|
||||
# default frequency will be daily (i.e. "day")
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
pprint(analysis_df)
|
||||
@@ -192,6 +194,14 @@ Running backtest
|
||||
qlib.init(provider_uri=<qlib data dir>)
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
# Benchmark is for calculating the excess return of your strategy.
|
||||
# Its data format will be like **ONE normal instrument**.
|
||||
# For example, you can query its data with the code below
|
||||
# `D.features(["SH000300"], ["$close"], start_time='2010-01-01', end_time='2017-12-31', freq='day')`
|
||||
# It is different from the argument `market`, which indicates a universe of stocks (e.g. **A SET** of stocks like csi300)
|
||||
# For example, you can query all data from a stock market with the code below.
|
||||
# ` D.features(D.instruments(market='csi300'), ["$close"], start_time='2010-01-01', end_time='2017-12-31', freq='day')`
|
||||
|
||||
FREQ = "day"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
|
||||
@@ -233,7 +233,7 @@ The meaning of each field is as follows:
|
||||
Dataset Section
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Model <../component/data.html#dataset>`_.
|
||||
The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Data <../component/data.html#dataset>`_.
|
||||
|
||||
The keywords arguments configuration of the ``DataHandler`` is as follows:
|
||||
|
||||
@@ -248,7 +248,7 @@ The keywords arguments configuration of the ``DataHandler`` is as follows:
|
||||
|
||||
Users can refer to the document of `DataHandler <../component/data.html#datahandler>`_ for more information about the meaning of each field in the configuration.
|
||||
|
||||
Here is the configuration for the ``Dataset`` module which will take care of data preprossing and slicing during the training and testing phase.
|
||||
Here is the configuration for the ``Dataset`` module which will take care of data preprocessing and slicing during the training and testing phase.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
|
||||
12
docs/conf.py
12
docs/conf.py
@@ -54,9 +54,9 @@ master_doc = "index"
|
||||
|
||||
|
||||
# General information about the project.
|
||||
project = u"QLib"
|
||||
copyright = u"Microsoft"
|
||||
author = u"Microsoft"
|
||||
project = "QLib"
|
||||
copyright = "Microsoft"
|
||||
author = "Microsoft"
|
||||
|
||||
# The version info for the project you're documenting, acts as replacement for
|
||||
# |version| and |release|, also used in various other places throughout the
|
||||
@@ -174,7 +174,7 @@ latex_elements = {
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, "qlib.tex", u"QLib Documentation", u"Microsoft", "manual"),
|
||||
(master_doc, "qlib.tex", "QLib Documentation", "Microsoft", "manual"),
|
||||
]
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ latex_documents = [
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [(master_doc, "qlib", u"QLib Documentation", [author], 1)]
|
||||
man_pages = [(master_doc, "qlib", "QLib Documentation", [author], 1)]
|
||||
|
||||
|
||||
# -- Options for Texinfo output -------------------------------------------
|
||||
@@ -194,7 +194,7 @@ texinfo_documents = [
|
||||
(
|
||||
master_doc,
|
||||
"QLib",
|
||||
u"QLib Documentation",
|
||||
"QLib Documentation",
|
||||
author,
|
||||
"QLib",
|
||||
"One line description of project.",
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
.. _code_standard:
|
||||
|
||||
=================================
|
||||
Code Standard
|
||||
=================================
|
||||
|
||||
Docstring
|
||||
=================================
|
||||
Please use the `Numpydoc Style <https://stackoverflow.com/a/24385103>`_.
|
||||
|
||||
Continuous Integration
|
||||
=================================
|
||||
Continuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.
|
||||
|
||||
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
|
||||
|
||||
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pip install black
|
||||
python -m black . -l 120
|
||||
60
docs/developer/code_standard_and_dev_guide.rst
Normal file
60
docs/developer/code_standard_and_dev_guide.rst
Normal file
@@ -0,0 +1,60 @@
|
||||
.. _code_standard:
|
||||
|
||||
=================================
|
||||
Code Standard
|
||||
=================================
|
||||
|
||||
Docstring
|
||||
=================================
|
||||
Please use the `Numpydoc Style <https://stackoverflow.com/a/24385103>`_.
|
||||
|
||||
Continuous Integration
|
||||
=================================
|
||||
Continuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.
|
||||
|
||||
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
|
||||
|
||||
1. Qlib will check the code format with black. The PR will raise error if your code does not align to the standard of Qlib(e.g. a common error is the mixed use of space and tab).
|
||||
You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install black
|
||||
python -m black . -l 120
|
||||
|
||||
|
||||
2. Qlib will check your code style pylint. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L66).
|
||||
Sometime pylint's restrictions are not that reasonable. You can ignore specific errors like this
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
return -ICLoss()(pred, target, index) # pylint: disable=E1130
|
||||
|
||||
|
||||
3. Qlib will check your code style flake8. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L73).
|
||||
You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
flake8 --ignore E501,F541,E402,F401,W503,E741,E266,E203,E302,E731,E262,F523,F821,F811,F841,E713,E265,W291,E712,E722,W293 qlib
|
||||
|
||||
|
||||
4. Qlib has integrated pre-commit, which will make it easier for developers to format their code.
|
||||
Just run the following two commands, and the code will be automatically formatted using black and flake8 when the git commit command is executed.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -e .[dev]
|
||||
pre-commit install
|
||||
|
||||
|
||||
=================================
|
||||
Development Guidance
|
||||
=================================
|
||||
|
||||
As a developer, you often want make changes to `Qlib` and hope it would reflect directly in your environment without reinstalling it. You can install `Qlib` in editable mode with following command.
|
||||
The `[dev]` option will help you to install some related packages when developing `Qlib` (e.g. pytest, sphinx)
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -e .[dev]
|
||||
@@ -53,6 +53,7 @@ Document Structure
|
||||
Online & Offline mode <advanced/server.rst>
|
||||
Serialization <advanced/serial.rst>
|
||||
Task Management <advanced/task_management.rst>
|
||||
Point-In-Time database <advanced/PIT.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -120,6 +120,32 @@ For more details about features, please refer `Feature API <../component/data.ht
|
||||
|
||||
.. note:: When calling `D.features()` at the client, use parameter `disk_cache=0` to skip dataset cache, use `disk_cache=1` to generate and use dataset cache. In addition, when calling at the server, users can use `disk_cache=2` to update the dataset cache.
|
||||
|
||||
|
||||
When you are building complicated expressions, implementing all the expressions in a single string may not be easy.
|
||||
For example, it looks quite long and complicated:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data import D
|
||||
>> data = D.features(["sh600519"], ["(($high / $close) + ($open / $close)) * (($high / $close) + ($open / $close)) / ($high / $close) + ($open / $close)"], start_time="20200101")
|
||||
|
||||
|
||||
But using string is not the only way to implement the expression. You can also implement expression by code.
|
||||
Here is an exmaple which does the same thing as above examples.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data.ops import *
|
||||
>> f1 = Feature("high") / Feature("close")
|
||||
>> f2 = Feature("open") / Feature("close")
|
||||
>> f3 = f1 + f2
|
||||
>> f4 = f3 * f3 / f3
|
||||
|
||||
>> data = D.features(["sh600519"], [f4], start_time="20200101")
|
||||
>> data.head()
|
||||
|
||||
|
||||
API
|
||||
====================
|
||||
To know more about how to use the Data, go to API Reference: `Data API <../reference/api.html#data>`_
|
||||
|
||||
@@ -37,7 +37,8 @@ Initialize Qlib before calling other APIs: run following code in python.
|
||||
Parameters
|
||||
-------------------
|
||||
|
||||
Besides `provider_uri` and `region`, `qlib.init` has other parameters. The following are several important parameters of `qlib.init`:
|
||||
Besides `provider_uri` and `region`, `qlib.init` has other parameters.
|
||||
The following are several important parameters of `qlib.init` (`Qlib` has a lot of config. Only part of parameters are limited here. More detailed setting can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/config.py>`_):
|
||||
|
||||
- `provider_uri`
|
||||
Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.
|
||||
@@ -48,7 +49,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
- ``qlib.constant.REG_CN``: China stock market.
|
||||
|
||||
Different modes will result in different trading limitations and costs.
|
||||
The region is just `shortcuts for defining a batch of configurations <https://github.com/microsoft/qlib/blob/main/qlib/config.py#L239>`_. Users can set the key configurations manually if the existing region setting can't meet their requirements.
|
||||
The region is just `shortcuts for defining a batch of configurations <https://github.com/microsoft/qlib/blob/528f74af099bf6156e9480bcd2bb28e453231212/qlib/config.py#L249>`_, which include minimal trading order unit (``trade_unit``), trading limitation (``limit_threshold``) , etc. It is not a necessary part and users can set the key configurations manually if the existing region setting can't meet their requirements.
|
||||
- `redis_host`
|
||||
Type: str, optional parameter(default: "127.0.0.1"), host of `redis`
|
||||
The lock and cache mechanism relies on redis.
|
||||
@@ -88,3 +89,9 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
"task_url": "mongodb://localhost:27017/", # your mongo url
|
||||
"task_db_name": "rolling_db", # the database name of Task Management
|
||||
})
|
||||
|
||||
- `logging_level`
|
||||
The logging level for the system.
|
||||
|
||||
- `kernels`
|
||||
The number of processes used when calculating features in Qlib's expression engine. It is very helpful to set it to 1 when you are debuggin an expression calculating exception
|
||||
|
||||
@@ -6,3 +6,4 @@
|
||||
|
||||
[https://www.ijcai.org/Proceedings/2017/0366.pdf](https://www.ijcai.org/Proceedings/2017/0366.pdf)
|
||||
|
||||
- NOTE: Current version of implementation is just a simplified version of ALSTM. It is an LSTM with attention.
|
||||
|
||||
3
examples/benchmarks/HIST/README.md
Normal file
3
examples/benchmarks/HIST/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# HIST
|
||||
* Code: [https://github.com/Wentao-Xu/HIST](https://github.com/Wentao-Xu/HIST)
|
||||
* Paper: [HIST: A Graph-based Framework for Stock Trend Forecasting via Mining Concept-Oriented Shared InformationAdaRNN: Adaptive Learning and Forecasting for Time Series](https://arxiv.org/abs/2110.13716).
|
||||
BIN
examples/benchmarks/HIST/qlib_csi300_stock_index.npy
Normal file
BIN
examples/benchmarks/HIST/qlib_csi300_stock_index.npy
Normal file
Binary file not shown.
4
examples/benchmarks/HIST/requirements.txt
Normal file
4
examples/benchmarks/HIST/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
92
examples/benchmarks/HIST/workflow_config_hist_Alpha360.yaml
Normal file
92
examples/benchmarks/HIST/workflow_config_hist_Alpha360.yaml
Normal file
@@ -0,0 +1,92 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: HIST
|
||||
module_path: qlib.contrib.model.pytorch_hist
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
metric: ic
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
stock2concept: "benchmarks/HIST/qlib_csi300_stock2concept.npy"
|
||||
stock_index: "benchmarks/HIST/qlib_csi300_stock_index.npy"
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
4
examples/benchmarks/IGMTF/README.md
Normal file
4
examples/benchmarks/IGMTF/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# IGMTF
|
||||
* Code: [https://github.com/Wentao-Xu/IGMTF](https://github.com/Wentao-Xu/IGMTF)
|
||||
* Paper: [IGMTF: An Instance-wise Graph-based Framework for
|
||||
Multivariate Time Series Forecasting](https://arxiv.org/abs/2109.06489).
|
||||
4
examples/benchmarks/IGMTF/requirements.txt
Normal file
4
examples/benchmarks/IGMTF/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -0,0 +1,89 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: IGMTF
|
||||
module_path: qlib.contrib.model.pytorch_igmtf
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
metric: ic
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,3 +1,3 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.21.0
|
||||
lightgbm==3.1.0
|
||||
lightgbm
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,80 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -63,8 +63,6 @@ task:
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
input_dim: 157
|
||||
output_dim: 1
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
@@ -73,6 +71,8 @@ task:
|
||||
batch_size: 8192
|
||||
GPU: 0
|
||||
weight_decay: 0.0002
|
||||
pt_model_kwargs:
|
||||
input_dim: 157
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -51,8 +51,6 @@ task:
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
input_dim: 360
|
||||
output_dim: 1
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
@@ -60,6 +58,8 @@ task:
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
GPU: 0
|
||||
pt_model_kwargs:
|
||||
input_dim: 360
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -4,6 +4,7 @@ This page lists a batch of methods designed for alpha seeking. Each method tries
|
||||
The alpha is evaluated in two ways.
|
||||
1. The correlation between the alpha and future return.
|
||||
1. Constructing portfolio based on the alpha and evaluating the final total return.
|
||||
- The explanation of metrics can be found [here](https://qlib.readthedocs.io/en/latest/component/report.html#id4)
|
||||
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
|
||||
|
||||
@@ -16,8 +17,12 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
> NOTE:
|
||||
> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference.
|
||||
|
||||
> NOTE:
|
||||
> We have very limited resources to implement and finetune the models. We tried our best effort to fairly compare these models. But some models may have greater potential than what it looks like in the table below. Your contribution is highly welcomed to explore their potential.
|
||||
|
||||
## Alpha158 dataset
|
||||
## Results on CSI300
|
||||
|
||||
### Alpha158 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
@@ -41,7 +46,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 |
|
||||
|
||||
|
||||
## Alpha360 dataset
|
||||
### Alpha360 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|-------------------------------------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
@@ -62,7 +67,65 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
|
||||
| TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 |
|
||||
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
|
||||
| IGMTF(Wentao Xu, et al.) | Alpha360 | 0.0480±0.00 | 0.3589±0.02 | 0.0606±0.00 | 0.4773±0.01 | 0.0946±0.02 | 1.3509±0.25 | -0.0716±0.02 |
|
||||
| HIST(Wentao Xu, et al.) | Alpha360 | 0.0522±0.00 | 0.3530±0.01 | 0.0667±0.00 | 0.4576±0.01 | 0.0987±0.02 | 1.3726±0.27 | -0.0681±0.01 |
|
||||
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
- The base model of TCTS is GRU.
|
||||
- About the datasets
|
||||
- Alpha158 is a tabular dataset. There are less spatial relationships between different features. Each feature are carefully desgined by human (a.k.a feature engineering)
|
||||
- Alpha360 contains raw price and volue data without much feature engineering. There are strong strong spatial relationships between the features in the time dimension.
|
||||
- The metrics can be categorized into two
|
||||
- Signal-based evaluation: IC, ICIR, Rank IC, Rank ICIR
|
||||
- Portfolio-based metrics: Annualized Return, Information Ratio, Max Drawdown
|
||||
|
||||
## Results on CSI500
|
||||
The results on CSI500 is not complete. PR's for models on csi500 are welcome!
|
||||
|
||||
Transfer previous models in CSI300 to CSI500 is quite easy. You can try models with just a few commands below.
|
||||
```
|
||||
cd examples/benchmarks/LightGBM
|
||||
pip install -r requirements.txt
|
||||
|
||||
# create new config and set the benchmark to csi500
|
||||
cp workflow_config_lightgbm_Alpha158.yaml workflow_config_lightgbm_Alpha158_csi500.yaml
|
||||
sed -i "s/csi300/csi500/g" workflow_config_lightgbm_Alpha158_csi500.yaml
|
||||
sed -i "s/SH000300/SH000905/g" workflow_config_lightgbm_Alpha158_csi500.yaml
|
||||
|
||||
# you can either run the model once
|
||||
qrun workflow_config_lightgbm_Alpha158_csi500.yaml
|
||||
|
||||
# or run it for multiple times automatically and get the summarized results.
|
||||
cd ../../
|
||||
python run_all_model.py run 3 lightgbm Alpha158 csi500 # for models with randomness. please run it for 20 times.
|
||||
```
|
||||
|
||||
### Alpha158 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| LightGBM | Alpha158 | 0.0377±0.00 | 0.3860±0.00 | 0.0448±0.00 | 0.4675±0.00 | 0.1151±0.00 | 1.3884±0.00 | -0.0898±0.00 |
|
||||
|
||||
### Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| LightGBM | Alpha360 | 0.0400±0.00 | 0.3605±0.00 | 0.0536±0.00 | 0.5431±0.00 | 0.0505±0.00 | 0.7658±0.02 | -0.1880±0.00 |
|
||||
|
||||
|
||||
# Contributing
|
||||
|
||||
Your contributions to new models are highly welcome!
|
||||
|
||||
If you want to contribute your new models, you can follow the steps below.
|
||||
1. Create a folder for your model
|
||||
2. The folder contains following items(you can refer to [this example](https://github.com/microsoft/qlib/tree/main/examples/benchmarks/TCTS)).
|
||||
- `requirements.txt`: required dependencies.
|
||||
- `README.md`: a brief introduction to your models
|
||||
- `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.
|
||||
3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).
|
||||
4. Please updated your results in the benchmark tables, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on 20 runs with different random seeds, if you don't have enough computational resource, you can ask for help in the PR).
|
||||
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
|
||||
|
||||
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))
|
||||
|
||||
@@ -6,8 +6,7 @@ import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -95,7 +94,7 @@ class MTSDatasetH(DatasetH):
|
||||
shuffle=True,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
assert horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
@@ -150,8 +149,15 @@ class MTSDatasetH(DatasetH):
|
||||
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
fn = _get_date_parse_fn(self._index[0][1])
|
||||
start_date = fn(slc.start)
|
||||
end_date = fn(slc.stop)
|
||||
|
||||
if isinstance(slc, slice):
|
||||
start, stop = slc.start, slc.stop
|
||||
elif isinstance(slc, (list, tuple)):
|
||||
start, stop = slc
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
start_date = fn(start)
|
||||
end_date = fn(stop)
|
||||
obj = copy.copy(self) # shallow copy
|
||||
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
|
||||
obj._data = self._data
|
||||
|
||||
@@ -130,7 +130,7 @@ class TRAModel(Model):
|
||||
|
||||
if prob is not None:
|
||||
P = sinkhorn(-L, epsilon=0.01) # sample assignment matrix
|
||||
lamb = self.lamb * (self.rho ** self.global_step)
|
||||
lamb = self.lamb * (self.rho**self.global_step)
|
||||
reg = prob.log().mul(P).sum(dim=-1).mean()
|
||||
loss = loss - lamb * reg
|
||||
|
||||
@@ -547,7 +547,7 @@ def evaluate(pred):
|
||||
score = pred.score
|
||||
label = pred.label
|
||||
diff = score - label
|
||||
MSE = (diff ** 2).mean()
|
||||
MSE = (diff**2).mean()
|
||||
MAE = (diff.abs()).mean()
|
||||
IC = score.corr(label)
|
||||
return {"MSE": MSE, "MAE": MAE, "IC": IC}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -1,3 +1,3 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
torch==1.2.0
|
||||
@@ -1,3 +1,3 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
xgboost==1.2.1
|
||||
@@ -4,16 +4,16 @@ This is the implementation of `DDG-DA` based on `Meta Controller` component prov
|
||||
Please refer to the paper for more details: *DDG-DA: Data Distribution Generation for Predictable Concept Drift Adaptation* [[arXiv](https://arxiv.org/abs/2201.04038)]
|
||||
|
||||
|
||||
## Background
|
||||
# Background
|
||||
In many real-world scenarios, we often deal with streaming data that is sequentially collected over time. Due to the non-stationary nature of the environment, the streaming data distribution may change in unpredictable ways, which is known as concept drift. To handle concept drift, previous methods first detect when/where the concept drift happens and then adapt models to fit the distribution of the latest data. However, there are still many cases that some underlying factors of environment evolution are predictable, making it possible to model the future concept drift trend of the streaming data, while such cases are not fully explored in previous work.
|
||||
|
||||
Therefore, we propose a novel method `DDG-DA`, that can effectively forecast the evolution of data distribution and improve the performance of models. Specifically, we first train a predictor to estimate the future data distribution, then leverage it to generate training samples, and finally train models on the generated data.
|
||||
|
||||
## Dataset
|
||||
# Dataset
|
||||
The data in the paper are private. So we conduct experiments on Qlib's public dataset.
|
||||
Though the dataset is different, the conclusion remains the same. By applying `DDG-DA`, users can see rising trends at the test phase both in the proxy models' ICs and the performances of the forecasting models.
|
||||
|
||||
## Run the Code
|
||||
# Run the Code
|
||||
Users can try `DDG-DA` by running the following command:
|
||||
```bash
|
||||
python workflow.py run_all
|
||||
@@ -24,7 +24,12 @@ The default forecasting models are `Linear`. Users can choose other forecasting
|
||||
python workflow.py --forecast_model="gbdt" run_all
|
||||
```
|
||||
|
||||
|
||||
## Results
|
||||
|
||||
# Results
|
||||
The results of related methods in Qlib's public dataset can be found [here](../)
|
||||
|
||||
# Requirements
|
||||
Here are the minimal hardware requirements to run the ``workflow.py`` of DDG-DA.
|
||||
* Memory: 45G
|
||||
* Disk: 4G
|
||||
|
||||
Pytorch with CPU & RAM will be enough for this example.
|
||||
|
||||
@@ -9,13 +9,10 @@ from qlib.data.dataset.handler import DataHandlerLP
|
||||
import pandas as pd
|
||||
import fire
|
||||
import sys
|
||||
from tqdm.auto import tqdm
|
||||
import yaml
|
||||
import pickle
|
||||
from qlib import auto_init
|
||||
from qlib.model.trainer import TrainerR, task_train
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
@@ -47,9 +44,10 @@ class DDGDA:
|
||||
rb = RollingBenchmark(model_type="gbdt")
|
||||
task = rb.basic_task()
|
||||
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
with R.start(experiment_name="feature_importance"):
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
fi = model.get_feature_importance()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
lightgbm==3.1.0
|
||||
optuna==2.7.0
|
||||
optuna-dashboard==0.4.1
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
@@ -15,6 +15,10 @@ from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
|
||||
import pandas as pd
|
||||
from qlib.contrib.evaluate import backtest_daily
|
||||
from qlib.contrib.evaluate import risk_analysis
|
||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
@@ -30,6 +34,7 @@ class OnlineSimulationExample:
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=None,
|
||||
trainer="TrainerR",
|
||||
):
|
||||
"""
|
||||
Init OnlineManagerExample.
|
||||
@@ -60,7 +65,13 @@ class OnlineSimulationExample:
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
if trainer == "TrainerRM":
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool)
|
||||
elif trainer == "TrainerR":
|
||||
self.trainer = TrainerR(self.exp_name)
|
||||
else:
|
||||
# TODO: support all the trainers: TrainerR, TrainerRM, DelayTrainerR
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
@@ -70,7 +81,8 @@ class OnlineSimulationExample:
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
TaskManager(self.task_pool).remove()
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
@@ -84,7 +96,30 @@ class OnlineSimulationExample:
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
signals = self.rolling_online_manager.get_signals()
|
||||
print(signals)
|
||||
# Backtesting
|
||||
# - the code is based on this example https://qlib.readthedocs.io/en/latest/component/strategy.html
|
||||
CSI300_BENCH = "SH000903"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 30,
|
||||
"n_drop": 3,
|
||||
"signal": signals.to_frame("score"),
|
||||
}
|
||||
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = backtest_daily(
|
||||
start_time=signals.index.get_level_values("datetime").min(),
|
||||
end_time=signals.index.get_level_values("datetime").max(),
|
||||
strategy=strategy_obj,
|
||||
)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
pprint(analysis_df)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
|
||||
@@ -21,7 +21,7 @@ class TestClass(unittest.TestCase):
|
||||
provider_uri = "~/.qlib/qlib_data/yahoo_cn_1min"
|
||||
qlib.init(
|
||||
provider_uri=provider_uri,
|
||||
mem_cache_size_limit=1024 ** 3 * 2,
|
||||
mem_cache_size_limit=1024**3 * 2,
|
||||
mem_cache_type="sizeof",
|
||||
kernels=1,
|
||||
expression_provider={"class": "LocalExpressionProvider", "kwargs": {"time2idx": False}},
|
||||
|
||||
@@ -24,6 +24,7 @@ We use China stock market data for our example.
|
||||
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
|
||||
rm -f csi300_weight.zip
|
||||
```
|
||||
NOTE: We don't find any public free resource to get the weight in the benchmark. To run the example, we manually create this weight data.
|
||||
|
||||
2. Prepare risk model data:
|
||||
|
||||
|
||||
@@ -117,8 +117,10 @@ def get_all_folders(models, exclude) -> dict:
|
||||
|
||||
|
||||
# function to get all the files under the model folder
|
||||
def get_all_files(folder_path, dataset) -> (str, str):
|
||||
yaml_path = str(Path(f"{folder_path}") / f"*{dataset}*.yaml")
|
||||
def get_all_files(folder_path, dataset, universe="") -> (str, str):
|
||||
if universe != "":
|
||||
universe = f"_{universe}"
|
||||
yaml_path = str(Path(f"{folder_path}") / f"*{dataset}{universe}.yaml")
|
||||
req_path = str(Path(f"{folder_path}") / f"*.txt")
|
||||
yaml_file = glob.glob(yaml_path)
|
||||
req_file = glob.glob(req_path)
|
||||
@@ -224,6 +226,7 @@ class ModelRunner:
|
||||
times=1,
|
||||
models=None,
|
||||
dataset="Alpha360",
|
||||
universe="",
|
||||
exclude=False,
|
||||
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
|
||||
exp_folder_name: str = "run_all_model_records",
|
||||
@@ -245,6 +248,9 @@ class ModelRunner:
|
||||
determines whether the model being used is excluded or included.
|
||||
dataset : str
|
||||
determines the dataset to be used for each model.
|
||||
universe : str
|
||||
the stock universe of the dataset.
|
||||
default "" indicates that
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path (NOTE: the local path must be a absolute path)
|
||||
@@ -259,6 +265,15 @@ class ModelRunner:
|
||||
-------
|
||||
Here are some use cases of the function in the bash:
|
||||
|
||||
The run_all_models will decide which config to run based no `models` `dataset` `universe`
|
||||
Example 1):
|
||||
|
||||
models="lightgbm", dataset="Alpha158", universe="" will result in running the following config
|
||||
examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
models="lightgbm", dataset="Alpha158", universe="csi500" will result in running the following config
|
||||
examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_csi500.yaml
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Case 1 - run all models multiple times
|
||||
@@ -279,6 +294,9 @@ class ModelRunner:
|
||||
# Case 6 - run other models except those are given as arguments for one time
|
||||
python run_all_model.py run --models=[mlp,tft,sfm] --exclude=True
|
||||
|
||||
# Case 7 - run lightgbm model on csi500.
|
||||
python run_all_model.py run 3 lightgbm Alpha158 csi500
|
||||
|
||||
"""
|
||||
self._init_qlib(exp_folder_name)
|
||||
|
||||
@@ -290,7 +308,7 @@ class ModelRunner:
|
||||
for fn in folders:
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn], dataset)
|
||||
yaml_path, req_path = get_all_files(folders[fn], dataset, universe=universe)
|
||||
if yaml_path is None:
|
||||
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
|
||||
continue
|
||||
|
||||
1218
examples/tutorial/detailed_workflow.ipynb
Normal file
1218
examples/tutorial/detailed_workflow.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -256,7 +256,6 @@
|
||||
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"print(recorder)\n",
|
||||
"pred_df = recorder.load_object(\"pred.pkl\")\n",
|
||||
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
|
||||
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n",
|
||||
"positions = recorder.load_object(\"portfolio_analysis/positions_normal_1day.pkl\")\n",
|
||||
"analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis_1day.pkl\")"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.8.2"
|
||||
__version__ = "0.8.6.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
@@ -12,6 +12,7 @@ import platform
|
||||
import subprocess
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
"""
|
||||
@@ -30,8 +31,8 @@ def init(default_conf="client", **kwargs):
|
||||
When using the recorder, skip_if_reg can set to True to avoid loss of recorder.
|
||||
|
||||
"""
|
||||
from .config import C
|
||||
from .data.cache import H
|
||||
from .config import C # pylint: disable=C0415
|
||||
from .data.cache import H # pylint: disable=C0415
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
logger = get_module_logger("Initialization", level=logging.INFO)
|
||||
@@ -85,7 +86,7 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
|
||||
# If the provider uri looks like this 172.23.233.89//data/csdesign'
|
||||
# It will be a nfs path. The client provider will be used
|
||||
if not auto_mount:
|
||||
if not auto_mount: # pylint: disable=R1702
|
||||
if not Path(mount_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
@@ -139,8 +140,10 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
if not _is_mount:
|
||||
try:
|
||||
Path(mount_path).mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
|
||||
except Exception as e:
|
||||
raise OSError(
|
||||
f"Failed to create directory {mount_path}, please create {mount_path} manually!"
|
||||
) from e
|
||||
|
||||
# check nfs-common
|
||||
command_res = os.popen("dpkg -l | grep nfs-common")
|
||||
|
||||
@@ -1,24 +1,29 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import List, Tuple, Union, TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from .account import Account
|
||||
from .report import Indicator, PortfolioMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
from .decision import BaseTradeDecision
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .backtest import backtest_loop
|
||||
from .backtest import collect_data_loop
|
||||
from .utils import CommonInfrastructure
|
||||
from .decision import Order
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
from ..utils import init_instance_by_config
|
||||
from .backtest import backtest_loop, collect_data_loop
|
||||
from .decision import Order
|
||||
from .exchange import Exchange
|
||||
from .utils import CommonInfrastructure
|
||||
|
||||
# make import more user-friendly by adding `from qlib.backtest import STH`
|
||||
|
||||
@@ -27,26 +32,35 @@ logger = get_module_logger("backtest caller")
|
||||
|
||||
|
||||
def get_exchange(
|
||||
exchange=None,
|
||||
freq="day",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes="all",
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5.0,
|
||||
limit_threshold=None,
|
||||
exchange: Union[str, dict, object, Path] = None,
|
||||
freq: str = "day",
|
||||
start_time: Union[pd.Timestamp, str] = None,
|
||||
end_time: Union[pd.Timestamp, str] = None,
|
||||
codes: Union[list, str] = "all",
|
||||
subscribe_fields: list = [],
|
||||
open_cost: float = 0.0015,
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> Exchange:
|
||||
"""get_exchange
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange().
|
||||
exchange: Exchange
|
||||
It could be None or any types that are acceptable by `init_instance_by_config`.
|
||||
freq: str
|
||||
frequency of data.
|
||||
start_time: Union[pd.Timestamp, str]
|
||||
closed start time for backtest.
|
||||
end_time: Union[pd.Timestamp, str]
|
||||
closed end time for backtest.
|
||||
codes: Union[list, str]
|
||||
list stock_id list or a string of instruments (i.e. all, csi500, sse50)
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
@@ -56,8 +70,6 @@ def get_exchange(
|
||||
min_cost : float
|
||||
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
|
||||
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
|
||||
trade_unit : int
|
||||
Included in kwargs. Please refer to the docs of `__init__` of `Exchange`
|
||||
deal_price: Union[str, Tuple[str], List[str]]
|
||||
The `deal_price` supports following two types of input
|
||||
- <deal_price> : str
|
||||
@@ -100,10 +112,14 @@ def get_exchange(
|
||||
|
||||
|
||||
def create_account_instance(
|
||||
start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position"
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
benchmark: str,
|
||||
account: Union[float, int, dict],
|
||||
pos_type: str = "Position",
|
||||
) -> Account:
|
||||
"""
|
||||
# TODO: is very strange pass benchmark_config in the account(maybe for report)
|
||||
# TODO: is very strange pass benchmark_config in the account (maybe for report)
|
||||
# There should be a post-step to process the report.
|
||||
|
||||
Parameters
|
||||
@@ -131,51 +147,53 @@ def create_account_instance(
|
||||
key "cash" means initial cash.
|
||||
key "stock1" means the information of first stock with amount and price(optional).
|
||||
...
|
||||
pos_type: str
|
||||
Postion type.
|
||||
"""
|
||||
if isinstance(account, (int, float)):
|
||||
pos_kwargs = {"init_cash": account}
|
||||
init_cash = account
|
||||
position_dict = {}
|
||||
elif isinstance(account, dict):
|
||||
init_cash = account["cash"]
|
||||
del account["cash"]
|
||||
pos_kwargs = {
|
||||
"init_cash": init_cash,
|
||||
"position_dict": account,
|
||||
}
|
||||
init_cash = account.pop("cash")
|
||||
position_dict = account
|
||||
else:
|
||||
raise ValueError("account must be in (int, float, Position)")
|
||||
raise ValueError("account must be in (int, float, dict)")
|
||||
|
||||
kwargs = {
|
||||
"init_cash": account,
|
||||
"benchmark_config": {
|
||||
return Account(
|
||||
init_cash=init_cash,
|
||||
position_dict=position_dict,
|
||||
pos_type=pos_type,
|
||||
benchmark_config={
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
},
|
||||
"pos_type": pos_type,
|
||||
}
|
||||
kwargs.update(pos_kwargs)
|
||||
return Account(**kwargs)
|
||||
)
|
||||
|
||||
|
||||
def get_strategy_executor(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy: BaseStrategy,
|
||||
executor: BaseExecutor,
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
) -> Tuple[BaseStrategy, BaseExecutor]:
|
||||
|
||||
# NOTE:
|
||||
# - for avoiding recursive import
|
||||
# - typing annotations is not reliable
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
from ..strategy.base import BaseStrategy # pylint: disable=C0415
|
||||
from .executor import BaseExecutor # pylint: disable=C0415
|
||||
|
||||
trade_account = create_account_instance(
|
||||
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
benchmark=benchmark,
|
||||
account=account,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
@@ -195,29 +213,31 @@ def get_strategy_executor(
|
||||
|
||||
|
||||
def backtest(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark="SH000300",
|
||||
account=1e9,
|
||||
exchange_kwargs={},
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and executor in the nested decision execution
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
|
||||
executor in the nested decision execution
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : pd.Timestamp|str
|
||||
start_time : Union[pd.Timestamp, str]
|
||||
closed start time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
end_time : pd.Timestamp|str
|
||||
end_time : Union[pd.Timestamp, str]
|
||||
closed end time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
strategy : Union[str, dict, BaseStrategy]
|
||||
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information.
|
||||
executor : Union[str, dict, BaseExecutor]
|
||||
strategy : Union[str, dict, object, Path]
|
||||
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more
|
||||
information.
|
||||
executor : Union[str, dict, object, Path]
|
||||
for initializing the outermost executor.
|
||||
benchmark: str
|
||||
the benchmark for reporting.
|
||||
@@ -256,16 +276,16 @@ def backtest(
|
||||
|
||||
|
||||
def collect_data(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark="SH000300",
|
||||
account=1e9,
|
||||
exchange_kwargs={},
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
return_value: dict = None,
|
||||
):
|
||||
) -> Generator[object, None, None]:
|
||||
"""initialize the strategy and executor, then collect the trade decision data for rl training
|
||||
|
||||
please refer to the docs of the backtest for the explanation of the parameters
|
||||
@@ -290,7 +310,7 @@ def collect_data(
|
||||
|
||||
def format_decisions(
|
||||
decisions: List[BaseTradeDecision],
|
||||
) -> Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:
|
||||
) -> Optional[Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]]:
|
||||
"""
|
||||
format the decisions collected by `qlib.backtest.collect_data`
|
||||
The decisions will be organized into a tree-like structure.
|
||||
@@ -315,7 +335,7 @@ def format_decisions(
|
||||
|
||||
cur_freq = decisions[0].strategy.trade_calendar.get_freq()
|
||||
|
||||
res = (cur_freq, [])
|
||||
res: Tuple[str, list] = (cur_freq, [])
|
||||
last_dec_idx = 0
|
||||
for i, dec in enumerate(decisions[1:], 1):
|
||||
if dec.strategy.trade_calendar.get_freq() == cur_freq:
|
||||
@@ -323,3 +343,6 @@ def format_decisions(
|
||||
last_dec_idx = i
|
||||
res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 :])))
|
||||
return res
|
||||
|
||||
|
||||
__all__ = ["Order", "backtest"]
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Dict, List, Tuple, TYPE_CHECKING
|
||||
from qlib.utils import init_instance_by_config
|
||||
from typing import Dict, List, Optional, Tuple, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from .position import BasePosition, InfPosition, Position
|
||||
from .report import PortfolioMetrics, Indicator
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
from .high_performance_ds import BaseOrderIndicator
|
||||
from .position import BasePosition
|
||||
from .report import Indicator, PortfolioMetrics
|
||||
|
||||
"""
|
||||
rtn & earning in the Account
|
||||
@@ -34,40 +38,42 @@ class AccumulatedInfo:
|
||||
AccumulatedInfo should be shared across different levels
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.rtn = 0 # accumulated return, do not consider cost
|
||||
self.cost = 0 # accumulated cost
|
||||
self.to = 0 # accumulated turnover
|
||||
def reset(self) -> None:
|
||||
self.rtn: float = 0.0 # accumulated return, do not consider cost
|
||||
self.cost: float = 0.0 # accumulated cost
|
||||
self.to: float = 0.0 # accumulated turnover
|
||||
|
||||
def add_return_value(self, value):
|
||||
def add_return_value(self, value: float) -> None:
|
||||
self.rtn += value
|
||||
|
||||
def add_cost(self, value):
|
||||
def add_cost(self, value: float) -> None:
|
||||
self.cost += value
|
||||
|
||||
def add_turnover(self, value):
|
||||
def add_turnover(self, value: float) -> None:
|
||||
self.to += value
|
||||
|
||||
@property
|
||||
def get_return(self):
|
||||
def get_return(self) -> float:
|
||||
return self.rtn
|
||||
|
||||
@property
|
||||
def get_cost(self):
|
||||
def get_cost(self) -> float:
|
||||
return self.cost
|
||||
|
||||
@property
|
||||
def get_turnover(self):
|
||||
def get_turnover(self) -> float:
|
||||
return self.to
|
||||
|
||||
|
||||
class Account:
|
||||
"""
|
||||
The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in qlib/backtest/executor.py:NestedExecutor
|
||||
Different level of executor has different Account object when calculating metrics. But the position object is shared cross all the Account object.
|
||||
The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in
|
||||
qlib/backtest/executor.py:NestedExecutor
|
||||
Different level of executor has different Account object when calculating metrics. But the position object is
|
||||
shared cross all the Account object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -78,7 +84,7 @@ class Account:
|
||||
benchmark_config: dict = {},
|
||||
pos_type: str = "Position",
|
||||
port_metr_enabled: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
"""the trade account of backtest.
|
||||
|
||||
Parameters
|
||||
@@ -99,10 +105,10 @@ class Account:
|
||||
|
||||
self._pos_type = pos_type
|
||||
self._port_metr_enabled = port_metr_enabled
|
||||
self.benchmark_config = None # avoid no attribute error
|
||||
self.benchmark_config: dict = {} # avoid no attribute error
|
||||
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
|
||||
def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:
|
||||
# 1) the following variables are shared by multiple layers
|
||||
# - you will see a shallow copy instead of deepcopy in the NestedExecutor;
|
||||
self.init_cash = init_cash
|
||||
@@ -114,22 +120,22 @@ class Account:
|
||||
"position_dict": position_dict,
|
||||
},
|
||||
"module_path": "qlib.backtest.position",
|
||||
}
|
||||
},
|
||||
)
|
||||
self.accum_info = AccumulatedInfo()
|
||||
|
||||
# 2) following variables are not shared between layers
|
||||
self.portfolio_metrics = None
|
||||
self.hist_positions = {}
|
||||
self.portfolio_metrics: Optional[PortfolioMetrics] = None
|
||||
self.hist_positions: Dict[pd.Timestamp, BasePosition] = {}
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config)
|
||||
|
||||
def is_port_metr_enabled(self):
|
||||
def is_port_metr_enabled(self) -> bool:
|
||||
"""
|
||||
Is portfolio-based metrics enabled.
|
||||
"""
|
||||
return self._port_metr_enabled and not self.current_position.skip_update()
|
||||
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
def reset_report(self, freq: str, benchmark_config: dict) -> None:
|
||||
# portfolio related metrics
|
||||
if self.is_port_metr_enabled():
|
||||
# NOTE:
|
||||
@@ -140,13 +146,13 @@ class Account:
|
||||
# fill stock value
|
||||
# The frequency of account may not align with the trading frequency.
|
||||
# This may result in obscure bugs when data quality is low.
|
||||
if isinstance(self.benchmark_config, dict) and self.benchmark_config.get("start_time") is not None:
|
||||
if isinstance(self.benchmark_config, dict) and "start_time" in self.benchmark_config:
|
||||
self.current_position.fill_stock_value(self.benchmark_config["start_time"], self.freq)
|
||||
|
||||
# trading related metrics(e.g. high-frequency trading)
|
||||
self.indicator = Indicator()
|
||||
|
||||
def reset(self, freq=None, benchmark_config=None, port_metr_enabled: bool = None):
|
||||
def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabled: bool = None) -> None:
|
||||
"""reset freq and report of account
|
||||
|
||||
Parameters
|
||||
@@ -155,6 +161,7 @@ class Account:
|
||||
frequency of account & report, by default None
|
||||
benchmark_config : {}, optional
|
||||
benchmark config of report, by default None
|
||||
port_metr_enabled: bool
|
||||
"""
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
@@ -165,13 +172,13 @@ class Account:
|
||||
|
||||
self.reset_report(self.freq, self.benchmark_config)
|
||||
|
||||
def get_hist_positions(self):
|
||||
def get_hist_positions(self) -> Dict[pd.Timestamp, BasePosition]:
|
||||
return self.hist_positions
|
||||
|
||||
def get_cash(self):
|
||||
def get_cash(self) -> float:
|
||||
return self.current_position.get_cash()
|
||||
|
||||
def _update_state_from_order(self, order, trade_val, cost, trade_price):
|
||||
def _update_state_from_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
if self.is_port_metr_enabled():
|
||||
# update turnover
|
||||
self.accum_info.add_turnover(trade_val)
|
||||
@@ -191,13 +198,14 @@ class Account:
|
||||
profit = self.current_position.get_stock_price(order.stock_id) * trade_amount - trade_val
|
||||
self.accum_info.add_return_value(profit) # note here do not consider cost
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
if self.current_position.skip_update():
|
||||
# TODO: supporting polymorphism for account
|
||||
# updating order for infinite position is meaningless
|
||||
return
|
||||
|
||||
# if stock is sold out, no stock price information in Position, then we should update account first, then update current position
|
||||
# if stock is sold out, no stock price information in Position, then we should update account first,
|
||||
# then update current position
|
||||
# if stock is bought, there is no stock in current position, update current, then update account
|
||||
# The cost will be subtracted from the cash at last. So the trading logic can ignore the cost calculation
|
||||
if order.direction == Order.SELL:
|
||||
@@ -212,29 +220,40 @@ class Account:
|
||||
self.current_position.update_order(order, trade_val, cost, trade_price)
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_current_position(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock"""
|
||||
def update_current_position(
|
||||
self,
|
||||
trade_start_time: pd.Timestamp,
|
||||
trade_end_time: pd.Timestamp,
|
||||
trade_exchange: Exchange,
|
||||
) -> None:
|
||||
"""
|
||||
Update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock
|
||||
"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
|
||||
assert self.current_position is not None
|
||||
|
||||
if not self.current_position.skip_update():
|
||||
stock_list = self.current_position.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time))
|
||||
self.current_position.update_stock_price(stock_id=code, price=bar_close)
|
||||
# update holding day count
|
||||
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
|
||||
self.current_position.add_count_all(bar=self.freq)
|
||||
|
||||
def update_portfolio_metrics(self, trade_start_time, trade_end_time):
|
||||
def update_portfolio_metrics(self, trade_start_time: pd.Timestamp, trade_end_time: pd.Timestamp) -> None:
|
||||
"""update portfolio_metrics"""
|
||||
# calculate earning
|
||||
# account_value - last_account_value
|
||||
# for the first trade date, account_value - init_cash
|
||||
# self.portfolio_metrics.is_empty() to judge is_first_trade_date
|
||||
# get last_account_value, last_total_cost, last_total_turnover
|
||||
assert self.portfolio_metrics is not None
|
||||
|
||||
if self.portfolio_metrics.is_empty():
|
||||
last_account_value = self.init_cash
|
||||
last_total_cost = 0
|
||||
@@ -243,14 +262,16 @@ class Account:
|
||||
last_account_value = self.portfolio_metrics.get_latest_account_value()
|
||||
last_total_cost = self.portfolio_metrics.get_latest_total_cost()
|
||||
last_total_turnover = self.portfolio_metrics.get_latest_total_turnover()
|
||||
|
||||
# get now_account_value, now_stock_value, now_earning, now_cost, now_turnover
|
||||
now_account_value = self.current_position.calculate_value()
|
||||
now_stock_value = self.current_position.calculate_stock_value()
|
||||
now_earning = now_account_value - last_account_value
|
||||
now_cost = self.accum_info.get_cost - last_total_cost
|
||||
now_turnover = self.accum_info.get_turnover - last_total_turnover
|
||||
|
||||
# update portfolio_metrics for today
|
||||
# judge whether the the trading is begin.
|
||||
# judge whether the trading is begin.
|
||||
# and don't add init account state into portfolio_metrics, due to we don't have excess return in those days.
|
||||
self.portfolio_metrics.update_portfolio_metrics_record(
|
||||
trade_start_time=trade_start_time,
|
||||
@@ -267,7 +288,7 @@ class Account:
|
||||
stock_value=now_stock_value,
|
||||
)
|
||||
|
||||
def update_hist_positions(self, trade_start_time):
|
||||
def update_hist_positions(self, trade_start_time: pd.Timestamp) -> None:
|
||||
"""update history position"""
|
||||
now_account_value = self.current_position.calculate_value()
|
||||
# set now_account_value to position
|
||||
@@ -283,11 +304,11 @@ class Account:
|
||||
trade_exchange: Exchange,
|
||||
atomic: bool,
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_info: list = None,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
trade_info: list = [],
|
||||
inner_order_indicators: List[BaseOrderIndicator] = [],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
|
||||
indicator_config: dict = {},
|
||||
):
|
||||
) -> None:
|
||||
"""update trade indicators and order indicators in each bar end"""
|
||||
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
|
||||
|
||||
@@ -319,11 +340,11 @@ class Account:
|
||||
trade_exchange: Exchange,
|
||||
atomic: bool,
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_info: list = None,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
trade_info: list = [],
|
||||
inner_order_indicators: List[BaseOrderIndicator] = [],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
|
||||
indicator_config: dict = {},
|
||||
):
|
||||
) -> None:
|
||||
"""update account at each trading bar step
|
||||
|
||||
Parameters
|
||||
@@ -338,6 +359,8 @@ class Account:
|
||||
whether the trading executor is atomic, which means there is no higher-frequency trading executor inside it
|
||||
- if atomic is True, calculate the indicators with trade_info
|
||||
- else, aggregate indicators with inner indicators
|
||||
outer_trade_decision: BaseTradeDecision
|
||||
external trade decision
|
||||
trade_info : List[(Order, float, float, float)], optional
|
||||
trading information, by default None
|
||||
- necessary if atomic is True
|
||||
@@ -377,9 +400,10 @@ class Account:
|
||||
indicator_config=indicator_config,
|
||||
)
|
||||
|
||||
def get_portfolio_metrics(self):
|
||||
def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
|
||||
"""get the history portfolio_metrics and positions instance"""
|
||||
if self.is_port_metr_enabled():
|
||||
assert self.portfolio_metrics is not None
|
||||
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
|
||||
_positions = self.get_hist_positions()
|
||||
return _portfolio_metrics, _positions
|
||||
|
||||
@@ -2,17 +2,29 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import BaseTradeDecision
|
||||
from typing import TYPE_CHECKING
|
||||
from qlib.backtest.report import Indicator, PortfolioMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from ..utils.time import Freq
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ..utils.time import Freq
|
||||
|
||||
def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor):
|
||||
|
||||
def backtest_loop(
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
"""backtest function for the interaction of the outermost strategy and executor in the nested decision execution
|
||||
|
||||
please refer to the docs of `collect_data_loop`
|
||||
@@ -24,26 +36,33 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec
|
||||
indicator: Indicator
|
||||
it computes the trading indicator
|
||||
"""
|
||||
return_value = {}
|
||||
return_value: dict = {}
|
||||
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||
pass
|
||||
return return_value.get("portfolio_metrics"), return_value.get("indicator")
|
||||
|
||||
portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
|
||||
indicator = cast(Indicator, return_value.get("indicator"))
|
||||
return portfolio_metrics, indicator
|
||||
|
||||
|
||||
def collect_data_loop(
|
||||
start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None
|
||||
):
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
return_value: dict = None,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : pd.Timestamp|str
|
||||
start_time : Union[pd.Timestamp, str]
|
||||
closed start time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
end_time : pd.Timestamp|str
|
||||
end_time : Union[pd.Timestamp, str]
|
||||
closed end time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
trade_strategy : BaseStrategy
|
||||
the outermost portfolio strategy
|
||||
trade_executor : BaseExecutor
|
||||
|
||||
@@ -2,28 +2,33 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum
|
||||
from qlib.data.data import Cal
|
||||
from qlib.utils.time import concat_date_time, epsilon_change
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast
|
||||
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from qlib.data.data import Cal
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.time import concat_date_time, epsilon_change
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
import warnings
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar, Optional, Union, List, Set, Tuple
|
||||
|
||||
|
||||
DecisionType = TypeVar("DecisionType")
|
||||
|
||||
|
||||
class OrderDir(IntEnum):
|
||||
# Order direction
|
||||
# Order direction
|
||||
SELL = 0
|
||||
BUY = 1
|
||||
|
||||
@@ -47,7 +52,7 @@ class Order:
|
||||
# - they are set by users and is time-invariant.
|
||||
stock_id: str
|
||||
amount: float # `amount` is a non-negative and adjusted value
|
||||
direction: int
|
||||
direction: OrderDir
|
||||
|
||||
# 2) time variant values:
|
||||
# - Users may want to set these values when using lower level APIs
|
||||
@@ -62,8 +67,8 @@ class Order:
|
||||
# What the value should be about in all kinds of cases
|
||||
# - not tradable: the deal_amount == 0 , factor is None
|
||||
# - the stock is suspended and the entire order fails. No cost for this order
|
||||
# - dealed or partially dealed: deal_amount >= 0 and factor is not None
|
||||
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
|
||||
# - dealt or partially dealt: deal_amount >= 0 and factor is not None
|
||||
deal_amount: float = 0.0 # `deal_amount` is a non-negative value
|
||||
factor: Optional[float] = None
|
||||
|
||||
# TODO:
|
||||
@@ -75,10 +80,10 @@ class Order:
|
||||
SELL: ClassVar[OrderDir] = OrderDir.SELL
|
||||
BUY: ClassVar[OrderDir] = OrderDir.BUY
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.direction not in {Order.SELL, Order.BUY}:
|
||||
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
|
||||
self.deal_amount = 0
|
||||
self.deal_amount = 0.0
|
||||
self.factor = None
|
||||
|
||||
@property
|
||||
@@ -100,7 +105,7 @@ class Order:
|
||||
return self.deal_amount * self.sign
|
||||
|
||||
@property
|
||||
def sign(self) -> float:
|
||||
def sign(self) -> int:
|
||||
"""
|
||||
return the sign of trading
|
||||
- `+1` indicates buying
|
||||
@@ -113,15 +118,12 @@ class Order:
|
||||
if isinstance(direction, OrderDir):
|
||||
return direction
|
||||
elif isinstance(direction, (int, float, np.integer, np.floating)):
|
||||
if direction > 0:
|
||||
return Order.BUY
|
||||
else:
|
||||
return Order.SELL
|
||||
return Order.BUY if direction > 0 else Order.SELL
|
||||
elif isinstance(direction, str):
|
||||
dl = direction.lower()
|
||||
if dl.strip() == "sell":
|
||||
dl = direction.lower().strip()
|
||||
if dl == "sell":
|
||||
return OrderDir.SELL
|
||||
elif dl.strip() == "buy":
|
||||
elif dl == "buy":
|
||||
return OrderDir.BUY
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -139,14 +141,14 @@ class OrderHelper:
|
||||
Motivation
|
||||
- Make generating order easier
|
||||
- User may have no knowledge about the adjust-factor information about the system.
|
||||
- It involves to much interaction with the exchange when generating orders.
|
||||
- It involves too much interaction with the exchange when generating orders.
|
||||
"""
|
||||
|
||||
def __init__(self, exchange: Exchange):
|
||||
def __init__(self, exchange: Exchange) -> None:
|
||||
self.exchange = exchange
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
self,
|
||||
code: str,
|
||||
amount: float,
|
||||
direction: OrderDir,
|
||||
@@ -176,21 +178,18 @@ class OrderHelper:
|
||||
Order:
|
||||
The created order
|
||||
"""
|
||||
if start_time is not None:
|
||||
start_time = pd.Timestamp(start_time)
|
||||
if end_time is not None:
|
||||
end_time = pd.Timestamp(end_time)
|
||||
# NOTE: factor is a value belongs to the results section. User don't have to care about it when creating orders
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
start_time=start_time if start_time is not None else pd.Timestamp(start_time),
|
||||
end_time=end_time if end_time is not None else pd.Timestamp(end_time),
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
|
||||
class TradeRange:
|
||||
@abstractmethod
|
||||
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
|
||||
"""
|
||||
This method will be call with following way
|
||||
@@ -217,6 +216,7 @@ class TradeRange:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `__call__` method")
|
||||
|
||||
@abstractmethod
|
||||
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
"""
|
||||
Parameters
|
||||
@@ -235,23 +235,26 @@ class TradeRange:
|
||||
|
||||
|
||||
class IdxTradeRange(TradeRange):
|
||||
def __init__(self, start_idx: int, end_idx: int):
|
||||
def __init__(self, start_idx: int, end_idx: int) -> None:
|
||||
self._start_idx = start_idx
|
||||
self._end_idx = end_idx
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
|
||||
return self._start_idx, self._end_idx
|
||||
|
||||
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TradeRangeByTime(TradeRange):
|
||||
"""This is a helper function for make decisions"""
|
||||
|
||||
def __init__(self, start_time: str, end_time: str):
|
||||
def __init__(self, start_time: str, end_time: str) -> None:
|
||||
"""
|
||||
This is a callable class.
|
||||
|
||||
**NOTE**:
|
||||
- It is designed for minute-bar for intraday trading!!!!!
|
||||
- It is designed for minute-bar for intra-day trading!!!!!
|
||||
- Both start_time and end_time are **closed** in the range
|
||||
|
||||
Parameters
|
||||
@@ -265,26 +268,25 @@ class TradeRangeByTime(TradeRange):
|
||||
self.end_time = pd.Timestamp(end_time).time()
|
||||
assert self.start_time < self.end_time
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
|
||||
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
|
||||
if trade_calendar is None:
|
||||
raise NotImplementedError("trade_calendar is necessary for getting TradeRangeByTime.")
|
||||
start = trade_calendar.start_time
|
||||
val_start, val_end = concat_date_time(start.date(), self.start_time), concat_date_time(
|
||||
start.date(), self.end_time
|
||||
)
|
||||
|
||||
start_date = trade_calendar.start_time.date()
|
||||
val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time)
|
||||
return trade_calendar.get_range_idx(val_start, val_end)
|
||||
|
||||
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
start_date = start_time.date()
|
||||
val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time)
|
||||
# NOTE: `end_date` should not be used. Because the `end_date` is for slicing. It may be in the next day
|
||||
# Assumption: start_time and end_time is for intraday trading. So it is OK for only using start_date
|
||||
# Assumption: start_time and end_time is for intra-day trading. So it is OK for only using start_date
|
||||
return max(val_start, start_time), min(val_end, end_time)
|
||||
|
||||
|
||||
class BaseTradeDecision:
|
||||
class BaseTradeDecision(Generic[DecisionType]):
|
||||
"""
|
||||
Trade decisions ara made by strategy and executed by exeuter
|
||||
Trade decisions ara made by strategy and executed by executor
|
||||
|
||||
Motivation:
|
||||
Here are several typical scenarios for `BaseTradeDecision`
|
||||
@@ -298,7 +300,7 @@ class BaseTradeDecision:
|
||||
2. Same as `case 1.3`
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None):
|
||||
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -317,20 +319,21 @@ class BaseTradeDecision:
|
||||
"""
|
||||
self.strategy = strategy
|
||||
self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
|
||||
self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
|
||||
if isinstance(trade_range, Tuple):
|
||||
# upper strategy has no knowledge about the sub executor before `_init_sub_trading`
|
||||
self.total_step: Optional[int] = None
|
||||
if isinstance(trade_range, tuple):
|
||||
# for Tuple[int, int]
|
||||
trade_range = IdxTradeRange(*trade_range)
|
||||
self.trade_range: TradeRange = trade_range
|
||||
self.trade_range: Optional[TradeRange] = trade_range
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
def get_decision(self) -> List[DecisionType]:
|
||||
"""
|
||||
get the **concrete decision** (e.g. execution orders)
|
||||
This will be called by the inner strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[object]:
|
||||
List[DecisionType:
|
||||
The decision result. Typically it is some orders
|
||||
Example:
|
||||
[]:
|
||||
@@ -340,7 +343,7 @@ class BaseTradeDecision:
|
||||
"""
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]:
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> Optional[BaseTradeDecision]:
|
||||
"""
|
||||
Be called at the **start** of each step.
|
||||
|
||||
@@ -355,10 +358,8 @@ class BaseTradeDecision:
|
||||
|
||||
Returns
|
||||
-------
|
||||
None:
|
||||
No update, use previous decision(or unavailable)
|
||||
BaseTradeDecision:
|
||||
New update, use new decision
|
||||
New update, use new decision. If no updates, return None (use previous decision (or unavailable))
|
||||
"""
|
||||
# purpose 1)
|
||||
self.total_step = trade_calendar.get_trade_len()
|
||||
@@ -366,13 +367,13 @@ class BaseTradeDecision:
|
||||
# purpose 2)
|
||||
return self.strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
def _get_range_limit(self, **kwargs) -> Tuple[int, int]:
|
||||
def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
|
||||
if self.trade_range is not None:
|
||||
return self.trade_range(trade_calendar=kwargs.get("inner_calendar"))
|
||||
return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar")))
|
||||
else:
|
||||
raise NotImplementedError("The decision didn't provide an index range")
|
||||
|
||||
def get_range_limit(self, **kwargs) -> Tuple[int, int]:
|
||||
def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
|
||||
"""
|
||||
return the expected step range for limiting the decision execution time
|
||||
Both left and right are **closed**
|
||||
@@ -413,21 +414,22 @@ class BaseTradeDecision:
|
||||
"""
|
||||
try:
|
||||
_start_idx, _end_idx = self._get_range_limit(**kwargs)
|
||||
except NotImplementedError:
|
||||
except NotImplementedError as e:
|
||||
if "default_value" in kwargs:
|
||||
return kwargs["default_value"]
|
||||
else:
|
||||
# Default to get full index
|
||||
raise NotImplementedError(f"The decision didn't provide an index range")
|
||||
raise NotImplementedError(f"The decision didn't provide an index range") from e
|
||||
|
||||
# clip index
|
||||
if getattr(self, "total_step", None) is not None:
|
||||
# if `self.update` is called.
|
||||
# Then the _start_idx, _end_idx should be clipped
|
||||
assert self.total_step is not None
|
||||
if _start_idx < 0 or _end_idx >= self.total_step:
|
||||
logger = get_module_logger("decision")
|
||||
logger.warning(
|
||||
f"[{_start_idx},{_end_idx}] go beyoud the total_step({self.total_step}), it will be clipped"
|
||||
f"[{_start_idx},{_end_idx}] go beyond the total_step({self.total_step}), it will be clipped.",
|
||||
)
|
||||
_start_idx, _end_idx = max(0, _start_idx), min(self.total_step - 1, _end_idx)
|
||||
return _start_idx, _end_idx
|
||||
@@ -445,7 +447,7 @@ class BaseTradeDecision:
|
||||
Parameters
|
||||
----------
|
||||
rtype: str
|
||||
- "full": return the full limitation of the deicsion in the day
|
||||
- "full": return the full limitation of the decision in the day
|
||||
- "step": return the limitation of current step
|
||||
|
||||
raise_error: bool
|
||||
@@ -498,11 +500,10 @@ class BaseTradeDecision:
|
||||
return True
|
||||
return True
|
||||
|
||||
def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision):
|
||||
def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision) -> None:
|
||||
"""
|
||||
|
||||
This method will be called on the inner_trade_decision after it is generated.
|
||||
`inner_trade_decision` will be changed **inplaced**.
|
||||
`inner_trade_decision` will be changed **inplace**.
|
||||
|
||||
Motivation of the `mod_inner_decision`
|
||||
- Leave a hook for outer decision to affect the decision generated by the inner strategy
|
||||
@@ -520,29 +521,38 @@ class BaseTradeDecision:
|
||||
inner_trade_decision.trade_range = self.trade_range
|
||||
|
||||
|
||||
class EmptyTradeDecision(BaseTradeDecision):
|
||||
class EmptyTradeDecision(BaseTradeDecision[object]):
|
||||
def get_decision(self) -> List[object]:
|
||||
return []
|
||||
|
||||
def empty(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TradeDecisionWO(BaseTradeDecision):
|
||||
class TradeDecisionWO(BaseTradeDecision[Order]):
|
||||
"""
|
||||
Trade Decision (W)ith (O)rder.
|
||||
Besides, the time_range is also included.
|
||||
"""
|
||||
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
|
||||
def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None:
|
||||
super().__init__(strategy, trade_range=trade_range)
|
||||
self.order_list = order_list
|
||||
self.order_list = cast(List[Order], order_list)
|
||||
start, end = strategy.trade_calendar.get_step_time()
|
||||
for o in order_list:
|
||||
assert isinstance(o, Order)
|
||||
if o.start_time is None:
|
||||
o.start_time = start
|
||||
if o.end_time is None:
|
||||
o.end_time = end
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
def get_decision(self) -> List[Order]:
|
||||
return self.order_list
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"class: {self.__class__.__name__}; strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]"
|
||||
return (
|
||||
f"class: {self.__class__.__name__}; "
|
||||
f"strategy: {self.strategy}; "
|
||||
f"trade_range: {self.trade_range}; "
|
||||
f"order_list[{len(self.order_list)}]"
|
||||
)
|
||||
|
||||
@@ -1,45 +1,49 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from ..utils.index_data import IndexData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .account import Account
|
||||
|
||||
from qlib.backtest.position import BasePosition, Position
|
||||
import random
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..data.data import D
|
||||
from qlib.backtest.position import BasePosition
|
||||
|
||||
from ..config import C
|
||||
from ..constant import REG_CN
|
||||
from ..data.data import D
|
||||
from ..log import get_module_logger
|
||||
from .decision import Order, OrderDir, OrderHelper
|
||||
from .high_performance_ds import BaseQuote, PandasQuote, NumpyQuote
|
||||
from .high_performance_ds import BaseQuote, NumpyQuote
|
||||
|
||||
|
||||
class Exchange:
|
||||
def __init__(
|
||||
self,
|
||||
freq="day",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes="all",
|
||||
freq: str = "day",
|
||||
start_time: Union[pd.Timestamp, str] = None,
|
||||
end_time: Union[pd.Timestamp, str] = None,
|
||||
codes: Union[list, str] = "all",
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
subscribe_fields=[],
|
||||
subscribe_fields: list = [],
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
volume_threshold=None,
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5,
|
||||
impact_cost=0.0,
|
||||
extra_quote=None,
|
||||
quote_cls=NumpyQuote,
|
||||
**kwargs,
|
||||
):
|
||||
volume_threshold: Union[tuple, dict] = None,
|
||||
open_cost: float = 0.0015,
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
impact_cost: float = 0.0,
|
||||
extra_quote: pd.DataFrame = None,
|
||||
quote_cls: Type[BaseQuote] = NumpyQuote,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""__init__
|
||||
:param freq: frequency of data
|
||||
:param start_time: closed start time for backtest
|
||||
@@ -72,11 +76,12 @@ class Exchange:
|
||||
]
|
||||
1) ("cum" or "current", limit_str) denotes a single volume limit.
|
||||
- limit_str is qlib data expression which is allowed to define your own Operator.
|
||||
Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for high frequency,
|
||||
such as DayCumsum. !!!NOTE: if you want you use the custom operator, you need to
|
||||
register it in qlib_init.
|
||||
- "cum" means that this is a cumulative value over time, such as cumulative market volume.
|
||||
So when it is used as a volume limit, it is necessary to subtract the dealt amount.
|
||||
Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for
|
||||
high frequency, such as DayCumsum. !!!NOTE: if you want you use the custom
|
||||
operator, you need to register it in qlib_init.
|
||||
- "cum" means that this is a cumulative value over time, such as cumulative market
|
||||
volume. So when it is used as a volume limit, it is necessary to subtract the dealt
|
||||
amount.
|
||||
- "current" means that this is a real-time value and will not accumulate over time,
|
||||
so it can be directly used as a capacity limit.
|
||||
e.g. ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"), ("current", "$bidV1")
|
||||
@@ -84,7 +89,7 @@ class Exchange:
|
||||
"buy" means the volume limits of buying. "sell" means the volume limits of selling.
|
||||
Different volume limits will be aggregated with min(). If volume_threshold is only
|
||||
("cum" or "current", limit_str) instead of a dict, the volume limits are for
|
||||
both by deault. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
|
||||
both by default. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
|
||||
3) e.g. "volume_threshold": {
|
||||
"all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"),
|
||||
"buy": ("current", "$askV1"),
|
||||
@@ -104,13 +109,14 @@ class Exchange:
|
||||
Necessary fields:
|
||||
$close is for calculating the total value at end of each day.
|
||||
Optional fields:
|
||||
$volume is only necessary when we limit the trade amount or calculate PA(vwap) indicator
|
||||
$volume is only necessary when we limit the trade amount or calculate
|
||||
PA(vwap) indicator
|
||||
$vwap is only necessary when we use the $vwap price as the deal price
|
||||
$factor is for rounding to the trading unit
|
||||
limit_sell will be set to False by default(False indicates we can sell this
|
||||
target on this day).
|
||||
limit_buy will be set to False by default(False indicates we can buy this
|
||||
target on this day).
|
||||
limit_sell will be set to False by default (False indicates we can sell
|
||||
this target on this day).
|
||||
limit_buy will be set to False by default (False indicates we can buy
|
||||
this target on this day).
|
||||
index: MultipleIndex(instrument, pd.Datetime)
|
||||
"""
|
||||
self.freq = freq
|
||||
@@ -135,7 +141,7 @@ class Exchange:
|
||||
if limit_threshold is None:
|
||||
if C.region == REG_CN:
|
||||
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
|
||||
elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1:
|
||||
elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1:
|
||||
if C.region == REG_CN:
|
||||
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
|
||||
|
||||
@@ -144,7 +150,7 @@ class Exchange:
|
||||
deal_price = "$" + deal_price
|
||||
self.buy_price = self.sell_price = deal_price
|
||||
elif isinstance(deal_price, (tuple, list)):
|
||||
self.buy_price, self.sell_price = deal_price
|
||||
self.buy_price, self.sell_price = cast(Tuple[str, str], deal_price)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@@ -161,10 +167,10 @@ class Exchange:
|
||||
|
||||
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
||||
if self.limit_type == self.LT_TP_EXP:
|
||||
assert isinstance(limit_threshold, tuple)
|
||||
for exp in limit_threshold:
|
||||
necessary_fields.add(exp)
|
||||
all_fields = necessary_fields | vol_lt_fields
|
||||
all_fields = list(all_fields | set(subscribe_fields))
|
||||
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
|
||||
|
||||
self.all_fields = all_fields
|
||||
|
||||
@@ -182,17 +188,22 @@ class Exchange:
|
||||
self.quote_cls = quote_cls
|
||||
self.quote: BaseQuote = self.quote_cls(self.quote_df, freq)
|
||||
|
||||
def get_quote_from_qlib(self):
|
||||
def get_quote_from_qlib(self) -> None:
|
||||
# get stock data from qlib
|
||||
if len(self.codes) == 0:
|
||||
self.codes = D.instruments()
|
||||
self.quote_df = D.features(
|
||||
self.codes, self.all_fields, self.start_time, self.end_time, freq=self.freq, disk_cache=True
|
||||
self.codes,
|
||||
self.all_fields,
|
||||
self.start_time,
|
||||
self.end_time,
|
||||
freq=self.freq,
|
||||
disk_cache=True,
|
||||
).dropna(subset=["$close"])
|
||||
self.quote_df.columns = self.all_fields
|
||||
|
||||
# check buy_price data and sell_price data
|
||||
for attr in "buy_price", "sell_price":
|
||||
for attr in ("buy_price", "sell_price"):
|
||||
pstr = getattr(self, attr) # price string
|
||||
if self.quote_df[pstr].isna().any():
|
||||
self.logger.warning("{} field data contains nan.".format(pstr))
|
||||
@@ -238,9 +249,9 @@ class Exchange:
|
||||
LT_FLT = "float" # float
|
||||
LT_NONE = "none" # none
|
||||
|
||||
def _get_limit_type(self, limit_threshold):
|
||||
def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
|
||||
"""get limit type"""
|
||||
if isinstance(limit_threshold, Tuple):
|
||||
if isinstance(limit_threshold, tuple):
|
||||
return self.LT_TP_EXP
|
||||
elif isinstance(limit_threshold, float):
|
||||
return self.LT_FLT
|
||||
@@ -249,7 +260,7 @@ class Exchange:
|
||||
else:
|
||||
raise NotImplementedError(f"This type of `limit_threshold` is not supported")
|
||||
|
||||
def _update_limit(self, limit_threshold):
|
||||
def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None:
|
||||
# check limit_threshold
|
||||
limit_type = self._get_limit_type(limit_threshold)
|
||||
if limit_type == self.LT_NONE:
|
||||
@@ -257,15 +268,18 @@ class Exchange:
|
||||
self.quote_df["limit_sell"] = False
|
||||
elif limit_type == self.LT_TP_EXP:
|
||||
# set limit
|
||||
limit_threshold = cast(tuple, limit_threshold)
|
||||
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
|
||||
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
|
||||
elif limit_type == self.LT_FLT:
|
||||
limit_threshold = cast(float, limit_threshold)
|
||||
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
|
||||
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
|
||||
|
||||
def _get_vol_limit(self, volume_threshold):
|
||||
@staticmethod
|
||||
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
|
||||
"""
|
||||
preproccess the volume limit.
|
||||
preprocess the volume limit.
|
||||
get the fields need to get from qlib.
|
||||
get the volume limit list of buying and selling which is composed of all limits.
|
||||
Parameters
|
||||
@@ -295,8 +309,7 @@ class Exchange:
|
||||
volume_threshold = {"all": volume_threshold}
|
||||
|
||||
assert isinstance(volume_threshold, dict)
|
||||
for key in volume_threshold:
|
||||
vol_limit = volume_threshold[key]
|
||||
for key, vol_limit in volume_threshold.items():
|
||||
assert isinstance(vol_limit, tuple)
|
||||
fields.add(vol_limit[1])
|
||||
|
||||
@@ -307,10 +320,19 @@ class Exchange:
|
||||
|
||||
return buy_vol_limit, sell_vol_limit, fields
|
||||
|
||||
def check_stock_limit(self, stock_id, start_time, end_time, direction=None):
|
||||
def check_stock_limit(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: int = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
stock_id : str
|
||||
start_time: pd.Timestamp
|
||||
end_time: pd.Timestamp
|
||||
direction : int, optional
|
||||
trade direction, by default None
|
||||
- if direction is None, check if tradable for buying and selling.
|
||||
@@ -320,47 +342,50 @@ class Exchange:
|
||||
if direction is None:
|
||||
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
||||
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
|
||||
return buy_limit or sell_limit
|
||||
return bool(buy_limit or sell_limit)
|
||||
elif direction == Order.BUY:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
||||
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all"))
|
||||
elif direction == Order.SELL:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
|
||||
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all"))
|
||||
else:
|
||||
raise ValueError(f"direction {direction} is not supported!")
|
||||
|
||||
def check_stock_suspended(self, stock_id, start_time, end_time):
|
||||
def check_stock_suspended(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> bool:
|
||||
# is suspended
|
||||
if stock_id in self.quote.get_all_stock():
|
||||
return self.quote.get_data(stock_id, start_time, end_time, "$close") is None
|
||||
else:
|
||||
return True
|
||||
|
||||
def is_stock_tradable(self, stock_id, start_time, end_time, direction=None):
|
||||
def is_stock_tradable(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: int = None,
|
||||
) -> bool:
|
||||
# check if stock can be traded
|
||||
# same as check in check_order
|
||||
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(
|
||||
stock_id, start_time, end_time, direction
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return not (
|
||||
self.check_stock_suspended(stock_id, start_time, end_time)
|
||||
or self.check_stock_limit(stock_id, start_time, end_time, direction)
|
||||
)
|
||||
|
||||
def check_order(self, order):
|
||||
def check_order(self, order: Order) -> bool:
|
||||
# check limit and suspended
|
||||
if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit(
|
||||
order.stock_id, order.start_time, order.end_time, order.direction
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return self.is_stock_tradable(order.stock_id, order.start_time, order.end_time, order.direction)
|
||||
|
||||
def deal_order(
|
||||
self,
|
||||
order,
|
||||
order: Order,
|
||||
trade_account: Account = None,
|
||||
position: BasePosition = None,
|
||||
dealt_order_amount: defaultdict = defaultdict(float),
|
||||
):
|
||||
dealt_order_amount: Dict[str, float] = defaultdict(float),
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Deal order when the actual transaction
|
||||
the results section in `Order` will be changed.
|
||||
@@ -371,9 +396,9 @@ class Exchange:
|
||||
:return: trade_val, trade_cost, trade_price
|
||||
"""
|
||||
# check order first.
|
||||
if self.check_order(order) is False:
|
||||
if not self.check_order(order):
|
||||
order.deal_amount = 0.0
|
||||
# using np.nan instead of None to make it more convenient to should the value in format string
|
||||
# using np.nan instead of None to make it more convenient to show the value in format string
|
||||
self.logger.debug(f"Order failed due to trading limitation: {order}")
|
||||
return 0.0, 0.0, np.nan
|
||||
|
||||
@@ -382,7 +407,9 @@ class Exchange:
|
||||
|
||||
# NOTE: order will be changed in this function
|
||||
trade_price, trade_val, trade_cost = self._calc_trade_info_by_order(
|
||||
order, trade_account.current_position if trade_account else position, dealt_order_amount
|
||||
order,
|
||||
trade_account.current_position if trade_account else position,
|
||||
dealt_order_amount,
|
||||
)
|
||||
if trade_val > 1e-5:
|
||||
# If the order can only be deal 0 value. Nothing to be updated
|
||||
@@ -396,35 +423,67 @@ class Exchange:
|
||||
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time, method="ts_data_last"):
|
||||
return self.quote.get_data(stock_id, start_time, end_time, method=method)
|
||||
def get_quote_info(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
field: str,
|
||||
method: str = "ts_data_last",
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field=field, method=method)
|
||||
|
||||
def get_close(self, stock_id, start_time, end_time, method="ts_data_last"):
|
||||
def get_close(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: str = "ts_data_last",
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method)
|
||||
|
||||
def get_volume(self, stock_id, start_time, end_time, method="sum"):
|
||||
def get_volume(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: Optional[str] = "sum",
|
||||
) -> float:
|
||||
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
|
||||
return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method))
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method="ts_data_last"):
|
||||
def get_deal_price(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: OrderDir,
|
||||
method: Optional[str] = "ts_data_last",
|
||||
) -> float:
|
||||
if direction == OrderDir.SELL:
|
||||
pstr = self.sell_price
|
||||
elif direction == OrderDir.BUY:
|
||||
pstr = self.buy_price
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
deal_price = self.quote.get_data(stock_id, start_time, end_time, field=pstr, method=method)
|
||||
if method is not None and (deal_price is None or np.isnan(deal_price) or deal_price <= 1e-08):
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
|
||||
self.logger.warning(f"setting deal_price to close price")
|
||||
deal_price = self.get_close(stock_id, start_time, end_time, method)
|
||||
return deal_price
|
||||
return cast(float, deal_price)
|
||||
|
||||
def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]:
|
||||
def get_factor(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
Union[float, None]:
|
||||
Optional[float]:
|
||||
`None`: if the stock is suspended `None` may be returned
|
||||
`float`: return factor if the factor exists
|
||||
"""
|
||||
@@ -434,11 +493,16 @@ class Exchange:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method="ts_data_last")
|
||||
|
||||
def generate_amount_position_from_weight_position(
|
||||
self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
|
||||
):
|
||||
self,
|
||||
weight_position: dict,
|
||||
cash: float,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: OrderDir = OrderDir.BUY,
|
||||
) -> dict:
|
||||
"""
|
||||
The generate the target position according to the weight and the cash.
|
||||
NOTE: All the cash will assigned to the tadable stock.
|
||||
NOTE: All the cash will assigned to the tradable stock.
|
||||
Parameter:
|
||||
weight_position : dict {stock_id : weight}; allocate cash by weight_position
|
||||
among then, weight must be in this range: 0 < weight < 1
|
||||
@@ -451,15 +515,14 @@ class Exchange:
|
||||
|
||||
# calculate the total weight of tradable value
|
||||
tradable_weight = 0.0
|
||||
for stock_id in weight_position:
|
||||
for stock_id, wp in weight_position.items():
|
||||
if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
|
||||
# weight_position must be greater than 0 and less than 1
|
||||
if weight_position[stock_id] < 0 or weight_position[stock_id] > 1:
|
||||
if wp < 0 or wp > 1:
|
||||
raise ValueError(
|
||||
"weight_position is {}, "
|
||||
"weight_position is not in the range of (0, 1).".format(weight_position[stock_id])
|
||||
"weight_position is {}, " "weight_position is not in the range of (0, 1).".format(wp),
|
||||
)
|
||||
tradable_weight += weight_position[stock_id]
|
||||
tradable_weight += wp
|
||||
|
||||
if tradable_weight - 1.0 >= 1e-5:
|
||||
raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight))
|
||||
@@ -467,19 +530,24 @@ class Exchange:
|
||||
amount_dict = {}
|
||||
for stock_id in weight_position:
|
||||
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
):
|
||||
amount_dict[stock_id] = (
|
||||
cash
|
||||
* weight_position[stock_id]
|
||||
/ tradable_weight
|
||||
// self.get_deal_price(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
)
|
||||
)
|
||||
return amount_dict
|
||||
|
||||
def get_real_deal_amount(self, current_amount, target_amount, factor):
|
||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
|
||||
"""
|
||||
Calculate the real adjust deal amount when considering the trading unit
|
||||
:param current_amount:
|
||||
@@ -501,7 +569,13 @@ class Exchange:
|
||||
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
|
||||
return -deal_amount
|
||||
|
||||
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
|
||||
def generate_order_for_target_amount_position(
|
||||
self,
|
||||
target_position: dict,
|
||||
current_position: dict,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> List[Order]:
|
||||
"""
|
||||
Note: some future information is used in this function
|
||||
Parameter:
|
||||
@@ -517,7 +591,8 @@ class Exchange:
|
||||
# three parts: kept stock_id, dropped stock_id, new stock_id
|
||||
# handle kept stock_id
|
||||
|
||||
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different;
|
||||
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest
|
||||
# results of the same parameter are different;
|
||||
# so here we sort stock_id, and then randomly shuffle the order of stock_id
|
||||
# because the same random seed is used, the final stock_id order is fixed
|
||||
sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys())))
|
||||
@@ -546,7 +621,7 @@ class Exchange:
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
factor=factor,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
# sell stock
|
||||
@@ -558,14 +633,19 @@ class Exchange:
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
factor=factor,
|
||||
)
|
||||
),
|
||||
)
|
||||
# return order_list : buy + sell
|
||||
return sell_order_list + buy_order_list
|
||||
|
||||
def calculate_amount_position_value(
|
||||
self, amount_dict, start_time, end_time, only_tradable=False, direction=OrderDir.SELL
|
||||
):
|
||||
self,
|
||||
amount_dict: dict,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
only_tradable: bool = False,
|
||||
direction: OrderDir = OrderDir.SELL,
|
||||
) -> float:
|
||||
"""Parameter
|
||||
position : Position()
|
||||
amount_dict : {stock_id : amount}
|
||||
@@ -576,30 +656,44 @@ class Exchange:
|
||||
"""
|
||||
value = 0
|
||||
for stock_id in amount_dict:
|
||||
if (
|
||||
only_tradable is True
|
||||
and self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
or only_tradable is False
|
||||
if not only_tradable or (
|
||||
not self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
and not self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
):
|
||||
value += (
|
||||
self.get_deal_price(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
)
|
||||
* amount_dict[stock_id]
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_factor_or_raise_error(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
|
||||
def _get_factor_or_raise_error(
|
||||
self,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> float:
|
||||
"""Please refer to the docs of get_amount_of_trade_unit"""
|
||||
if factor is None:
|
||||
if stock_id is not None and start_time is not None and end_time is not None:
|
||||
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
else:
|
||||
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
|
||||
assert factor is not None
|
||||
return factor
|
||||
|
||||
def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
|
||||
def get_amount_of_trade_unit(
|
||||
self,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
get the trade unit of amount based on **factor**
|
||||
the factor can be given directly or calculated in given time range and stock id.
|
||||
@@ -617,15 +711,23 @@ class Exchange:
|
||||
"""
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
factor = self._get_factor_or_raise_error(
|
||||
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
factor=factor,
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
return self.trade_unit / factor
|
||||
else:
|
||||
return None
|
||||
|
||||
def round_amount_by_trade_unit(
|
||||
self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None
|
||||
):
|
||||
self,
|
||||
deal_amount: float,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> float:
|
||||
"""Parameter
|
||||
Please refer to the docs of get_amount_of_trade_unit
|
||||
deal_amount : float, adjusted amount
|
||||
@@ -635,12 +737,15 @@ class Exchange:
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
# the minimal amount is 1. Add 0.1 for solving precision problem.
|
||||
factor = self._get_factor_or_raise_error(
|
||||
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
factor=factor,
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
|
||||
return deal_amount
|
||||
|
||||
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int:
|
||||
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Optional[float]:
|
||||
"""parse the capacity limit string and return the actual amount of orders that can be executed.
|
||||
NOTE:
|
||||
this function will change the order.deal_amount **inplace**
|
||||
@@ -652,15 +757,12 @@ class Exchange:
|
||||
dealt_order_amount : dict
|
||||
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
|
||||
"""
|
||||
if order.direction == Order.BUY:
|
||||
vol_limit = self.buy_vol_limit
|
||||
elif order.direction == Order.SELL:
|
||||
vol_limit = self.sell_vol_limit
|
||||
vol_limit = self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit
|
||||
|
||||
if vol_limit is None:
|
||||
return order.deal_amount
|
||||
|
||||
vol_limit_num = []
|
||||
vol_limit_num: List[float] = []
|
||||
for limit in vol_limit:
|
||||
assert isinstance(limit, tuple)
|
||||
if limit[0] == "current":
|
||||
@@ -671,7 +773,7 @@ class Exchange:
|
||||
field=limit[1],
|
||||
method="sum",
|
||||
)
|
||||
vol_limit_num.append(limit_value)
|
||||
vol_limit_num.append(cast(float, limit_value))
|
||||
elif limit[0] == "cum":
|
||||
limit_value = self.quote.get_data(
|
||||
order.stock_id,
|
||||
@@ -689,12 +791,14 @@ class Exchange:
|
||||
if vol_limit_min < orig_deal_amount:
|
||||
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")
|
||||
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
|
||||
return None
|
||||
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price: float, cash: float, cost_ratio: float) -> float:
|
||||
"""return the real order amount after cash limit for buying.
|
||||
Parameters
|
||||
----------
|
||||
trade_price : float
|
||||
position : cash
|
||||
cash : float
|
||||
cost_ratio : float
|
||||
|
||||
Return
|
||||
@@ -702,7 +806,7 @@ class Exchange:
|
||||
float
|
||||
the real order amount after cash limit for buying.
|
||||
"""
|
||||
max_trade_amount = 0
|
||||
max_trade_amount = 0.0
|
||||
if cash >= self.min_cost:
|
||||
# critical_price means the stock transaction price when the service fee is equal to min_cost.
|
||||
critical_price = self.min_cost / cost_ratio + self.min_cost
|
||||
@@ -714,7 +818,12 @@ class Exchange:
|
||||
max_trade_amount = (cash - self.min_cost) / trade_price
|
||||
return max_trade_amount
|
||||
|
||||
def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount):
|
||||
def _calc_trade_info_by_order(
|
||||
self,
|
||||
order: Order,
|
||||
position: Optional[BasePosition],
|
||||
dealt_order_amount: dict,
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Calculation of trade info
|
||||
**NOTE**: Order will be changed in this function
|
||||
@@ -753,7 +862,8 @@ class Exchange:
|
||||
if not np.isclose(order.deal_amount, current_amount):
|
||||
# when not selling last stock. rounding is necessary
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(current_amount, order.deal_amount), order.factor
|
||||
min(current_amount, order.deal_amount),
|
||||
order.factor,
|
||||
)
|
||||
|
||||
# in case of negative value of cash
|
||||
@@ -778,7 +888,8 @@ class Exchange:
|
||||
# The money is not enough
|
||||
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(max_buy_amount, order.deal_amount), order.factor
|
||||
min(max_buy_amount, order.deal_amount),
|
||||
order.factor,
|
||||
)
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
else:
|
||||
@@ -789,7 +900,7 @@ class Exchange:
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("order type {} error".format(order.type))
|
||||
raise NotImplementedError("order direction {} error".format(order.direction))
|
||||
|
||||
trade_val = order.deal_amount * trade_price
|
||||
trade_cost = max(trade_val * cost_ratio, self.min_cost)
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
from abc import abstractclassmethod, abstractmethod
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from types import GeneratorType
|
||||
from typing import Any, Dict, Generator, List, Tuple, Union, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.account import Account
|
||||
from qlib.backtest.position import BasePosition
|
||||
from qlib.log import get_module_logger
|
||||
from types import GeneratorType
|
||||
from qlib.backtest.account import Account
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Union
|
||||
from collections import defaultdict
|
||||
|
||||
from qlib.backtest.report import Indicator
|
||||
|
||||
from .decision import EmptyTradeDecision, Order, BaseTradeDecision
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..utils import init_instance_by_config
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, get_start_end_idx
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
@@ -34,9 +33,9 @@ class BaseExecutor:
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
settle_type=BasePosition.ST_NO,
|
||||
**kwargs,
|
||||
):
|
||||
settle_type: str = BasePosition.ST_NO,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -57,15 +56,21 @@ class BaseExecutor:
|
||||
- 'base_price': the based price than which the trading price is advanced, Optional, default by 'twap'
|
||||
- If 'base_price' is 'twap', the based price is the time weighted average price
|
||||
- If 'base_price' is 'vwap', the based price is the volume weighted average price
|
||||
- 'weight_method': weighted method when calculating total trading pa by different orders' pa in each step, optional, default by 'mean'
|
||||
- 'weight_method': weighted method when calculating total trading pa by different orders' pa in each
|
||||
step, optional, default by 'mean'
|
||||
- If 'weight_method' is 'mean', calculating mean value of different orders' pa
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' pa
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' pa
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different
|
||||
orders' pa
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different
|
||||
orders' pa
|
||||
- 'ffr_config': config for calculating fulfill rate(ffr), optional
|
||||
- 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each step, optional, default by 'mean'
|
||||
- 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each
|
||||
step, optional, default by 'mean'
|
||||
- If 'weight_method' is 'mean', calculating mean value of different orders' ffr
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' ffr
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' ffr
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different
|
||||
orders' ffr
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different
|
||||
orders' ffr
|
||||
Example:
|
||||
{
|
||||
'show_indicator': True,
|
||||
@@ -83,7 +88,8 @@ class BaseExecutor:
|
||||
whether to print trading info, by default False
|
||||
track_data : bool, optional
|
||||
whether to generate trade_decision, will be used when training rl agent
|
||||
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
|
||||
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will
|
||||
be generated by `collect_data`
|
||||
- Else, `trade_decision` will not be generated
|
||||
|
||||
trade_exchange : Exchange
|
||||
@@ -115,10 +121,10 @@ class BaseExecutor:
|
||||
get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")
|
||||
|
||||
# record deal order amount in one day
|
||||
self.dealt_order_amount = defaultdict(float)
|
||||
self.dealt_order_amount: Dict[str, float] = defaultdict(float)
|
||||
self.deal_day = None
|
||||
|
||||
def reset_common_infra(self, common_infra, copy_trade_account=False):
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset trade_account
|
||||
@@ -129,14 +135,15 @@ class BaseExecutor:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
if common_infra.has("trade_account"):
|
||||
if copy_trade_account:
|
||||
# NOTE: there is a trick in the code.
|
||||
# shallow copy is used instead of deepcopy.
|
||||
# 1. So positions are shared
|
||||
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
|
||||
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
|
||||
else:
|
||||
self.trade_account = common_infra.get("trade_account")
|
||||
# NOTE: there is a trick in the code.
|
||||
# shallow copy is used instead of deepcopy.
|
||||
# 1. So positions are shared
|
||||
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
|
||||
self.trade_account: Account = (
|
||||
copy.copy(common_infra.get("trade_account"))
|
||||
if copy_trade_account
|
||||
else common_infra.get("trade_account")
|
||||
)
|
||||
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
|
||||
|
||||
@property
|
||||
@@ -152,7 +159,7 @@ class BaseExecutor:
|
||||
"""
|
||||
return self.level_infra.get("trade_calendar")
|
||||
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs):
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
- reset `start_time` and `end_time`, used in trade calendar
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
@@ -165,13 +172,13 @@ class BaseExecutor:
|
||||
if common_infra is not None:
|
||||
self.reset_common_infra(common_infra)
|
||||
|
||||
def get_level_infra(self):
|
||||
def get_level_infra(self) -> LevelInfrastructure:
|
||||
return self.level_infra
|
||||
|
||||
def finished(self):
|
||||
def finished(self) -> bool:
|
||||
return self.trade_calendar.finished()
|
||||
|
||||
def execute(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
def execute(self, trade_decision: BaseTradeDecision, level: int = 0) -> List[object]:
|
||||
"""execute the trade decision and return the executed result
|
||||
|
||||
NOTE: this function is never used directly in the framework. Should we delete it?
|
||||
@@ -188,13 +195,17 @@ class BaseExecutor:
|
||||
execute_result : List[object]
|
||||
the executed result for trade decision
|
||||
"""
|
||||
return_value = {}
|
||||
return_value: dict = {}
|
||||
for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
return cast(list, return_value.get("execute_result"))
|
||||
|
||||
@abstractclassmethod
|
||||
def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
@abstractmethod
|
||||
def _collect_data(
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
level: int = 0,
|
||||
) -> Union[Generator[Any, Any, Tuple[List[object], dict]], Tuple[List[object], dict]]:
|
||||
"""
|
||||
Please refer to the doc of collect_data
|
||||
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
|
||||
@@ -212,8 +223,11 @@ class BaseExecutor:
|
||||
"""
|
||||
|
||||
def collect_data(
|
||||
self, trade_decision: BaseTradeDecision, return_value: dict = None, level: int = 0
|
||||
) -> List[object]:
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
return_value: dict = None,
|
||||
level: int = 0,
|
||||
) -> Generator[Any, Any, List[object]]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
his function will make a step forward
|
||||
@@ -245,7 +259,7 @@ class BaseExecutor:
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
|
||||
atomic = not issubclass(self.__class__, NestedExecutor) # issubclass(A, A) is True
|
||||
atomic = not issubclass(self.__class__, NestedExecutor) # issubclass(A, A) is True
|
||||
|
||||
if atomic and trade_decision.get_range_limit(default_value=None) is not None:
|
||||
raise ValueError("atomic executor doesn't support specify `range_limit`")
|
||||
@@ -256,7 +270,9 @@ class BaseExecutor:
|
||||
obj = self._collect_data(trade_decision=trade_decision, level=level)
|
||||
|
||||
if isinstance(obj, GeneratorType):
|
||||
res, kwargs = yield from obj
|
||||
yield_res = yield from obj
|
||||
assert isinstance(yield_res, tuple) and len(yield_res) == 2
|
||||
res, kwargs = yield_res
|
||||
else:
|
||||
# Some concrete executor don't have inner decisions
|
||||
res, kwargs = obj
|
||||
@@ -282,7 +298,7 @@ class BaseExecutor:
|
||||
return_value.update({"execute_result": res})
|
||||
return res
|
||||
|
||||
def get_all_executors(self):
|
||||
def get_all_executors(self) -> List[BaseExecutor]:
|
||||
"""get all executors"""
|
||||
return [self]
|
||||
|
||||
@@ -290,7 +306,8 @@ class BaseExecutor:
|
||||
class NestedExecutor(BaseExecutor):
|
||||
"""
|
||||
Nested Executor with inner strategy and executor
|
||||
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env.
|
||||
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision`
|
||||
in a higher frequency env.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -307,8 +324,8 @@ class NestedExecutor(BaseExecutor):
|
||||
skip_empty_decision: bool = True,
|
||||
align_range_limit: bool = True,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -326,10 +343,14 @@ class NestedExecutor(BaseExecutor):
|
||||
It is only for nested executor, because range_limit is given by outer strategy
|
||||
"""
|
||||
self.inner_executor: BaseExecutor = init_instance_by_config(
|
||||
inner_executor, common_infra=common_infra, accept_types=BaseExecutor
|
||||
inner_executor,
|
||||
common_infra=common_infra,
|
||||
accept_types=BaseExecutor,
|
||||
)
|
||||
self.inner_strategy: BaseStrategy = init_instance_by_config(
|
||||
inner_strategy, common_infra=common_infra, accept_types=BaseStrategy
|
||||
inner_strategy,
|
||||
common_infra=common_infra,
|
||||
accept_types=BaseStrategy,
|
||||
)
|
||||
|
||||
self._skip_empty_decision = skip_empty_decision
|
||||
@@ -347,10 +368,10 @@ class NestedExecutor(BaseExecutor):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def reset_common_infra(self, common_infra, copy_trade_account=False):
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset inner_strategyand inner_executor common infra
|
||||
- reset inner_strategy and inner_executor common infra
|
||||
"""
|
||||
# NOTE: please refer to the docs of BaseExecutor.reset_common_infra for the meaning of `copy_trade_account`
|
||||
|
||||
@@ -361,7 +382,7 @@ class NestedExecutor(BaseExecutor):
|
||||
self.inner_executor.reset_common_infra(common_infra, copy_trade_account=True)
|
||||
self.inner_strategy.reset_common_infra(common_infra)
|
||||
|
||||
def _init_sub_trading(self, trade_decision):
|
||||
def _init_sub_trading(self, trade_decision: BaseTradeDecision) -> None:
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
|
||||
self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
sub_level_infra = self.inner_executor.get_level_infra()
|
||||
@@ -371,14 +392,18 @@ class NestedExecutor(BaseExecutor):
|
||||
def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
# outer strategy have chance to update decision each iterator
|
||||
updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar)
|
||||
if updated_trade_decision is not None:
|
||||
if updated_trade_decision is not None: # TODO: always is None for now?
|
||||
trade_decision = updated_trade_decision
|
||||
# NEW UPDATE
|
||||
# create a hook for inner strategy to update outer decision
|
||||
self.inner_strategy.alter_outer_trade_decision(trade_decision)
|
||||
return trade_decision
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
def _collect_data(
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
level: int = 0,
|
||||
) -> Generator[Any, Any, Tuple[List[object], dict]]:
|
||||
execute_result = []
|
||||
inner_order_indicators = []
|
||||
decision_list = []
|
||||
@@ -393,8 +418,8 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
if trade_decision.empty() and self._skip_empty_decision:
|
||||
# give one chance for outer strategy to update the strategy
|
||||
# - For updating some information in the sub executor(the strategy have no knowledge of the inner
|
||||
# executor when generating the decision)
|
||||
# - For updating some information in the sub executor (the strategy have no knowledge of the inner
|
||||
# executor when generating the decision)
|
||||
break
|
||||
|
||||
sub_cal: TradeCalendarManager = self.inner_executor.trade_calendar
|
||||
@@ -408,15 +433,19 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
# NOTE: !!!!!
|
||||
# the two lines below is for a special case in RL
|
||||
# To solve the confliction below
|
||||
# - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction loop
|
||||
# For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=> (inner Qlib Executor)])
|
||||
# To solve the conflicts below
|
||||
# - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction
|
||||
# loop For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=>
|
||||
# (inner Qlib Executor)])
|
||||
# - However, RL-based framework has it's own script to run the loop
|
||||
# For an _RL learning example_, (RL Policy) <=> (RL Env[(inner Qlib Executor)])
|
||||
# To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution below is proposed
|
||||
# - The entry script follow the example of _RL learning example_ to be compatible with all kinds of RL Framework
|
||||
# To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution
|
||||
# below is proposed
|
||||
# - The entry script follow the example of _RL learning example_ to be compatible with all kinds of
|
||||
# RL Framework
|
||||
# - Each step of (RL Env) will make (inner Qlib Executor) one step forward
|
||||
# - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env) by `yield from` and wait for the action from the policy
|
||||
# - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env)
|
||||
# by `yield from` and wait for the action from the policy
|
||||
# So the two lines below is the implementation of yielding control rights
|
||||
if isinstance(res, GeneratorType):
|
||||
res = yield from res
|
||||
@@ -430,13 +459,15 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
# NOTE: Trade Calendar will step forward in the follow line
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(
|
||||
trade_decision=_inner_trade_decision, level=level + 1
|
||||
trade_decision=_inner_trade_decision,
|
||||
level=level + 1,
|
||||
)
|
||||
assert isinstance(_inner_execute_result, list)
|
||||
self.post_inner_exe_step(_inner_execute_result)
|
||||
execute_result.extend(_inner_execute_result)
|
||||
|
||||
inner_order_indicators.append(
|
||||
self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True)
|
||||
self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True),
|
||||
)
|
||||
else:
|
||||
# do nothing and just step forward
|
||||
@@ -444,7 +475,7 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}
|
||||
|
||||
def post_inner_exe_step(self, inner_exe_res):
|
||||
def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:
|
||||
"""
|
||||
A hook for doing sth after each step of inner strategy
|
||||
|
||||
@@ -453,13 +484,24 @@ class NestedExecutor(BaseExecutor):
|
||||
inner_exe_res :
|
||||
the execution result of inner task
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_all_executors(self):
|
||||
def get_all_executors(self) -> List[BaseExecutor]:
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
return [self, *self.inner_executor.get_all_executors()]
|
||||
|
||||
|
||||
def _retrieve_orders_from_decision(trade_decision: BaseTradeDecision) -> List[Order]:
|
||||
"""
|
||||
IDE-friendly helper function.
|
||||
"""
|
||||
decisions = trade_decision.get_decision()
|
||||
orders: List[Order] = []
|
||||
for decision in decisions:
|
||||
assert isinstance(decision, Order)
|
||||
orders.append(decision)
|
||||
return orders
|
||||
|
||||
|
||||
class SimulatorExecutor(BaseExecutor):
|
||||
"""Executor that simulate the true market"""
|
||||
|
||||
@@ -468,10 +510,10 @@ class SimulatorExecutor(BaseExecutor):
|
||||
|
||||
# available trade_types
|
||||
TT_SERIAL = "serial"
|
||||
## The orders will be executed serially in a sequence
|
||||
# The orders will be executed serially in a sequence
|
||||
# In each trading step, it is possible that users sell instruments first and use the money to buy new instruments
|
||||
TT_PARAL = "parallel"
|
||||
## The orders will be executed parallelly
|
||||
# The orders will be executed in parallel
|
||||
# In each trading step, if users try to sell instruments first and buy new instruments with money, failure will
|
||||
# occur
|
||||
|
||||
@@ -486,8 +528,8 @@ class SimulatorExecutor(BaseExecutor):
|
||||
track_data: bool = False,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_type: str = TT_SERIAL,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -521,7 +563,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
List[Order]:
|
||||
get a list orders according to `self.trade_type`
|
||||
"""
|
||||
orders = trade_decision.get_decision()
|
||||
orders = _retrieve_orders_from_decision(trade_decision)
|
||||
|
||||
if self.trade_type == self.TT_SERIAL:
|
||||
# Orders will be traded in a parallel way
|
||||
@@ -529,15 +571,15 @@ class SimulatorExecutor(BaseExecutor):
|
||||
elif self.trade_type == self.TT_PARAL:
|
||||
# NOTE: !!!!!!!
|
||||
# Assumption: there will not be orders in different trading direction in a single step of a strategy !!!!
|
||||
# The parallel trading failure will be caused only by the confliction of money
|
||||
# Therefore, make the buying go first will make sure the confliction happen.
|
||||
# The parallel trading failure will be caused only by the conflicts of money
|
||||
# Therefore, make the buying go first will make sure the conflicts happen.
|
||||
# It equals to parallel trading after sorting the order by direction
|
||||
order_it = sorted(orders, key=lambda order: -order.direction)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return order_it
|
||||
|
||||
def _update_dealt_order_amount(self, order):
|
||||
def _update_dealt_order_amount(self, order: Order) -> None:
|
||||
"""update date and dealt order amount in the day."""
|
||||
|
||||
now_deal_day = self.trade_calendar.get_step_time()[0].floor(freq="D")
|
||||
@@ -546,10 +588,9 @@ class SimulatorExecutor(BaseExecutor):
|
||||
self.deal_day = now_deal_day
|
||||
self.dealt_order_amount[order.stock_id] += order.deal_amount
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
trade_start_time, _ = self.trade_calendar.get_step_time()
|
||||
execute_result = []
|
||||
execute_result: list = []
|
||||
|
||||
for order in self._get_order_iterator(trade_decision):
|
||||
# execute the order.
|
||||
@@ -563,7 +604,8 @@ class SimulatorExecutor(BaseExecutor):
|
||||
self._update_dealt_order_amount(order)
|
||||
if self.verbose:
|
||||
print(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}, cash {:.2f}.".format(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, "
|
||||
"value {:.2f}, cash {:.2f}.".format(
|
||||
trade_start_time,
|
||||
"sell" if order.direction == Order.SELL else "buy",
|
||||
order.stock_id,
|
||||
@@ -573,6 +615,6 @@ class SimulatorExecutor(BaseExecutor):
|
||||
order.factor,
|
||||
trade_val,
|
||||
self.trade_account.get_cash(),
|
||||
)
|
||||
),
|
||||
)
|
||||
return execute_result, {"trade_info": execute_result}
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import List, Text, Union, Callable, Iterable, Dict
|
||||
from collections import OrderedDict
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import qlib.utils.index_data as idd
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..utils.index_data import IndexData, SingleData
|
||||
from ..utils.resam import resam_ts_data, ts_data_last
|
||||
from ..log import get_module_logger
|
||||
from ..utils.time import is_single_value, Freq
|
||||
import qlib.utils.index_data as idd
|
||||
from ..utils.time import Freq, is_single_value
|
||||
|
||||
|
||||
class BaseQuote:
|
||||
def __init__(self, quote_df: pd.DataFrame, freq):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
|
||||
self.logger = get_module_logger("online operator", level=logging.INFO)
|
||||
|
||||
def get_all_stock(self) -> Iterable:
|
||||
@@ -38,7 +41,7 @@ class BaseQuote:
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
field: Union[str],
|
||||
method: Union[str, None] = None,
|
||||
method: Optional[str] = None,
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
"""get the specific field of stock data during start time and end_time,
|
||||
and apply method to the data.
|
||||
@@ -98,7 +101,7 @@ class BaseQuote:
|
||||
|
||||
|
||||
class PandasQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
|
||||
super().__init__(quote_df=quote_df, freq=freq)
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
@@ -123,7 +126,7 @@ class PandasQuote(BaseQuote):
|
||||
|
||||
|
||||
class NumpyQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq, region="cn"):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq: str, region: str = "cn") -> None:
|
||||
"""NumpyQuote
|
||||
|
||||
Parameters
|
||||
@@ -177,7 +180,8 @@ class NumpyQuote(BaseQuote):
|
||||
data = self._agg_data(data, method)
|
||||
return data
|
||||
|
||||
def _agg_data(self, data: IndexData, method):
|
||||
@staticmethod
|
||||
def _agg_data(data: IndexData, method: str) -> Union[IndexData, np.ndarray, None]:
|
||||
"""Agg data by specific method."""
|
||||
# FIXME: why not call the method of data directly?
|
||||
if method == "sum":
|
||||
@@ -223,31 +227,31 @@ class BaseSingleMetric:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `__init__` method")
|
||||
|
||||
def __add__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __add__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__add__` method")
|
||||
|
||||
def __radd__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __radd__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
return self + other
|
||||
|
||||
def __sub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __sub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__sub__` method")
|
||||
|
||||
def __rsub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __rsub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__rsub__` method")
|
||||
|
||||
def __mul__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __mul__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__mul__` method")
|
||||
|
||||
def __truediv__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __truediv__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__truediv__` method")
|
||||
|
||||
def __eq__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __eq__(self, other: object) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__eq__` method")
|
||||
|
||||
def __gt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __gt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__gt__` method")
|
||||
|
||||
def __lt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __lt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__lt__` method")
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -264,7 +268,7 @@ class BaseSingleMetric:
|
||||
|
||||
raise NotImplementedError(f"Please implement the `count` method")
|
||||
|
||||
def abs(self) -> "BaseSingleMetric":
|
||||
def abs(self) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `abs` method")
|
||||
|
||||
@property
|
||||
@@ -273,18 +277,18 @@ class BaseSingleMetric:
|
||||
|
||||
raise NotImplementedError(f"Please implement the `empty` method")
|
||||
|
||||
def add(self, other: "BaseSingleMetric", fill_value: float = None) -> "BaseSingleMetric":
|
||||
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
|
||||
"""Replace np.NaN with fill_value in two metrics and add them."""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `add` method")
|
||||
|
||||
def replace(self, replace_dict: dict) -> "BaseSingleMetric":
|
||||
def replace(self, replace_dict: dict) -> BaseSingleMetric:
|
||||
"""Replace the value of metric according to replace_dict."""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `replace` method")
|
||||
|
||||
def apply(self, func: dict) -> "BaseSingleMetric":
|
||||
"""Replace the value of metric with func(metric).
|
||||
def apply(self, func: Callable) -> BaseSingleMetric:
|
||||
"""Replace the value of metric with func (metric).
|
||||
Currently, the func is only qlib/backtest/order/Order.parse_dir.
|
||||
"""
|
||||
|
||||
@@ -303,11 +307,11 @@ class BaseOrderIndicator:
|
||||
to inherit the BaseSingleMetric.
|
||||
"""
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
def __init__(self):
|
||||
self.data = {} # will be created in the subclass
|
||||
self.logger = get_module_logger("online operator")
|
||||
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]):
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
|
||||
"""assign one metric.
|
||||
|
||||
Parameters
|
||||
@@ -327,7 +331,7 @@ class BaseOrderIndicator:
|
||||
|
||||
raise NotImplementedError(f"Please implement the 'assign' method")
|
||||
|
||||
def transfer(self, func: Callable, new_col: str = None) -> Union[None, BaseSingleMetric]:
|
||||
def transfer(self, func: Callable, new_col: str = None) -> Optional[BaseSingleMetric]:
|
||||
"""compute new metric with existing metrics.
|
||||
|
||||
Parameters
|
||||
@@ -351,6 +355,7 @@ class BaseOrderIndicator:
|
||||
tmp_metric = func(**func_kwargs)
|
||||
if new_col is not None:
|
||||
self.data[new_col] = tmp_metric
|
||||
return None
|
||||
else:
|
||||
return tmp_metric
|
||||
|
||||
@@ -371,7 +376,7 @@ class BaseOrderIndicator:
|
||||
|
||||
raise NotImplementedError(f"Please implement the 'get_metric_series' method")
|
||||
|
||||
def get_index_data(self, metric) -> SingleData:
|
||||
def get_index_data(self, metric: str) -> SingleData:
|
||||
"""get one metric with the format of SingleData
|
||||
|
||||
Parameters
|
||||
@@ -388,7 +393,12 @@ class BaseOrderIndicator:
|
||||
raise NotImplementedError(f"Please implement the 'get_index_data' method")
|
||||
|
||||
@staticmethod
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
|
||||
def sum_all_indicators(
|
||||
order_indicator: BaseOrderIndicator,
|
||||
indicators: List[BaseOrderIndicator],
|
||||
metrics: Union[str, List[str]],
|
||||
fill_value: float = 0,
|
||||
) -> None:
|
||||
"""sum indicators with the same metrics.
|
||||
and assign to the order_indicator(BaseOrderIndicator).
|
||||
NOTE: indicators could be a empty list when orders in lower level all fail.
|
||||
@@ -526,16 +536,17 @@ class PandasSingleMetric(SingleMetric):
|
||||
def index(self):
|
||||
return list(self.metric.index)
|
||||
|
||||
def add(self, other, fill_value=None):
|
||||
def add(self, other: BaseSingleMetric, fill_value: float = None) -> PandasSingleMetric:
|
||||
other = cast(PandasSingleMetric, other)
|
||||
return self.__class__(self.metric.add(other.metric, fill_value=fill_value))
|
||||
|
||||
def replace(self, replace_dict: dict):
|
||||
def replace(self, replace_dict: dict) -> PandasSingleMetric:
|
||||
return self.__class__(self.metric.replace(replace_dict))
|
||||
|
||||
def apply(self, func: Callable):
|
||||
def apply(self, func: Callable) -> PandasSingleMetric:
|
||||
return self.__class__(self.metric.apply(func))
|
||||
|
||||
def reindex(self, index, fill_value):
|
||||
def reindex(self, index: Any, fill_value: float) -> PandasSingleMetric:
|
||||
return self.__class__(self.metric.reindex(index, fill_value=fill_value))
|
||||
|
||||
def __repr__(self):
|
||||
@@ -549,13 +560,14 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
||||
Str is the name of metric.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(PandasOrderIndicator, self).__init__()
|
||||
self.data: Dict[str, PandasSingleMetric] = OrderedDict()
|
||||
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]):
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
|
||||
self.data[col] = PandasSingleMetric(metric)
|
||||
|
||||
def get_index_data(self, metric):
|
||||
def get_index_data(self, metric: str) -> SingleData:
|
||||
if metric in self.data:
|
||||
return idd.SingleData(self.data[metric].metric)
|
||||
else:
|
||||
@@ -571,7 +583,12 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
||||
return {k: v.metric for k, v in self.data.items()}
|
||||
|
||||
@staticmethod
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
|
||||
def sum_all_indicators(
|
||||
order_indicator: BaseOrderIndicator,
|
||||
indicators: List[BaseOrderIndicator],
|
||||
metrics: Union[str, List[str]],
|
||||
fill_value: float = 0,
|
||||
) -> None:
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
for metric in metrics:
|
||||
@@ -591,13 +608,14 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
Str is the name of metric.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(NumpyOrderIndicator, self).__init__()
|
||||
self.data: Dict[str, SingleData] = OrderedDict()
|
||||
|
||||
def assign(self, col: str, metric: dict):
|
||||
def assign(self, col: str, metric: dict) -> None:
|
||||
self.data[col] = idd.SingleData(metric)
|
||||
|
||||
def get_index_data(self, metric):
|
||||
def get_index_data(self, metric: str) -> SingleData:
|
||||
if metric in self.data:
|
||||
return self.data[metric]
|
||||
else:
|
||||
@@ -613,21 +631,27 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
return tmp_metric_dict
|
||||
|
||||
@staticmethod
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
|
||||
def sum_all_indicators(
|
||||
order_indicator: BaseOrderIndicator,
|
||||
indicators: List[BaseOrderIndicator],
|
||||
metrics: Union[str, List[str]],
|
||||
fill_value: float = 0,
|
||||
) -> None:
|
||||
# get all index(stock_id)
|
||||
stocks = set()
|
||||
stock_set: set = set()
|
||||
for indicator in indicators:
|
||||
# set(np.ndarray.tolist()) is faster than set(np.ndarray)
|
||||
stocks = stocks | set(indicator.data[metrics[0]].index.tolist())
|
||||
stocks = list(stocks)
|
||||
stocks.sort()
|
||||
stock_set = stock_set | set(indicator.data[metrics[0]].index.tolist())
|
||||
stocks = sorted(list(stock_set))
|
||||
|
||||
# add metric by index
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
for metric in metrics:
|
||||
order_indicator.data[metric] = idd.sum_by_index(
|
||||
[indicator.data[metric] for indicator in indicators], stocks, fill_value
|
||||
[indicator.data[metric] for indicator in indicators],
|
||||
stocks,
|
||||
fill_value,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -2,26 +2,28 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
from datetime import timedelta
|
||||
import numpy as np
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .decision import Order
|
||||
from ..data.data import D
|
||||
from .decision import Order
|
||||
|
||||
|
||||
class BasePosition:
|
||||
"""
|
||||
The Position want to maintain the position like a dictionary
|
||||
The Position wants to maintain the position like a dictionary
|
||||
Please refer to the `Position` class for the position
|
||||
"""
|
||||
|
||||
def __init__(self, *args, cash=0.0, **kwargs):
|
||||
def __init__(self, *args: Any, cash: float = 0.0, **kwargs: Any) -> None:
|
||||
self._settle_type = self.ST_NO
|
||||
self.position: dict = {}
|
||||
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
|
||||
pass
|
||||
|
||||
def skip_update(self) -> bool:
|
||||
"""
|
||||
@@ -51,7 +53,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `check_stock` method")
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -66,7 +68,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_order` method")
|
||||
|
||||
def update_stock_price(self, stock_id, price: float):
|
||||
def update_stock_price(self, stock_id: str, price: float) -> None:
|
||||
"""
|
||||
Updating the latest price of the order
|
||||
The useful when clearing balance at each bar end
|
||||
@@ -91,13 +93,16 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `calculate_stock_value` method")
|
||||
|
||||
def get_stock_list(self) -> List:
|
||||
def calculate_value(self) -> float:
|
||||
raise NotImplementedError(f"Please implement the `calculate_value` method")
|
||||
|
||||
def get_stock_list(self) -> List[str]:
|
||||
"""
|
||||
Get the list of stocks in the position.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_list` method")
|
||||
|
||||
def get_stock_price(self, code) -> float:
|
||||
def get_stock_price(self, code: str) -> float:
|
||||
"""
|
||||
get the latest price of the stock
|
||||
|
||||
@@ -108,7 +113,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_price` method")
|
||||
|
||||
def get_stock_amount(self, code) -> float:
|
||||
def get_stock_amount(self, code: str) -> float:
|
||||
"""
|
||||
get the amount of the stock
|
||||
|
||||
@@ -126,18 +131,20 @@ class BasePosition:
|
||||
|
||||
def get_cash(self, include_settle: bool = False) -> float:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
include_settle:
|
||||
will the unsettled(delayed) cash included
|
||||
Default: not include those unavailable cash
|
||||
|
||||
Returns
|
||||
-------
|
||||
float:
|
||||
the available(tradable) cash in position
|
||||
include_settle:
|
||||
will the unsettled(delayed) cash included
|
||||
Default: not include those unavailable cash
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_cash` method")
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
def get_stock_amount_dict(self) -> dict:
|
||||
"""
|
||||
generate stock amount dict {stock_id : amount of stock}
|
||||
|
||||
@@ -148,7 +155,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
|
||||
"""
|
||||
generate stock weight dict {stock_id : value weight of stock in the position}
|
||||
it is meaningful in the beginning or the end of each trade step
|
||||
@@ -167,7 +174,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
def add_count_all(self, bar: str) -> None:
|
||||
"""
|
||||
Will be called at the end of each bar on each level
|
||||
|
||||
@@ -178,24 +185,19 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
def update_weight_all(self):
|
||||
def update_weight_all(self) -> None:
|
||||
"""
|
||||
Updating the position weight;
|
||||
|
||||
# TODO: this function is a little weird. The weight data in the position is in a wrong state after dealing order
|
||||
# and before updating weight.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bar :
|
||||
The level to be updated
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
ST_CASH = "cash"
|
||||
ST_NO = None
|
||||
ST_NO = "None" # String is more typehint friendly than None
|
||||
|
||||
def settle_start(self, settle_type: str):
|
||||
def settle_start(self, settle_type: str) -> None:
|
||||
"""
|
||||
settlement start
|
||||
It will act like start and commit a transaction
|
||||
@@ -212,21 +214,16 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `settle_conf` method")
|
||||
|
||||
def settle_commit(self):
|
||||
def settle_commit(self) -> None:
|
||||
"""
|
||||
settlement commit
|
||||
|
||||
Parameters
|
||||
----------
|
||||
settle_type : str
|
||||
please refer to the documents of Executor
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `settle_commit` method")
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.__dict__.__str__()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.__dict__.__repr__()
|
||||
|
||||
|
||||
@@ -244,13 +241,11 @@ class Position(BasePosition):
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, cash: float = 0, position_dict: Dict[str, Dict[str, float]] = {}):
|
||||
def __init__(self, cash: float = 0, position_dict: Dict[str, Union[Dict[str, float], float]] = {}) -> None:
|
||||
"""Init position by cash and position_dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time :
|
||||
the start time of backtest. It's for filling the initial value of stocks.
|
||||
cash : float, optional
|
||||
initial cash in account, by default 0
|
||||
position_dict : Dict[
|
||||
@@ -270,9 +265,9 @@ class Position(BasePosition):
|
||||
# Otherwise the initial value
|
||||
self.init_cash = cash
|
||||
self.position = position_dict.copy()
|
||||
for stock in self.position:
|
||||
if isinstance(self.position[stock], int):
|
||||
self.position[stock] = {"amount": self.position[stock]}
|
||||
for stock, value in self.position.items():
|
||||
if isinstance(value, int):
|
||||
self.position[stock] = {"amount": value}
|
||||
self.position["cash"] = cash
|
||||
|
||||
# If the stock price information is missing, the account value will not be calculated temporarily
|
||||
@@ -281,21 +276,23 @@ class Position(BasePosition):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30):
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
|
||||
"""fill the stock value by the close price of latest last_days from qlib.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time :
|
||||
the start time of backtest.
|
||||
freq : str
|
||||
Frequency
|
||||
last_days : int, optional
|
||||
the days to get the latest close price, by default 30.
|
||||
"""
|
||||
stock_list = []
|
||||
for stock in self.position:
|
||||
if not isinstance(self.position[stock], dict):
|
||||
for stock, value in self.position.items():
|
||||
if not isinstance(value, dict):
|
||||
continue
|
||||
if ("price" not in self.position[stock]) or (self.position[stock]["price"] is None):
|
||||
if value.get("price", None) is None:
|
||||
stock_list.append(stock)
|
||||
|
||||
if len(stock_list) == 0:
|
||||
@@ -306,7 +303,12 @@ class Position(BasePosition):
|
||||
price_end_time = start_time
|
||||
price_start_time = start_time - timedelta(days=last_days)
|
||||
price_df = D.features(
|
||||
stock_list, ["$close"], price_start_time, price_end_time, freq=freq, disk_cache=True
|
||||
stock_list,
|
||||
["$close"],
|
||||
price_start_time,
|
||||
price_end_time,
|
||||
freq=freq,
|
||||
disk_cache=True,
|
||||
).dropna()
|
||||
price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict()
|
||||
|
||||
@@ -318,7 +320,7 @@ class Position(BasePosition):
|
||||
self.position[stock]["price"] = price_dict[stock]
|
||||
self.position["now_account_value"] = self.calculate_value()
|
||||
|
||||
def _init_stock(self, stock_id, amount, price=None):
|
||||
def _init_stock(self, stock_id: str, amount: float, price: float = None) -> None:
|
||||
"""
|
||||
initialization the stock in current position
|
||||
|
||||
@@ -336,7 +338,7 @@ class Position(BasePosition):
|
||||
self.position[stock_id]["price"] = price
|
||||
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
|
||||
|
||||
def _buy_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
def _buy_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
self._init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
|
||||
@@ -346,15 +348,16 @@ class Position(BasePosition):
|
||||
|
||||
self.position["cash"] -= trade_val + cost
|
||||
|
||||
def _sell_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
def _sell_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
raise KeyError("{} not in current position".format(stock_id))
|
||||
else:
|
||||
if np.isclose(self.position[stock_id]["amount"], trade_amount):
|
||||
# Selling all the stocks
|
||||
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both ralative amount and absolute amount
|
||||
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
|
||||
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both
|
||||
# relative amount and absolute amount
|
||||
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
|
||||
self._del_stock(stock_id)
|
||||
else:
|
||||
# decrease the amount of stock
|
||||
@@ -362,7 +365,11 @@ class Position(BasePosition):
|
||||
# check if to delete
|
||||
if self.position[stock_id]["amount"] < -1e-5:
|
||||
raise ValueError(
|
||||
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
|
||||
"only have {} {}, require {}".format(
|
||||
self.position[stock_id]["amount"] + trade_amount,
|
||||
stock_id,
|
||||
trade_amount,
|
||||
),
|
||||
)
|
||||
|
||||
new_cash = trade_val - cost
|
||||
@@ -373,13 +380,13 @@ class Position(BasePosition):
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def _del_stock(self, stock_id):
|
||||
def _del_stock(self, stock_id: str) -> None:
|
||||
del self.position[stock_id]
|
||||
|
||||
def check_stock(self, stock_id):
|
||||
def check_stock(self, stock_id: str) -> bool:
|
||||
return stock_id in self.position
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
# handle order, order is a order class, defined in exchange.py
|
||||
if order.direction == Order.BUY:
|
||||
# BUY
|
||||
@@ -390,54 +397,54 @@ class Position(BasePosition):
|
||||
else:
|
||||
raise NotImplementedError("do not support order direction {}".format(order.direction))
|
||||
|
||||
def update_stock_price(self, stock_id, price):
|
||||
def update_stock_price(self, stock_id: str, price: float) -> None:
|
||||
self.position[stock_id]["price"] = price
|
||||
|
||||
def update_stock_count(self, stock_id, bar, count):
|
||||
def update_stock_count(self, stock_id: str, bar: str, count: float) -> None: # TODO: check type of `bar`
|
||||
self.position[stock_id][f"count_{bar}"] = count
|
||||
|
||||
def update_stock_weight(self, stock_id, weight):
|
||||
def update_stock_weight(self, stock_id: str, weight: float) -> None:
|
||||
self.position[stock_id]["weight"] = weight
|
||||
|
||||
def calculate_stock_value(self):
|
||||
def calculate_stock_value(self) -> float:
|
||||
stock_list = self.get_stock_list()
|
||||
value = 0
|
||||
for stock_id in stock_list:
|
||||
value += self.position[stock_id]["amount"] * self.position[stock_id]["price"]
|
||||
return value
|
||||
|
||||
def calculate_value(self):
|
||||
def calculate_value(self) -> float:
|
||||
value = self.calculate_stock_value()
|
||||
value += self.position["cash"] + self.position.get("cash_delay", 0.0)
|
||||
return value
|
||||
|
||||
def get_stock_list(self):
|
||||
def get_stock_list(self) -> List[str]:
|
||||
stock_list = list(set(self.position.keys()) - {"cash", "now_account_value", "cash_delay"})
|
||||
return stock_list
|
||||
|
||||
def get_stock_price(self, code):
|
||||
def get_stock_price(self, code: str) -> float:
|
||||
return self.position[code]["price"]
|
||||
|
||||
def get_stock_amount(self, code):
|
||||
def get_stock_amount(self, code: str) -> float:
|
||||
return self.position[code]["amount"] if code in self.position else 0
|
||||
|
||||
def get_stock_count(self, code, bar):
|
||||
def get_stock_count(self, code: str, bar: str) -> float:
|
||||
"""the days the account has been hold, it may be used in some special strategies"""
|
||||
if f"count_{bar}" in self.position[code]:
|
||||
return self.position[code][f"count_{bar}"]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get_stock_weight(self, code):
|
||||
def get_stock_weight(self, code: str) -> float:
|
||||
return self.position[code]["weight"]
|
||||
|
||||
def get_cash(self, include_settle=False):
|
||||
def get_cash(self, include_settle: bool = False) -> float:
|
||||
cash = self.position["cash"]
|
||||
if include_settle:
|
||||
cash += self.position.get("cash_delay", 0.0)
|
||||
return cash
|
||||
|
||||
def get_stock_amount_dict(self):
|
||||
def get_stock_amount_dict(self) -> dict:
|
||||
"""generate stock amount dict {stock_id : amount of stock}"""
|
||||
d = {}
|
||||
stock_list = self.get_stock_list()
|
||||
@@ -445,7 +452,7 @@ class Position(BasePosition):
|
||||
d[stock_code] = self.get_stock_amount(code=stock_code)
|
||||
return d
|
||||
|
||||
def get_stock_weight_dict(self, only_stock=False):
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
|
||||
"""get_stock_weight_dict
|
||||
generate stock weight dict {stock_id : value weight of stock in the position}
|
||||
it is meaningful in the beginning or the end of each trade date
|
||||
@@ -463,7 +470,7 @@ class Position(BasePosition):
|
||||
d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value
|
||||
return d
|
||||
|
||||
def add_count_all(self, bar):
|
||||
def add_count_all(self, bar: str) -> None:
|
||||
stock_list = self.get_stock_list()
|
||||
for code in stock_list:
|
||||
if f"count_{bar}" in self.position[code]:
|
||||
@@ -471,18 +478,18 @@ class Position(BasePosition):
|
||||
else:
|
||||
self.position[code][f"count_{bar}"] = 1
|
||||
|
||||
def update_weight_all(self):
|
||||
def update_weight_all(self) -> None:
|
||||
weight_dict = self.get_stock_weight_dict()
|
||||
for stock_code, weight in weight_dict.items():
|
||||
self.update_stock_weight(stock_code, weight)
|
||||
|
||||
def settle_start(self, settle_type):
|
||||
def settle_start(self, settle_type: str) -> None:
|
||||
assert self._settle_type == self.ST_NO, "Currently, settlement can't be nested!!!!!"
|
||||
self._settle_type = settle_type
|
||||
if settle_type == self.ST_CASH:
|
||||
self.position["cash_delay"] = 0.0
|
||||
|
||||
def settle_commit(self):
|
||||
def settle_commit(self) -> None:
|
||||
if self._settle_type != self.ST_NO:
|
||||
if self._settle_type == self.ST_CASH:
|
||||
self.position["cash"] += self.position["cash_delay"]
|
||||
@@ -507,10 +514,10 @@ class InfPosition(BasePosition):
|
||||
# InfPosition always have any stocks
|
||||
return True
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
pass
|
||||
|
||||
def update_stock_price(self, stock_id, price: float):
|
||||
def update_stock_price(self, stock_id: str, price: float) -> None:
|
||||
pass
|
||||
|
||||
def calculate_stock_value(self) -> float:
|
||||
@@ -522,33 +529,36 @@ class InfPosition(BasePosition):
|
||||
"""
|
||||
return np.inf
|
||||
|
||||
def get_stock_list(self) -> List:
|
||||
def calculate_value(self) -> float:
|
||||
raise NotImplementedError(f"InfPosition doesn't support calculating value")
|
||||
|
||||
def get_stock_list(self) -> List[str]:
|
||||
raise NotImplementedError(f"InfPosition doesn't support stock list position")
|
||||
|
||||
def get_stock_price(self, code) -> float:
|
||||
def get_stock_price(self, code: str) -> float:
|
||||
"""the price of the inf position is meaningless"""
|
||||
return np.nan
|
||||
|
||||
def get_stock_amount(self, code) -> float:
|
||||
def get_stock_amount(self, code: str) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_cash(self, include_settle=False) -> float:
|
||||
def get_cash(self, include_settle: bool = False) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
def get_stock_amount_dict(self) -> dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool) -> Dict:
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
def add_count_all(self, bar: str) -> None:
|
||||
raise NotImplementedError(f"InfPosition doesn't support add_count_all")
|
||||
|
||||
def update_weight_all(self):
|
||||
def update_weight_all(self) -> None:
|
||||
raise NotImplementedError(f"InfPosition doesn't support update_weight_all")
|
||||
|
||||
def settle_start(self, settle_type: str):
|
||||
def settle_start(self, settle_type: str) -> None:
|
||||
pass
|
||||
|
||||
def settle_commit(self):
|
||||
def settle_commit(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -4,14 +4,16 @@
|
||||
This module is not well maintained.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from .position import Position
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..config import C
|
||||
from ..data import D
|
||||
from .position import Position
|
||||
|
||||
|
||||
def get_benchmark_weight(
|
||||
bench,
|
||||
@@ -214,7 +216,9 @@ def get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, g
|
||||
for idx, row in (~bench_stock_weight_df.isna()).iterrows():
|
||||
bench_values = stock_group_field_df.loc[idx, row[row].index]
|
||||
new_stock_group_df.loc[idx] = get_daily_bin_group(
|
||||
bench_values, stock_group_field_df.loc[idx], group_n=group_n
|
||||
bench_values,
|
||||
stock_group_field_df.loc[idx],
|
||||
group_n=group_n,
|
||||
)
|
||||
return new_stock_group_df
|
||||
|
||||
@@ -315,7 +319,7 @@ def brinson_pa(
|
||||
# The excess profit from the interaction of assets allocation and stocks selection
|
||||
"RIN": Q4 - Q3 - Q2 + Q1,
|
||||
"RTotal": Q4 - Q1, # The totoal excess profit
|
||||
}
|
||||
},
|
||||
),
|
||||
{
|
||||
"port_group_ret": port_group_ret_df,
|
||||
|
||||
@@ -2,22 +2,20 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
import pathlib
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Text, Tuple, Type, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from .decision import IdxTradeRange
|
||||
import qlib.utils.index_data as idd
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from ..data import D
|
||||
from qlib.backtest.exchange import Exchange
|
||||
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
import qlib.utils.index_data as idd
|
||||
from .high_performance_ds import BaseOrderIndicator, BaseSingleMetric, NumpyOrderIndicator
|
||||
|
||||
|
||||
class PortfolioMetrics:
|
||||
@@ -40,7 +38,7 @@ class PortfolioMetrics:
|
||||
update report
|
||||
"""
|
||||
|
||||
def __init__(self, freq: str = "day", benchmark_config: dict = {}):
|
||||
def __init__(self, freq: str = "day", benchmark_config: dict = {}) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -51,13 +49,17 @@ class PortfolioMetrics:
|
||||
- benchmark : Union[str, list, pd.Series]
|
||||
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
||||
example:
|
||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
||||
print(
|
||||
D.features(D.instruments('csi500'),
|
||||
['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()
|
||||
)
|
||||
2017-01-04 0.011693
|
||||
2017-01-05 0.000721
|
||||
2017-01-06 -0.004322
|
||||
2017-01-09 0.006874
|
||||
2017-01-10 -0.003350
|
||||
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the
|
||||
'bench'.
|
||||
- If `benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000300 CSI300
|
||||
- start_time : Union[str, pd.Timestamp], optional
|
||||
@@ -72,25 +74,26 @@ class PortfolioMetrics:
|
||||
self.init_vars()
|
||||
self.init_bench(freq=freq, benchmark_config=benchmark_config)
|
||||
|
||||
def init_vars(self):
|
||||
self.accounts = OrderedDict() # account position value for each trade time
|
||||
self.returns = OrderedDict() # daily return rate for each trade time
|
||||
self.total_turnovers = OrderedDict() # total turnover for each trade time
|
||||
self.turnovers = OrderedDict() # turnover for each trade time
|
||||
self.total_costs = OrderedDict() # total trade cost for each trade time
|
||||
self.costs = OrderedDict() # trade cost rate for each trade time
|
||||
self.values = OrderedDict() # value for each trade time
|
||||
self.cashes = OrderedDict()
|
||||
self.benches = OrderedDict()
|
||||
self.latest_pm_time = None # pd.TimeStamp
|
||||
def init_vars(self) -> None:
|
||||
self.accounts: dict = OrderedDict() # account position value for each trade time
|
||||
self.returns: dict = OrderedDict() # daily return rate for each trade time
|
||||
self.total_turnovers: dict = OrderedDict() # total turnover for each trade time
|
||||
self.turnovers: dict = OrderedDict() # turnover for each trade time
|
||||
self.total_costs: dict = OrderedDict() # total trade cost for each trade time
|
||||
self.costs: dict = OrderedDict() # trade cost rate for each trade time
|
||||
self.values: dict = OrderedDict() # value for each trade time
|
||||
self.cashes: dict = OrderedDict()
|
||||
self.benches: dict = OrderedDict()
|
||||
self.latest_pm_time: Optional[pd.TimeStamp] = None
|
||||
|
||||
def init_bench(self, freq=None, benchmark_config=None):
|
||||
def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
self.benchmark_config = benchmark_config
|
||||
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
|
||||
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
@staticmethod
|
||||
def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.Series]:
|
||||
if benchmark_config is None:
|
||||
return None
|
||||
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
|
||||
@@ -112,7 +115,12 @@ class PortfolioMetrics:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
|
||||
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
|
||||
def _sample_benchmark(
|
||||
self,
|
||||
bench: pd.Series,
|
||||
trade_start_time: Union[str, pd.Timestamp],
|
||||
trade_end_time: Union[str, pd.Timestamp],
|
||||
) -> Optional[float]:
|
||||
if self.bench is None:
|
||||
return None
|
||||
|
||||
@@ -122,35 +130,35 @@ class PortfolioMetrics:
|
||||
_ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||
return 0.0 if _ret is None else _ret - 1
|
||||
|
||||
def is_empty(self):
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.accounts) == 0
|
||||
|
||||
def get_latest_date(self):
|
||||
def get_latest_date(self) -> pd.Timestamp:
|
||||
return self.latest_pm_time
|
||||
|
||||
def get_latest_account_value(self):
|
||||
def get_latest_account_value(self) -> float:
|
||||
return self.accounts[self.latest_pm_time]
|
||||
|
||||
def get_latest_total_cost(self):
|
||||
def get_latest_total_cost(self) -> Any:
|
||||
return self.total_costs[self.latest_pm_time]
|
||||
|
||||
def get_latest_total_turnover(self):
|
||||
def get_latest_total_turnover(self) -> Any:
|
||||
return self.total_turnovers[self.latest_pm_time]
|
||||
|
||||
def update_portfolio_metrics_record(
|
||||
self,
|
||||
trade_start_time=None,
|
||||
trade_end_time=None,
|
||||
account_value=None,
|
||||
cash=None,
|
||||
return_rate=None,
|
||||
total_turnover=None,
|
||||
turnover_rate=None,
|
||||
total_cost=None,
|
||||
cost_rate=None,
|
||||
stock_value=None,
|
||||
bench_value=None,
|
||||
):
|
||||
trade_start_time: Union[str, pd.Timestamp] = None,
|
||||
trade_end_time: Union[str, pd.Timestamp] = None,
|
||||
account_value: float = None,
|
||||
cash: float = None,
|
||||
return_rate: float = None,
|
||||
total_turnover: float = None,
|
||||
turnover_rate: float = None,
|
||||
total_cost: float = None,
|
||||
cost_rate: float = None,
|
||||
stock_value: float = None,
|
||||
bench_value: float = None,
|
||||
) -> None:
|
||||
# check data
|
||||
if None in [
|
||||
trade_start_time,
|
||||
@@ -164,7 +172,8 @@ class PortfolioMetrics:
|
||||
stock_value,
|
||||
]:
|
||||
raise ValueError(
|
||||
"None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, total_cost, cost_rate, stock_value]"
|
||||
"None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, "
|
||||
"total_cost, cost_rate, stock_value]",
|
||||
)
|
||||
|
||||
if trade_end_time is None and bench_value is None:
|
||||
@@ -186,7 +195,7 @@ class PortfolioMetrics:
|
||||
self.latest_pm_time = trade_start_time
|
||||
# finish pm update in each step
|
||||
|
||||
def generate_portfolio_metrics_dataframe(self):
|
||||
def generate_portfolio_metrics_dataframe(self) -> pd.DataFrame:
|
||||
pm = pd.DataFrame()
|
||||
pm["account"] = pd.Series(self.accounts)
|
||||
pm["return"] = pd.Series(self.returns)
|
||||
@@ -200,19 +209,18 @@ class PortfolioMetrics:
|
||||
pm.index.name = "datetime"
|
||||
return pm
|
||||
|
||||
def save_portfolio_metrics(self, path):
|
||||
def save_portfolio_metrics(self, path: str) -> None:
|
||||
r = self.generate_portfolio_metrics_dataframe()
|
||||
r.to_csv(path)
|
||||
|
||||
def load_portfolio_metrics(self, path):
|
||||
def load_portfolio_metrics(self, path: str) -> None:
|
||||
"""load pm from a file
|
||||
should have format like
|
||||
columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
|
||||
:param
|
||||
path: str/ pathlib.Path()
|
||||
"""
|
||||
path = pathlib.Path(path)
|
||||
with path.open("rb") as f:
|
||||
with pathlib.Path(path).open("rb") as f:
|
||||
r = pd.read_csv(f, index_col=0)
|
||||
r.index = pd.DatetimeIndex(r.index)
|
||||
|
||||
@@ -262,30 +270,30 @@ class Indicator:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, order_indicator_cls=NumpyOrderIndicator):
|
||||
def __init__(self, order_indicator_cls: Type[BaseOrderIndicator] = NumpyOrderIndicator) -> None:
|
||||
self.order_indicator_cls = order_indicator_cls
|
||||
|
||||
# order indicator is metrics for a single order for a specific step
|
||||
self.order_indicator_his = OrderedDict()
|
||||
self.order_indicator_his: dict = OrderedDict()
|
||||
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
|
||||
|
||||
# trade indicator is metrics for all orders for a specific step
|
||||
self.trade_indicator_his = OrderedDict()
|
||||
self.trade_indicator: Dict[str, float] = OrderedDict()
|
||||
self.trade_indicator_his: dict = OrderedDict()
|
||||
self.trade_indicator: Dict[str, Optional[BaseSingleMetric]] = OrderedDict()
|
||||
|
||||
self._trade_calendar = None
|
||||
|
||||
# def reset(self, trade_calendar: TradeCalendarManager):
|
||||
def reset(self):
|
||||
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
|
||||
def reset(self) -> None:
|
||||
self.order_indicator = self.order_indicator_cls()
|
||||
self.trade_indicator = OrderedDict()
|
||||
# self._trade_calendar = trade_calendar
|
||||
|
||||
def record(self, trade_start_time):
|
||||
def record(self, trade_start_time: Union[str, pd.Timestamp]) -> None:
|
||||
self.order_indicator_his[trade_start_time] = self.get_order_indicator()
|
||||
self.trade_indicator_his[trade_start_time] = self.get_trade_indicator()
|
||||
|
||||
def _update_order_trade_info(self, trade_info: list):
|
||||
def _update_order_trade_info(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
|
||||
amount = dict()
|
||||
deal_amount = dict()
|
||||
trade_price = dict()
|
||||
@@ -314,7 +322,7 @@ class Indicator:
|
||||
self.order_indicator.assign("trade_dir", trade_dir)
|
||||
self.order_indicator.assign("pa", pa)
|
||||
|
||||
def _update_order_fulfill_rate(self):
|
||||
def _update_order_fulfill_rate(self) -> None:
|
||||
def func(deal_amount, amount):
|
||||
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
|
||||
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
|
||||
@@ -323,11 +331,11 @@ class Indicator:
|
||||
|
||||
self.order_indicator.transfer(func, "ffr")
|
||||
|
||||
def update_order_indicators(self, trade_info: list):
|
||||
def update_order_indicators(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
|
||||
self._update_order_trade_info(trade_info=trade_info)
|
||||
self._update_order_fulfill_rate()
|
||||
|
||||
def _agg_order_trade_info(self, inner_order_indicators: List[Dict[str, pd.Series]]):
|
||||
def _agg_order_trade_info(self, inner_order_indicators: List[BaseOrderIndicator]) -> None:
|
||||
# calculate total trade amount with each inner order indicator.
|
||||
def trade_amount_func(deal_amount, trade_price):
|
||||
return deal_amount * trade_price
|
||||
@@ -338,7 +346,10 @@ class Indicator:
|
||||
# sum inner order indicators with same metric.
|
||||
all_metric = ["inner_amount", "deal_amount", "trade_price", "trade_value", "trade_cost", "trade_dir"]
|
||||
self.order_indicator_cls.sum_all_indicators(
|
||||
self.order_indicator, inner_order_indicators, all_metric, fill_value=0
|
||||
self.order_indicator,
|
||||
inner_order_indicators,
|
||||
all_metric,
|
||||
fill_value=0,
|
||||
)
|
||||
|
||||
def func(trade_price, deal_amount):
|
||||
@@ -353,9 +364,9 @@ class Indicator:
|
||||
|
||||
self.order_indicator.transfer(func_apply, "trade_dir")
|
||||
|
||||
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision):
|
||||
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision) -> None:
|
||||
# NOTE: these indicator is designed for order execution, so the
|
||||
decision: List[Order] = outer_trade_decision.get_decision()
|
||||
decision: List[Order] = cast(List[Order], outer_trade_decision.get_decision())
|
||||
if len(decision) == 0:
|
||||
self.order_indicator.assign("amount", {})
|
||||
else:
|
||||
@@ -370,7 +381,7 @@ class Indicator:
|
||||
decision: BaseTradeDecision,
|
||||
trade_exchange: Exchange,
|
||||
pa_config: dict = {},
|
||||
):
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""
|
||||
Get the base volume and price information
|
||||
All the base price values are rooted from this function
|
||||
@@ -381,12 +392,17 @@ class Indicator:
|
||||
|
||||
if decision.trade_range is not None:
|
||||
trade_start_time, trade_end_time = decision.trade_range.clip_time_range(
|
||||
start_time=trade_start_time, end_time=trade_end_time
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
)
|
||||
|
||||
if price == "deal_price":
|
||||
price_s = trade_exchange.get_deal_price(
|
||||
inst, trade_start_time, trade_end_time, direction=direction, method=None
|
||||
inst,
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
direction=direction,
|
||||
method=None,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -405,31 +421,35 @@ class Indicator:
|
||||
# NOTE: there are some zeros in the trading price. These cases are known meaningless
|
||||
# for aligning the previous logic, remove it.
|
||||
# remove zero and negative values.
|
||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(np.bool)]
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
|
||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||
# ~(np.NaN < 1e-8) -> ~(False) -> True
|
||||
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
if agg == "vwap":
|
||||
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
|
||||
if isinstance(volume_s, (int, float, np.number)):
|
||||
volume_s = idd.SingleData(volume_s, [trade_start_time])
|
||||
assert isinstance(volume_s, idd.SingleData)
|
||||
volume_s = volume_s.reindex(price_s.index)
|
||||
elif agg == "twap":
|
||||
volume_s = idd.SingleData(1, price_s.index)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
assert isinstance(volume_s, idd.SingleData)
|
||||
base_volume = volume_s.sum()
|
||||
base_price = (price_s * volume_s).sum() / base_volume
|
||||
return base_price, base_volume
|
||||
|
||||
def _agg_base_price(
|
||||
self,
|
||||
inner_order_indicators: List[Dict[str, Union[SingleMetric, idd.SingleData]]],
|
||||
inner_order_indicators: List[BaseOrderIndicator],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
trade_exchange: Exchange,
|
||||
pa_config: dict = {},
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
# NOTE:!!!!
|
||||
# Strong assumption!!!!!!
|
||||
@@ -437,7 +457,7 @@ class Indicator:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inner_order_indicators : List[Dict[str, pd.Series]]
|
||||
inner_order_indicators : List[BaseOrderIndicator]
|
||||
the indicators of account of inner executor
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
a list of decisions according to inner_order_indicators
|
||||
@@ -482,14 +502,17 @@ class Indicator:
|
||||
bv_new = idd.SingleData(bv_new)
|
||||
bp_all.append(bp_new)
|
||||
bv_all.append(bv_new)
|
||||
bp_all = idd.concat(bp_all, axis=1)
|
||||
bv_all = idd.concat(bv_all, axis=1)
|
||||
bp_all_multi_data = idd.concat(bp_all, axis=1)
|
||||
bv_all_multi_data = idd.concat(bv_all, axis=1)
|
||||
|
||||
base_volume = bv_all.sum(axis=1)
|
||||
base_volume = bv_all_multi_data.sum(axis=1)
|
||||
self.order_indicator.assign("base_volume", base_volume.to_dict())
|
||||
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict())
|
||||
self.order_indicator.assign(
|
||||
"base_price",
|
||||
((bp_all_multi_data * bv_all_multi_data).sum(axis=1) / base_volume).to_dict(),
|
||||
)
|
||||
|
||||
def _agg_order_price_advantage(self):
|
||||
def _agg_order_price_advantage(self) -> None:
|
||||
def if_empty_func(trade_price):
|
||||
return trade_price.empty
|
||||
|
||||
@@ -506,12 +529,12 @@ class Indicator:
|
||||
|
||||
def agg_order_indicators(
|
||||
self,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]],
|
||||
inner_order_indicators: List[BaseOrderIndicator],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_exchange: Exchange,
|
||||
indicator_config={},
|
||||
):
|
||||
indicator_config: dict = {},
|
||||
) -> None:
|
||||
self._agg_order_trade_info(inner_order_indicators)
|
||||
self._update_trade_amount(outer_trade_decision)
|
||||
self._update_order_fulfill_rate()
|
||||
@@ -519,71 +542,66 @@ class Indicator:
|
||||
self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO
|
||||
self._agg_order_price_advantage()
|
||||
|
||||
def _cal_trade_fulfill_rate(self, method="mean"):
|
||||
def _cal_trade_fulfill_rate(self, method: str = "mean") -> Optional[BaseSingleMetric]:
|
||||
if method == "mean":
|
||||
|
||||
def func(ffr):
|
||||
return ffr.mean()
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda ffr: ffr.mean(),
|
||||
)
|
||||
elif method == "amount_weighted":
|
||||
|
||||
def func(ffr, deal_amount):
|
||||
return (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda ffr, deal_amount: (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
|
||||
)
|
||||
elif method == "value_weighted":
|
||||
|
||||
def func(ffr, trade_value):
|
||||
return (ffr * trade_value.abs()).sum() / (trade_value.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda ffr, trade_value: (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"method {method} is not supported!")
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_price_advantage(self, method="mean"):
|
||||
def _cal_trade_price_advantage(self, method: str = "mean") -> Optional[BaseSingleMetric]:
|
||||
if method == "mean":
|
||||
|
||||
def func(pa):
|
||||
return pa.mean()
|
||||
|
||||
return self.order_indicator.transfer(lambda pa: pa.mean())
|
||||
elif method == "amount_weighted":
|
||||
|
||||
def func(pa, deal_amount):
|
||||
return (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda pa, deal_amount: (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
|
||||
)
|
||||
elif method == "value_weighted":
|
||||
|
||||
def func(pa, trade_value):
|
||||
return (pa * trade_value.abs()).sum() / (trade_value.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda pa, trade_value: (pa * trade_value.abs()).sum() / (trade_value.abs().sum()),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"method {method} is not supported!")
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_positive_rate(self):
|
||||
def _cal_trade_positive_rate(self) -> Optional[BaseSingleMetric]:
|
||||
def func(pa):
|
||||
return (pa > 0).sum() / pa.count()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_deal_amount(self):
|
||||
def _cal_deal_amount(self) -> Optional[BaseSingleMetric]:
|
||||
def func(deal_amount):
|
||||
return deal_amount.abs().sum()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_value(self):
|
||||
def _cal_trade_value(self) -> Optional[BaseSingleMetric]:
|
||||
def func(trade_value):
|
||||
return trade_value.abs().sum()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_order_count(self):
|
||||
def _cal_trade_order_count(self) -> Optional[BaseSingleMetric]:
|
||||
def func(amount):
|
||||
return amount.count()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}):
|
||||
def cal_trade_indicators(
|
||||
self,
|
||||
trade_start_time: Union[str, pd.Timestamp],
|
||||
freq: str,
|
||||
indicator_config: dict = {},
|
||||
) -> None:
|
||||
show_indicator = indicator_config.get("show_indicator", False)
|
||||
ffr_config = indicator_config.get("ffr_config", {})
|
||||
pa_config = indicator_config.get("pa_config", {})
|
||||
@@ -601,18 +619,22 @@ class Indicator:
|
||||
self.trade_indicator["count"] = order_count
|
||||
if show_indicator:
|
||||
print(
|
||||
"[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
freq, trade_start_time, fulfill_rate, price_advantage, positive_rate
|
||||
)
|
||||
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
freq,
|
||||
trade_start_time
|
||||
if isinstance(trade_start_time, str)
|
||||
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
fulfill_rate,
|
||||
price_advantage,
|
||||
positive_rate,
|
||||
),
|
||||
)
|
||||
|
||||
def get_order_indicator(self, raw: bool = True):
|
||||
if raw:
|
||||
return self.order_indicator
|
||||
return self.order_indicator.to_series()
|
||||
def get_order_indicator(self, raw: bool = True) -> Union[BaseOrderIndicator, Dict[Text, pd.Series]]:
|
||||
return self.order_indicator if raw else self.order_indicator.to_series()
|
||||
|
||||
def get_trade_indicator(self):
|
||||
def get_trade_indicator(self) -> Dict[str, Optional[BaseSingleMetric]]:
|
||||
return self.trade_indicator
|
||||
|
||||
def generate_trade_indicators_dataframe(self):
|
||||
def generate_trade_indicators_dataframe(self) -> pd.DataFrame:
|
||||
return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index")
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.utils import init_instance_by_config
|
||||
import abc
|
||||
from typing import Dict, List, Text, Tuple, Union
|
||||
from ..model.base import BaseModel
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
from ..data.dataset import Dataset
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..model.base import BaseModel
|
||||
from ..utils.resam import resam_ts_data
|
||||
import pandas as pd
|
||||
import abc
|
||||
|
||||
|
||||
class Signal(metaclass=abc.ABCMeta):
|
||||
@@ -19,7 +22,7 @@ class Signal(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
|
||||
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame, None]:
|
||||
"""
|
||||
get the signal at the end of the decision step(from `start_time` to `end_time`)
|
||||
|
||||
@@ -28,7 +31,6 @@ class Signal(metaclass=abc.ABCMeta):
|
||||
Union[pd.Series, pd.DataFrame, None]:
|
||||
returns None if no signal in the specific day
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SignalWCache(Signal):
|
||||
@@ -37,13 +39,14 @@ class SignalWCache(Signal):
|
||||
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
|
||||
"""
|
||||
|
||||
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
|
||||
def __init__(self, signal: Union[pd.Series, pd.DataFrame]) -> None:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : Union[pd.Series, pd.DataFrame]
|
||||
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
|
||||
The expected format of the signal is like the data below (the order of index is not important and can be
|
||||
automatically adjusted)
|
||||
|
||||
instrument datetime
|
||||
SH600000 2008-01-02 0.079704
|
||||
@@ -54,8 +57,8 @@ class SignalWCache(Signal):
|
||||
"""
|
||||
self.signal_cache = convert_index_format(signal, level="datetime")
|
||||
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
|
||||
# the frequency of the signal may not algin with the decision frequency of strategy
|
||||
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame]:
|
||||
# the frequency of the signal may not align with the decision frequency of strategy
|
||||
# so resampling from the data is necessary
|
||||
# the latest signal leverage more recent data and therefore is used in trading.
|
||||
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
|
||||
@@ -63,7 +66,7 @@ class SignalWCache(Signal):
|
||||
|
||||
|
||||
class ModelSignal(SignalWCache):
|
||||
def __init__(self, model: BaseModel, dataset: Dataset):
|
||||
def __init__(self, model: BaseModel, dataset: Dataset) -> None:
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
pred_scores = self.model.predict(dataset)
|
||||
@@ -71,7 +74,7 @@ class ModelSignal(SignalWCache):
|
||||
pred_scores = pred_scores.iloc[:, 0]
|
||||
super().__init__(pred_scores)
|
||||
|
||||
def _update_model(self):
|
||||
def _update_model(self) -> None:
|
||||
"""
|
||||
When using online data, update model in each bar as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
@@ -83,7 +86,7 @@ class ModelSignal(SignalWCache):
|
||||
|
||||
|
||||
def create_signal_from(
|
||||
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame]
|
||||
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame],
|
||||
) -> Signal:
|
||||
"""
|
||||
create signal from diverse information
|
||||
|
||||
@@ -2,16 +2,22 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.utils.time import epsilon_change
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.decision import BaseTradeDecision
|
||||
|
||||
import pandas as pd
|
||||
import warnings
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ..data.data import Cal
|
||||
|
||||
|
||||
@@ -26,8 +32,8 @@ class TradeCalendarManager:
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
level_infra: "LevelInfrastructure" = None,
|
||||
):
|
||||
level_infra: LevelInfrastructure = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -43,19 +49,26 @@ class TradeCalendarManager:
|
||||
self.level_infra = level_infra
|
||||
self.reset(freq=freq, start_time=start_time, end_time=end_time)
|
||||
|
||||
def reset(self, freq, start_time, end_time):
|
||||
def reset(
|
||||
self,
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Please refer to the docs of `__init__`
|
||||
|
||||
Reset the trade calendar
|
||||
- self.trade_len : The total count for trading step
|
||||
- self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1]
|
||||
- self.trade_step : The number of trading step finished, self.trade_step can be
|
||||
[0, 1, 2, ..., self.trade_len - 1]
|
||||
"""
|
||||
self.freq = freq
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(end_time) if end_time else None
|
||||
|
||||
_calendar = Cal.calendar(freq=freq, future=True)
|
||||
assert isinstance(_calendar, np.ndarray)
|
||||
self._calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, future=True)
|
||||
self.start_index = _start_index
|
||||
@@ -63,7 +76,7 @@ class TradeCalendarManager:
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
self.trade_step = 0
|
||||
|
||||
def finished(self):
|
||||
def finished(self) -> bool:
|
||||
"""
|
||||
Check if the trading finished
|
||||
- Should check before calling strategy.generate_decisions and executor.execute
|
||||
@@ -72,29 +85,32 @@ class TradeCalendarManager:
|
||||
"""
|
||||
return self.trade_step >= self.trade_len
|
||||
|
||||
def step(self):
|
||||
def step(self) -> None:
|
||||
if self.finished():
|
||||
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
|
||||
self.trade_step = self.trade_step + 1
|
||||
self.trade_step += 1
|
||||
|
||||
def get_freq(self):
|
||||
def get_freq(self) -> str:
|
||||
return self.freq
|
||||
|
||||
def get_trade_len(self):
|
||||
def get_trade_len(self) -> int:
|
||||
"""get the total step length"""
|
||||
return self.trade_len
|
||||
|
||||
def get_trade_step(self):
|
||||
def get_trade_step(self) -> int:
|
||||
return self.trade_step
|
||||
|
||||
def get_step_time(self, trade_step=None, shift=0):
|
||||
def get_step_time(self, trade_step: int = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
"""
|
||||
Get the left and right endpoints of the trade_step'th trading interval
|
||||
|
||||
About the endpoints:
|
||||
- Qlib uses the closed interval in time-series data selection, which has the same performance as pandas.Series.loc
|
||||
# - The returned right endpoints should minus 1 seconds because of the closed interval representation in Qlib.
|
||||
# Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval.
|
||||
- Qlib uses the closed interval in time-series data selection, which has the same performance as
|
||||
pandas.Series.loc
|
||||
# - The returned right endpoints should minus 1 seconds because of the closed interval representation in
|
||||
# Qlib.
|
||||
# Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time
|
||||
# interval.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -105,15 +121,14 @@ class TradeCalendarManager:
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[pd.Timestamp, pd.Timestap]
|
||||
Tuple[pd.Timestamp, pd.Timestamp]
|
||||
- If shift == 0, return the trading time range
|
||||
- If shift > 0, return the trading time range of the earlier shift bars
|
||||
- If shift < 0, return the trading time range of the later shift bar
|
||||
"""
|
||||
if trade_step is None:
|
||||
trade_step = self.get_trade_step()
|
||||
trade_step = trade_step - shift
|
||||
calendar_index = self.start_index + trade_step
|
||||
calendar_index = self.start_index + trade_step - shift
|
||||
return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1])
|
||||
|
||||
def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]:
|
||||
@@ -126,7 +141,7 @@ class TradeCalendarManager:
|
||||
Parameters
|
||||
----------
|
||||
rtype: str
|
||||
- "full": return the full limitation of the deicsion in the day
|
||||
- "full": return the full limitation of the decision in the day
|
||||
- "step": return the limitation of current step
|
||||
|
||||
Returns
|
||||
@@ -134,6 +149,8 @@ class TradeCalendarManager:
|
||||
Tuple[int, int]:
|
||||
"""
|
||||
# potential performance issue
|
||||
assert self.level_infra is not None
|
||||
|
||||
day_start = pd.Timestamp(self.start_time.date())
|
||||
day_end = epsilon_change(day_start + pd.Timedelta(days=1))
|
||||
freq = self.level_infra.get("common_infra").get("trade_exchange").freq
|
||||
@@ -148,7 +165,7 @@ class TradeCalendarManager:
|
||||
|
||||
return start_idx - day_start_idx, end_index - day_start_idx
|
||||
|
||||
def get_all_time(self):
|
||||
def get_all_time(self) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
"""Get the start_time and end_time for trading"""
|
||||
return self.start_time, self.end_time
|
||||
|
||||
@@ -167,30 +184,33 @@ class TradeCalendarManager:
|
||||
Tuple[int, int]:
|
||||
the index of the range. **the left and right are closed**
|
||||
"""
|
||||
left, right = (
|
||||
bisect.bisect_right(self._calendar, start_time) - 1,
|
||||
bisect.bisect_right(self._calendar, end_time) - 1,
|
||||
)
|
||||
left = bisect.bisect_right(list(self._calendar), start_time) - 1
|
||||
right = bisect.bisect_right(list(self._calendar), end_time) - 1
|
||||
left -= self.start_index
|
||||
right -= self.start_index
|
||||
|
||||
def clip(idx):
|
||||
def clip(idx: int) -> int:
|
||||
return min(max(0, idx), self.trade_len - 1)
|
||||
|
||||
return clip(left), clip(right)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"class: {self.__class__.__name__}; {self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]"
|
||||
return (
|
||||
f"class: {self.__class__.__name__}; "
|
||||
f"{self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: "
|
||||
f"[{self.trade_step}/{self.trade_len}]"
|
||||
)
|
||||
|
||||
|
||||
class BaseInfrastructure:
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.reset_infra(**kwargs)
|
||||
|
||||
def get_support_infra(self):
|
||||
@abstractmethod
|
||||
def get_support_infra(self) -> Set[str]:
|
||||
raise NotImplementedError("`get_support_infra` is not implemented!")
|
||||
|
||||
def reset_infra(self, **kwargs):
|
||||
def reset_infra(self, **kwargs: Any) -> None:
|
||||
support_infra = self.get_support_infra()
|
||||
for k, v in kwargs.items():
|
||||
if k in support_infra:
|
||||
@@ -198,53 +218,58 @@ class BaseInfrastructure:
|
||||
else:
|
||||
warnings.warn(f"{k} is ignored in `reset_infra`!")
|
||||
|
||||
def get(self, infra_name):
|
||||
def get(self, infra_name: str) -> Any:
|
||||
if hasattr(self, infra_name):
|
||||
return getattr(self, infra_name)
|
||||
else:
|
||||
warnings.warn(f"infra {infra_name} is not found!")
|
||||
|
||||
def has(self, infra_name):
|
||||
def has(self, infra_name: str) -> bool:
|
||||
return infra_name in self.get_support_infra() and hasattr(self, infra_name)
|
||||
|
||||
def update(self, other):
|
||||
def update(self, other: BaseInfrastructure) -> None:
|
||||
support_infra = other.get_support_infra()
|
||||
infra_dict = {_infra: getattr(other, _infra) for _infra in support_infra if hasattr(other, _infra)}
|
||||
self.reset_infra(**infra_dict)
|
||||
|
||||
|
||||
class CommonInfrastructure(BaseInfrastructure):
|
||||
def get_support_infra(self):
|
||||
return ["trade_account", "trade_exchange"]
|
||||
def get_support_infra(self) -> Set[str]:
|
||||
return {"trade_account", "trade_exchange"}
|
||||
|
||||
|
||||
class LevelInfrastructure(BaseInfrastructure):
|
||||
"""level infrastructure is created by executor, and then shared to strategies on the same level"""
|
||||
|
||||
def get_support_infra(self):
|
||||
def get_support_infra(self) -> Set[str]:
|
||||
"""
|
||||
Descriptions about the infrastructure
|
||||
|
||||
sub_level_infra:
|
||||
- **NOTE**: this will only work after _init_sub_trading !!!
|
||||
"""
|
||||
return ["trade_calendar", "sub_level_infra", "common_infra"]
|
||||
return {"trade_calendar", "sub_level_infra", "common_infra"}
|
||||
|
||||
def reset_cal(self, freq, start_time, end_time):
|
||||
def reset_cal(
|
||||
self,
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp, None],
|
||||
end_time: Union[str, pd.Timestamp, None],
|
||||
) -> None:
|
||||
"""reset trade calendar manager"""
|
||||
if self.has("trade_calendar"):
|
||||
self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time)
|
||||
else:
|
||||
self.reset_infra(
|
||||
trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self)
|
||||
trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self),
|
||||
)
|
||||
|
||||
def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure):
|
||||
"""this will make the calendar access easier when acrossing multi-levels"""
|
||||
def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure) -> None:
|
||||
"""this will make the calendar access easier when crossing multi-levels"""
|
||||
self.reset_infra(sub_level_infra=sub_level_infra)
|
||||
|
||||
|
||||
def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Union[int, int]:
|
||||
def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Tuple[int, int]:
|
||||
"""
|
||||
A helper function for getting the decision-level index range limitation for inner strategy
|
||||
- NOTE: this function is not applicable to order-level
|
||||
|
||||
@@ -22,7 +22,7 @@ from pathlib import Path
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from qlib.constant import REG_CN, REG_US
|
||||
from qlib.constant import REG_CN, REG_US, REG_TW
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.utils.time import Freq
|
||||
@@ -75,6 +75,17 @@ class Config:
|
||||
def set_conf_from_C(self, config_c):
|
||||
self.update(**config_c.__dict__["_config"])
|
||||
|
||||
def register_from_C(self, config, skip_register=True):
|
||||
from .utils import set_log_with_config # pylint: disable=C0415
|
||||
|
||||
if C.registered and skip_register:
|
||||
return
|
||||
|
||||
C.set_conf_from_C(config)
|
||||
if C.logging_config:
|
||||
set_log_with_config(C.logging_config)
|
||||
C.register()
|
||||
|
||||
|
||||
# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format
|
||||
PROTOCOL_VERSION = 4
|
||||
@@ -92,6 +103,7 @@ _default_config = {
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
"instrument_provider": "LocalInstrumentProvider",
|
||||
"feature_provider": "LocalFeatureProvider",
|
||||
"pit_provider": "LocalPITProvider",
|
||||
"expression_provider": "LocalExpressionProvider",
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
@@ -101,14 +113,13 @@ _default_config = {
|
||||
# "~/.qlib/stock_data/cn_data"
|
||||
# # dict
|
||||
# {"day": "~/.qlib/stock_data/cn_data", "1min": "~/.qlib/stock_data/cn_data_1min"}
|
||||
# NOTE: provider_uri priority:
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
|
||||
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
|
||||
# 3. qlib.init: provider_uri
|
||||
"provider_uri": "",
|
||||
# cache
|
||||
"expression_cache": None,
|
||||
"dataset_cache": None,
|
||||
"calendar_cache": None,
|
||||
# for simple dataset cache
|
||||
"local_cache_path": None,
|
||||
@@ -171,6 +182,18 @@ _default_config = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
},
|
||||
"pit_record_type": {
|
||||
"date": "I", # uint32
|
||||
"period": "I", # uint32
|
||||
"value": "d", # float64
|
||||
"index": "I", # uint32
|
||||
},
|
||||
"pit_record_nan": {
|
||||
"date": 0,
|
||||
"period": 0,
|
||||
"value": float("NAN"),
|
||||
"index": 0xFFFFFFFF,
|
||||
},
|
||||
# Default config for MongoDB
|
||||
"mongo": {
|
||||
"task_url": "mongodb://localhost:27017/",
|
||||
@@ -184,20 +207,12 @@ _default_config = {
|
||||
|
||||
MODE_CONF = {
|
||||
"server": {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
"instrument_provider": "LocalInstrumentProvider",
|
||||
"feature_provider": "LocalFeatureProvider",
|
||||
"expression_provider": "LocalExpressionProvider",
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
# config it in qlib.init()
|
||||
"provider_uri": "",
|
||||
# redis
|
||||
"redis_host": "127.0.0.1",
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# cache
|
||||
"expression_cache": DISK_EXPRESSION_CACHE,
|
||||
"dataset_cache": DISK_DATASET_CACHE,
|
||||
@@ -205,25 +220,15 @@ MODE_CONF = {
|
||||
"mount_path": None,
|
||||
},
|
||||
"client": {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
"instrument_provider": "LocalInstrumentProvider",
|
||||
"feature_provider": "LocalFeatureProvider",
|
||||
"expression_provider": "LocalExpressionProvider",
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
# config it in user's own code
|
||||
"provider_uri": "~/.qlib/qlib_data/cn_data",
|
||||
# cache
|
||||
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
|
||||
# Disable cache by default. Avoid introduce advanced features for beginners
|
||||
"expression_cache": None,
|
||||
"dataset_cache": None,
|
||||
# SimpleDatasetCache directory
|
||||
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
|
||||
"calendar_cache": None,
|
||||
# client config
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
"mount_path": None,
|
||||
"auto_mount": False, # The nfs is already mounted on our server[auto_mount: False].
|
||||
# The nfs should be auto-mounted by qlib on other
|
||||
@@ -257,6 +262,11 @@ _default_region_config = {
|
||||
"limit_threshold": None,
|
||||
"deal_price": "close",
|
||||
},
|
||||
REG_TW: {
|
||||
"trade_unit": 1000,
|
||||
"limit_threshold": 0.1,
|
||||
"deal_price": "close",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -388,13 +398,11 @@ class QlibConfig(Config):
|
||||
default_conf : str
|
||||
the default config template chosen by user: "server", "client"
|
||||
"""
|
||||
from .utils import set_log_with_config, get_module_logger, can_use_cache
|
||||
from .utils import set_log_with_config, get_module_logger, can_use_cache # pylint: disable=C0415
|
||||
|
||||
self.reset()
|
||||
|
||||
_logging_config = self.logging_config
|
||||
if "logging_config" in kwargs:
|
||||
_logging_config = kwargs["logging_config"]
|
||||
_logging_config = kwargs.get("logging_config", self.logging_config)
|
||||
|
||||
# set global config
|
||||
if _logging_config:
|
||||
@@ -433,11 +441,11 @@ class QlibConfig(Config):
|
||||
)
|
||||
|
||||
def register(self):
|
||||
from .utils import init_instance_by_config
|
||||
from .data.ops import register_all_ops
|
||||
from .data.data import register_all_wrappers
|
||||
from .workflow import R, QlibRecorder
|
||||
from .workflow.utils import experiment_exit_handler
|
||||
from .utils import init_instance_by_config # pylint: disable=C0415
|
||||
from .data.ops import register_all_ops # pylint: disable=C0415
|
||||
from .data.data import register_all_wrappers # pylint: disable=C0415
|
||||
from .workflow import R, QlibRecorder # pylint: disable=C0415
|
||||
from .workflow.utils import experiment_exit_handler # pylint: disable=C0415
|
||||
|
||||
register_all_ops(self)
|
||||
register_all_wrappers(self)
|
||||
@@ -454,7 +462,7 @@ class QlibConfig(Config):
|
||||
self._registered = True
|
||||
|
||||
def reset_qlib_version(self):
|
||||
import qlib
|
||||
import qlib # pylint: disable=C0415
|
||||
|
||||
reset_version = self.get("qlib_reset_version", None)
|
||||
if reset_version is not None:
|
||||
|
||||
@@ -4,6 +4,10 @@
|
||||
# REGION CONST
|
||||
REG_CN = "cn"
|
||||
REG_US = "us"
|
||||
REG_TW = "tw"
|
||||
|
||||
# Epsilon for avoiding division by zero.
|
||||
EPS = 1e-12
|
||||
|
||||
# Infinity in integer
|
||||
INF = 10**18
|
||||
|
||||
@@ -7,8 +7,7 @@ import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -16,7 +15,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
def _to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return torch.tensor(x, dtype=torch.float, device=device)
|
||||
return torch.tensor(x, dtype=torch.float, device=device) # pylint: disable=E1101
|
||||
return x
|
||||
|
||||
|
||||
@@ -64,11 +63,20 @@ def _get_date_parse_fn(target):
|
||||
get_date_parse_fn(20120101)('2017-01-01') => 20170101
|
||||
"""
|
||||
if isinstance(target, int):
|
||||
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
|
||||
|
||||
def _fn(x):
|
||||
return int(str(x).replace("-", "")[:8]) # 20200201
|
||||
|
||||
elif isinstance(target, str) and len(target) == 8:
|
||||
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
|
||||
|
||||
def _fn(x):
|
||||
return str(x).replace("-", "")[:8] # '20200201'
|
||||
|
||||
else:
|
||||
_fn = lambda x: x # '2021-01-01'
|
||||
|
||||
def _fn(x):
|
||||
return x # '2021-01-01'
|
||||
|
||||
return _fn
|
||||
|
||||
|
||||
|
||||
@@ -5,9 +5,7 @@ from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_callable_kwargs
|
||||
from ...data.dataset import processor as processor_module
|
||||
from ...log import TimeInspector
|
||||
from inspect import getfullargspec
|
||||
import copy
|
||||
|
||||
|
||||
def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
@@ -257,7 +255,10 @@ class Alpha158(DataHandlerLP):
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
use = lambda x: x not in exclude and (include is None or x in include)
|
||||
|
||||
def use(x):
|
||||
return x not in exclude and (include is None or x in include)
|
||||
|
||||
if use("ROC"):
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
|
||||
164
qlib/contrib/data/highfreq_handler.py
Normal file
164
qlib/contrib/data/highfreq_handler.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
|
||||
|
||||
EPSILON = 1e-4
|
||||
|
||||
|
||||
class HighFreqHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
drop_raw=drop_raw,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
# norm with the close price of 237th minute of yesterday.
|
||||
if shift == 0:
|
||||
template_norm = "{0}/DayLast(Ref({1}, 243))"
|
||||
else:
|
||||
template_norm = "Ref({0}, " + str(shift) + ")/DayLast(Ref({1}, 243))"
|
||||
|
||||
template_fillnan = "FFillNan({0})"
|
||||
# calculate -> ffill -> remove paused
|
||||
feature_ops = template_paused.format(
|
||||
template_fillnan.format(
|
||||
template_norm.format(template_if.format("$close", price_field), template_fillnan.format("$close"))
|
||||
)
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 0)]
|
||||
fields += [get_normalized_price_feature("$high", 0)]
|
||||
fields += [get_normalized_price_feature("$low", 0)]
|
||||
fields += [get_normalized_price_feature("$close", 0)]
|
||||
fields += [get_normalized_price_feature("$vwap", 0)]
|
||||
names += ["$open", "$high", "$low", "$close", "$vwap"]
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 240)]
|
||||
fields += [get_normalized_price_feature("$high", 240)]
|
||||
fields += [get_normalized_price_feature("$low", 240)]
|
||||
fields += [get_normalized_price_feature("$close", 240)]
|
||||
fields += [get_normalized_price_feature("$vwap", 240)]
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
|
||||
# calculate and fill nan with 0
|
||||
template_gzero = "If(Ge({0}, 0), {0}, 0)"
|
||||
fields += [
|
||||
template_gzero.format(
|
||||
template_paused.format(
|
||||
"If(IsNull({0}), 0, {0})".format("{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format("$volume"))
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume"]
|
||||
|
||||
fields += [
|
||||
template_gzero.format(
|
||||
template_paused.format(
|
||||
"If(IsNull({0}), 0, {0})".format(
|
||||
"Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format("$volume")
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqBacktestHandler(DataHandler):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
# template_paused = "{0}"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
|
||||
fields += [
|
||||
template_paused.format(
|
||||
template_if.format(
|
||||
template_fillnan.format("$close"),
|
||||
"$vwap",
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
|
||||
names += ["$volume0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$factor"))]
|
||||
names += ["$factor0"]
|
||||
|
||||
return fields, names
|
||||
81
qlib/contrib/data/highfreq_processor.py
Normal file
81
qlib/contrib/data/highfreq_processor.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.data.dataset.utils import fetch_df_by_index
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class HighFreqTrans(Processor):
|
||||
def __init__(self, dtype: str = "bool"):
|
||||
self.dtype = dtype
|
||||
|
||||
def fit(self, df_features):
|
||||
pass
|
||||
|
||||
def __call__(self, df_features):
|
||||
if self.dtype == "bool":
|
||||
return df_features.astype(np.int8)
|
||||
else:
|
||||
return df_features.astype(np.float32)
|
||||
|
||||
|
||||
class HighFreqNorm(Processor):
|
||||
def __init__(
|
||||
self,
|
||||
fit_start_time: pd.Timestamp,
|
||||
fit_end_time: pd.Timestamp,
|
||||
feature_save_dir: str,
|
||||
norm_groups: Dict[str, int],
|
||||
):
|
||||
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.feature_save_dir = feature_save_dir
|
||||
self.norm_groups = norm_groups
|
||||
|
||||
def fit(self, df_features) -> None:
|
||||
if os.path.exists(self.feature_save_dir) and len(os.listdir(self.feature_save_dir)) != 0:
|
||||
return
|
||||
os.makedirs(self.feature_save_dir)
|
||||
fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
del df_features
|
||||
index = 0
|
||||
names = {}
|
||||
for name, dim in self.norm_groups.items():
|
||||
names[name] = slice(index, index + dim)
|
||||
index += dim
|
||||
for name, name_val in names.items():
|
||||
df_values = fetch_df.iloc(axis=1)[name_val].values
|
||||
if name.endswith("volume"):
|
||||
df_values = np.log1p(df_values)
|
||||
self.feature_mean = np.nanmean(df_values)
|
||||
np.save(self.feature_save_dir + name + "_mean.npy", self.feature_mean)
|
||||
df_values = df_values - self.feature_mean
|
||||
self.feature_std = np.nanstd(np.absolute(df_values))
|
||||
np.save(self.feature_save_dir + name + "_std.npy", self.feature_std)
|
||||
df_values = df_values / self.feature_std
|
||||
np.save(self.feature_save_dir + name + "_vmax.npy", np.nanmax(df_values))
|
||||
np.save(self.feature_save_dir + name + "_vmin.npy", np.nanmin(df_values))
|
||||
return
|
||||
|
||||
def __call__(self, df_features):
|
||||
if "date" in df_features:
|
||||
df_features.droplevel("date", inplace=True)
|
||||
df_values = df_features.values
|
||||
index = 0
|
||||
names = {}
|
||||
for name, dim in self.norm_groups.items():
|
||||
names[name] = slice(index, index + dim)
|
||||
index += dim
|
||||
for name, name_val in names.items():
|
||||
feature_mean = np.load(self.feature_save_dir + name + "_mean.npy")
|
||||
feature_std = np.load(self.feature_save_dir + name + "_std.npy")
|
||||
|
||||
if name.endswith("volume"):
|
||||
df_values[:, name_val] = np.log1p(df_values[:, name_val])
|
||||
df_values[:, name_val] -= feature_mean
|
||||
df_values[:, name_val] /= feature_std
|
||||
df_features = pd.DataFrame(data=df_values, index=df_features.index, columns=df_features.columns)
|
||||
return df_features.fillna(0)
|
||||
301
qlib/contrib/data/highfreq_provider.py
Normal file
301
qlib/contrib/data/highfreq_provider.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.data import Cal
|
||||
from qlib.contrib.ops.high_freq import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut
|
||||
import pickle as pkl
|
||||
from joblib import Parallel, delayed
|
||||
from utilsd.logging import print_log
|
||||
|
||||
|
||||
class HighFreqProvider:
|
||||
def __init__(
|
||||
self,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
train_end_time: str,
|
||||
valid_start_time: str,
|
||||
valid_end_time: str,
|
||||
test_start_time: str,
|
||||
qlib_conf: dict,
|
||||
feature_conf: dict,
|
||||
label_conf: Optional[dict] = None,
|
||||
backtest_conf: dict = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.test_start_time = test_start_time
|
||||
self.train_end_time = train_end_time
|
||||
self.valid_start_time = valid_start_time
|
||||
self.valid_end_time = valid_end_time
|
||||
self._init_qlib(qlib_conf)
|
||||
self.feature_conf = feature_conf
|
||||
self.label_conf = label_conf
|
||||
self.backtest_conf = backtest_conf
|
||||
self.qlib_conf = qlib_conf
|
||||
|
||||
def get_pre_datasets(self):
|
||||
"""Generate the training, validation and test datasets for prediction
|
||||
|
||||
Returns:
|
||||
Tuple[BaseDataset, BaseDataset, BaseDataset]: The training and test datasets
|
||||
"""
|
||||
|
||||
dict_feature_path = self.feature_conf["path"]
|
||||
train_feature_path = dict_feature_path[:-4] + "_train.pkl"
|
||||
valid_feature_path = dict_feature_path[:-4] + "_valid.pkl"
|
||||
test_feature_path = dict_feature_path[:-4] + "_test.pkl"
|
||||
|
||||
dict_label_path = self.label_conf["path"]
|
||||
train_label_path = dict_label_path[:-4] + "_train.pkl"
|
||||
valid_label_path = dict_label_path[:-4] + "_valid.pkl"
|
||||
test_label_path = dict_label_path[:-4] + "_test.pkl"
|
||||
|
||||
if (
|
||||
not os.path.isfile(train_feature_path)
|
||||
or not os.path.isfile(valid_feature_path)
|
||||
or not os.path.isfile(test_feature_path)
|
||||
):
|
||||
xtrain, xvalid, xtest = self._gen_data(self.feature_conf)
|
||||
xtrain.to_pickle(train_feature_path)
|
||||
xvalid.to_pickle(valid_feature_path)
|
||||
xtest.to_pickle(test_feature_path)
|
||||
del xtrain, xvalid, xtest
|
||||
|
||||
if (
|
||||
not os.path.isfile(train_label_path)
|
||||
or not os.path.isfile(valid_label_path)
|
||||
or not os.path.isfile(test_label_path)
|
||||
):
|
||||
ytrain, yvalid, ytest = self._gen_data(self.label_conf)
|
||||
ytrain.to_pickle(train_label_path)
|
||||
yvalid.to_pickle(valid_label_path)
|
||||
ytest.to_pickle(test_label_path)
|
||||
del ytrain, yvalid, ytest
|
||||
|
||||
feature = {
|
||||
"train": train_feature_path,
|
||||
"valid": valid_feature_path,
|
||||
"test": test_feature_path,
|
||||
}
|
||||
|
||||
label = {
|
||||
"train": train_label_path,
|
||||
"valid": valid_label_path,
|
||||
"test": test_label_path,
|
||||
}
|
||||
|
||||
return feature, label
|
||||
|
||||
def get_backtest(self, **kwargs) -> None:
|
||||
self._gen_data(self.backtest_conf)
|
||||
|
||||
def _init_qlib(self, qlib_conf):
|
||||
"""initialize qlib"""
|
||||
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut],
|
||||
expression_cache=None,
|
||||
**qlib_conf,
|
||||
)
|
||||
|
||||
def _prepare_calender_cache(self):
|
||||
"""preload the calendar for cache"""
|
||||
|
||||
# This code used the copy-on-write feature of Linux
|
||||
# to avoid calculating the calendar multiple times in the subprocess.
|
||||
# This code may accelerate, but may be not useful on Windows and Mac Os
|
||||
Cal.calendar(freq="1min")
|
||||
get_calendar_day(freq="1min")
|
||||
|
||||
def _gen_dataframe(self, config, datasets=["train", "valid", "test"]):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
data = pkl.load(f)
|
||||
if isinstance(data, dict):
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
trainset, validset, testset = dataset.prepare(["train", "valid", "test"])
|
||||
data = {
|
||||
"train": trainset,
|
||||
"valid": validset,
|
||||
"test": testset,
|
||||
}
|
||||
with open(path, "wb") as f:
|
||||
pkl.dump(data, f)
|
||||
with open(path[:-4] + "train.pkl", "wb") as f:
|
||||
pkl.dump(trainset, f)
|
||||
with open(path[:-4] + "valid.pkl", "wb") as f:
|
||||
pkl.dump(validset, f)
|
||||
with open(path[:-4] + "test.pkl", "wb") as f:
|
||||
pkl.dump(testset, f)
|
||||
res = [data[i] for i in datasets]
|
||||
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
return res
|
||||
|
||||
def _gen_data(self, config, datasets=["train", "valid", "test"]):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
data = pkl.load(f)
|
||||
if isinstance(data, dict):
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
res = dataset.prepare(datasets)
|
||||
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
return res
|
||||
|
||||
def _gen_dataset(self, config):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
dataset = pkl.load(f)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.prepare(["train", "valid", "test"])
|
||||
print_log(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
return dataset
|
||||
|
||||
def _gen_day_dataset(self, config, conf_type):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
with open(path + "tmp_dataset.pkl", "rb") as f:
|
||||
new_dataset = pkl.load(f)
|
||||
|
||||
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="1min")[::240]
|
||||
|
||||
def generate_dataset(times):
|
||||
if os.path.isfile(path + times.strftime("%Y-%m-%d") + ".pkl"):
|
||||
print("exist " + times.strftime("%Y-%m-%d"))
|
||||
return
|
||||
self._init_qlib(self.qlib_conf)
|
||||
end_times = times + datetime.timedelta(days=1)
|
||||
new_dataset.handler.config(**{"start_time": times, "end_time": end_times})
|
||||
if conf_type == "backtest":
|
||||
new_dataset.handler.setup_data()
|
||||
else:
|
||||
new_dataset.handler.setup_data(init_type=DataHandlerLP.IT_LS)
|
||||
new_dataset.config(dump_all=True, recursive=True)
|
||||
new_dataset.to_pickle(path + times.strftime("%Y-%m-%d") + ".pkl")
|
||||
|
||||
Parallel(n_jobs=8)(delayed(generate_dataset)(times) for times in time_list)
|
||||
|
||||
def _gen_stock_dataset(self, config, conf_type):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
with open(path + "tmp_dataset.pkl", "rb") as f:
|
||||
new_dataset = pkl.load(f)
|
||||
|
||||
instruments = D.instruments(market="all")
|
||||
stock_list = D.list_instruments(
|
||||
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq="1min", as_list=True
|
||||
)
|
||||
|
||||
def generate_dataset(stock):
|
||||
if os.path.isfile(path + stock + ".pkl"):
|
||||
print("exist " + stock)
|
||||
return
|
||||
self._init_qlib(self.qlib_conf)
|
||||
new_dataset.handler.config(**{"instruments": [stock]})
|
||||
if conf_type == "backtest":
|
||||
new_dataset.handler.setup_data()
|
||||
else:
|
||||
new_dataset.handler.setup_data(init_type=DataHandlerLP.IT_LS)
|
||||
new_dataset.config(dump_all=True, recursive=True)
|
||||
new_dataset.to_pickle(path + stock + ".pkl")
|
||||
|
||||
Parallel(n_jobs=32)(delayed(generate_dataset)(stock) for stock in stock_list)
|
||||
@@ -1,9 +1,6 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
from ...log import TimeInspector
|
||||
from ...utils.serial import Serializable
|
||||
from ...data.dataset.processor import Processor, get_group_columns
|
||||
|
||||
|
||||
@@ -62,10 +59,10 @@ class ConfigSectionProcessor(Processor):
|
||||
|
||||
# Features
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x**0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x**0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
_cols = [
|
||||
"KMID",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import pandas as pd
|
||||
from typing import Dict, Iterable
|
||||
from typing import Dict, Iterable, Union
|
||||
|
||||
|
||||
def align_index(df_dict, join):
|
||||
@@ -24,6 +24,10 @@ class SepDataFrame:
|
||||
SepDataFrame tries to act like a DataFrame whose column with multiindex
|
||||
"""
|
||||
|
||||
# TODO:
|
||||
# SepDataFrame try to behave like pandas dataframe, but it is still not them same
|
||||
# Contributions are welcome to make it more complete.
|
||||
|
||||
def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False):
|
||||
"""
|
||||
initialize the data based on the dataframe dictionary
|
||||
@@ -77,14 +81,37 @@ class SepDataFrame:
|
||||
|
||||
def _update_join(self):
|
||||
if self.join not in self:
|
||||
self.join = next(iter(self._df_dict.keys()))
|
||||
if len(self._df_dict) > 0:
|
||||
self.join = next(iter(self._df_dict.keys()))
|
||||
else:
|
||||
# NOTE: this will change the behavior of previous reindex when all the keys are empty
|
||||
self.join = None
|
||||
|
||||
def __getitem__(self, item):
|
||||
# TODO: behave more like pandas when multiindex
|
||||
return self._df_dict[item]
|
||||
|
||||
def __setitem__(self, item: str, df: pd.DataFrame):
|
||||
def __setitem__(self, item: str, df: Union[pd.DataFrame, pd.Series]):
|
||||
# TODO: consider the join behavior
|
||||
self._df_dict[item] = df
|
||||
if not isinstance(item, tuple):
|
||||
self._df_dict[item] = df
|
||||
else:
|
||||
# NOTE: corner case of MultiIndex
|
||||
_df_dict_key, *col_name = item
|
||||
col_name = tuple(col_name)
|
||||
if _df_dict_key in self._df_dict:
|
||||
if len(col_name) == 1:
|
||||
col_name = col_name[0]
|
||||
self._df_dict[_df_dict_key][col_name] = df
|
||||
else:
|
||||
if isinstance(df, pd.Series):
|
||||
if len(col_name) == 1:
|
||||
col_name = col_name[0]
|
||||
self._df_dict[_df_dict_key] = df.to_frame(col_name)
|
||||
else:
|
||||
df_copy = df.copy() # avoid changing df
|
||||
df_copy.columns = pd.MultiIndex.from_tuples([(*col_name, *idx) for idx in df.columns.to_list()])
|
||||
self._df_dict[_df_dict_key] = df_copy
|
||||
|
||||
def __delitem__(self, item: str):
|
||||
del self._df_dict[item]
|
||||
@@ -164,14 +191,14 @@ import builtins
|
||||
|
||||
|
||||
def _isinstance(instance, cls):
|
||||
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602
|
||||
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602 # noqa: F821
|
||||
if isinstance(cls, Iterable):
|
||||
for c in cls:
|
||||
if c is pd.DataFrame:
|
||||
return True
|
||||
elif cls is pd.DataFrame:
|
||||
return True
|
||||
return isinstance_orig(instance, cls) # pylint: disable=E0602
|
||||
return isinstance_orig(instance, cls) # pylint: disable=E0602 # noqa: F821
|
||||
|
||||
|
||||
builtins.isinstance_orig = builtins.isinstance
|
||||
|
||||
@@ -4,8 +4,10 @@ Here is a batch of evaluation functions.
|
||||
The interface should be redesigned carefully in the future.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
from typing import Tuple
|
||||
from qlib import get_module_logger
|
||||
from qlib.utils.paral import complex_parallel, DelayedDict
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
|
||||
def calc_long_short_prec(
|
||||
@@ -46,7 +48,9 @@ def calc_long_short_prec(
|
||||
|
||||
group = df.groupby(level=date_col)
|
||||
|
||||
N = lambda x: int(len(x) * quantile)
|
||||
def N(x):
|
||||
return int(len(x) * quantile)
|
||||
|
||||
# find the top/low quantile of prediction and treat them as long and short target
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
@@ -61,32 +65,6 @@ def calc_long_short_prec(
|
||||
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred :
|
||||
pred
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
return ic, ric
|
||||
|
||||
|
||||
def calc_long_short_return(
|
||||
pred: pd.Series,
|
||||
label: pd.Series,
|
||||
@@ -122,8 +100,113 @@ def calc_long_short_return(
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
group = df.groupby(level=date_col)
|
||||
N = lambda x: int(len(x) * quantile)
|
||||
|
||||
def N(x):
|
||||
return int(len(x) * quantile)
|
||||
|
||||
r_long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label.mean())
|
||||
r_short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label.mean())
|
||||
r_avg = group.label.mean()
|
||||
return (r_long - r_short) / 2, r_avg
|
||||
|
||||
|
||||
def pred_autocorr(pred: pd.Series, lag=1, inst_col="instrument", date_col="datetime"):
|
||||
"""pred_autocorr.
|
||||
|
||||
Limitation:
|
||||
- If the datetime is not sequential densely, the correlation will be calulated based on adjacent dates. (some users may expected NaN)
|
||||
|
||||
:param pred: pd.Series with following format
|
||||
instrument datetime
|
||||
SH600000 2016-01-04 -0.000403
|
||||
2016-01-05 -0.000753
|
||||
2016-01-06 -0.021801
|
||||
2016-01-07 -0.065230
|
||||
2016-01-08 -0.062465
|
||||
:type pred: pd.Series
|
||||
:param lag:
|
||||
"""
|
||||
if isinstance(pred, pd.DataFrame):
|
||||
pred = pred.iloc[:, 0]
|
||||
get_module_logger("pred_autocorr").warning(f"Only the first column in {pred.columns} of `pred` is kept")
|
||||
pred_ustk = pred.sort_index().unstack(inst_col)
|
||||
corr_s = {}
|
||||
for (idx, cur), (_, prev) in zip(pred_ustk.iterrows(), pred_ustk.shift(lag).iterrows()):
|
||||
corr_s[idx] = cur.corr(prev)
|
||||
corr_s = pd.Series(corr_s).sort_index()
|
||||
return corr_s
|
||||
|
||||
|
||||
def pred_autocorr_all(pred_dict, n_jobs=-1, **kwargs):
|
||||
"""
|
||||
calculate auto correlation for pred_dict
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred_dict : dict
|
||||
A dict like {<method_name>: <prediction>}
|
||||
kwargs :
|
||||
all these arguments will be passed into pred_autocorr
|
||||
"""
|
||||
ac_dict = {}
|
||||
for k, pred in pred_dict.items():
|
||||
ac_dict[k] = delayed(pred_autocorr)(pred, **kwargs)
|
||||
return complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), ac_dict)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> (pd.Series, pd.Series):
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred :
|
||||
pred
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
return ic, ric
|
||||
|
||||
|
||||
def calc_all_ic(pred_dict_all, label, date_col="datetime", dropna=False, n_jobs=-1):
|
||||
"""calc_all_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred_dict_all :
|
||||
A dict like {<method_name>: <prediction>}
|
||||
label:
|
||||
A pd.Series of label values
|
||||
|
||||
Returns
|
||||
-------
|
||||
{'Q2+IND_z': {'ic': <ic series like>
|
||||
2016-01-04 -0.057407
|
||||
...
|
||||
2020-05-28 0.183470
|
||||
2020-05-29 0.171393
|
||||
'ric': <rank ic series like>
|
||||
2016-01-04 -0.040888
|
||||
...
|
||||
2020-05-28 0.236665
|
||||
2020-05-29 0.183886
|
||||
}
|
||||
...}
|
||||
"""
|
||||
pred_all_ics = {}
|
||||
for k, pred in pred_dict_all.items():
|
||||
pred_all_ics[k] = DelayedDict(["ic", "ric"], delayed(calc_ic)(pred, label, date_col=date_col, dropna=dropna))
|
||||
pred_all_ics = complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), pred_all_ics)
|
||||
return pred_all_ics
|
||||
|
||||
@@ -26,6 +26,13 @@ logger = get_module_logger("Evaluate")
|
||||
|
||||
def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
"""Risk Analysis
|
||||
NOTE:
|
||||
The calculation of annulaized return is different from the definition of annualized return.
|
||||
It is implemented by design.
|
||||
Qlib tries to cumulated returns by summation instead of production to avoid the cumulated curve being skewed exponentially.
|
||||
All the calculation of annualized returns follows this principle in Qlib.
|
||||
|
||||
TODO: add a parameter to enable calculating metrics with production accumulation of return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -332,7 +339,7 @@ def long_short_backtest(
|
||||
for stock in long_stocks:
|
||||
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
|
||||
continue
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||
if np.isnan(profit):
|
||||
long_profit.append(0)
|
||||
else:
|
||||
@@ -341,17 +348,17 @@ def long_short_backtest(
|
||||
for stock in short_stocks:
|
||||
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
|
||||
continue
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||
if np.isnan(profit):
|
||||
short_profit.append(0)
|
||||
else:
|
||||
short_profit.append(-profit)
|
||||
short_profit.append(profit * -1)
|
||||
|
||||
for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):
|
||||
# exclude the suspend stock
|
||||
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
|
||||
continue
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||
if np.isnan(profit):
|
||||
all_profit.append(0)
|
||||
else:
|
||||
|
||||
@@ -5,12 +5,10 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.stats import spearmanr, pearsonr
|
||||
|
||||
|
||||
from ..data import D
|
||||
|
||||
from collections import OrderedDict
|
||||
@@ -243,4 +241,4 @@ def get_rank_ic(a, b):
|
||||
|
||||
|
||||
def get_normal_ic(a, b):
|
||||
return pearsonr(a, b).correlation
|
||||
return pearsonr(a, b)[0]
|
||||
|
||||
@@ -2,3 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .data_selection import MetaTaskDS, MetaDatasetDS, MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaTaskDS", "MetaDatasetDS", "MetaModelDS"]
|
||||
|
||||
@@ -3,3 +3,6 @@
|
||||
|
||||
from .dataset import MetaDatasetDS, MetaTaskDS
|
||||
from .model import MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaDatasetDS", "MetaTaskDS", "MetaModelDS"]
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from copy import deepcopy
|
||||
from qlib.data.dataset.utils import init_task_handler
|
||||
from qlib.utils.data import deepcopy_basic_type
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from typing import Dict, List, Union, Text, Tuple
|
||||
from qlib.data.dataset.handler import DataHandler
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from joblib import Parallel, delayed
|
||||
from qlib.model.meta.dataset import MetaTaskDataset
|
||||
from qlib.model.trainer import task_train, TrainerR
|
||||
from qlib.data.dataset import DatasetH
|
||||
from tqdm.auto import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from joblib import Parallel, delayed # pylint: disable=E0401
|
||||
from typing import Dict, List, Union, Text, Tuple
|
||||
from qlib.data.dataset.utils import init_task_handler
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from qlib.model.meta.dataset import MetaTaskDataset
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config
|
||||
from qlib.utils.data import deepcopy_basic_type
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class InternalData:
|
||||
@@ -218,7 +217,7 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
----------
|
||||
task_tpl : Union[dict, list]
|
||||
Decide what tasks are used.
|
||||
- dict : the task template, the prepared task is generated with `step`, `trunc_days` and `RollingGen`
|
||||
- dict : the task template, the prepared task is generated with `step`, `trunc_days` and `RollingGen`
|
||||
- list : when list, use the list of tasks directly
|
||||
the list is supposed to be sorted according timeline
|
||||
step : int
|
||||
@@ -291,7 +290,7 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
ic_df = self.internal_data.data_ic_df
|
||||
|
||||
segs = task["dataset"]["kwargs"]["segments"]
|
||||
end = max([segs[k][1] for k in ("train", "valid") if k in segs])
|
||||
end = max(segs[k][1] for k in ("train", "valid") if k in segs)
|
||||
ic_df_avail = ic_df.loc[:end, pd.IndexSlice[:, :end]]
|
||||
|
||||
# meta data set focus on the **information** instead of preprocess
|
||||
|
||||
@@ -1,28 +1,25 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from qlib.model.meta.task import MetaTask
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import optim
|
||||
from tqdm.auto import tqdm
|
||||
import collections
|
||||
import copy
|
||||
from typing import Union, List, Tuple, Dict
|
||||
from typing import Union, List
|
||||
|
||||
from ....data.dataset.weight import Reweighter
|
||||
from ....model.meta.dataset import MetaTaskDataset
|
||||
from ....model.meta.model import MetaModel, MetaTaskModel
|
||||
from ....model.meta.model import MetaTaskModel
|
||||
from ....workflow import R
|
||||
|
||||
from .utils import ICLoss
|
||||
from .dataset import MetaDatasetDS
|
||||
from qlib.contrib.meta.data_selection.net import PredNet
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
from qlib.contrib.meta.data_selection.net import PredNet
|
||||
|
||||
logger = get_module_logger("data selection")
|
||||
|
||||
@@ -100,7 +97,6 @@ class MetaModelDS(MetaTaskModel):
|
||||
|
||||
if phase == "train":
|
||||
opt.zero_grad()
|
||||
norm_loss = nn.MSELoss()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
elif phase == "test":
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
|
||||
|
||||
class ICLoss(nn.Module):
|
||||
def forward(self, pred, y, idx, skip_size=50):
|
||||
"""forward.
|
||||
FIXME:
|
||||
- Some times it will be a slightly different from the result from `pandas.corr()`
|
||||
- It may be caused by the precision problem of model;
|
||||
|
||||
:param pred:
|
||||
:param y:
|
||||
|
||||
@@ -10,17 +10,19 @@ try:
|
||||
from .gbdt import LGBModel
|
||||
except ModuleNotFoundError:
|
||||
DEnsembleModel, LGBModel = None, None
|
||||
print("Please install necessary libs for DEnsembleModel and LGBModel, such as lightgbm.")
|
||||
print(
|
||||
"ModuleNotFoundError. DEnsembleModel and LGBModel are skipped. (optional: maybe installing lightgbm can fix it.)"
|
||||
)
|
||||
try:
|
||||
from .xgboost import XGBModel
|
||||
except ModuleNotFoundError:
|
||||
XGBModel = None
|
||||
print("Please install necessary libs for XGBModel, such as xgboost.")
|
||||
print("ModuleNotFoundError. XGBModel is skipped(optional: maybe installing xgboost can fix it).")
|
||||
try:
|
||||
from .linear import LinearModel
|
||||
except ModuleNotFoundError:
|
||||
LinearModel = None
|
||||
print("Please install necessary libs for LinearModel, such as scipy and sklearn.")
|
||||
print("ModuleNotFoundError. LinearModel is skipped(optional: maybe installing scipy and sklearn can fix it).")
|
||||
# import pytorch models
|
||||
try:
|
||||
from .pytorch_alstm import ALSTM
|
||||
@@ -36,6 +38,6 @@ try:
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN, ADD)
|
||||
except ModuleNotFoundError:
|
||||
pytorch_classes = ()
|
||||
print("Please install necessary libs for PyTorch models.")
|
||||
print("ModuleNotFoundError. PyTorch models are skipped (optional: maybe installing pytorch can fix it).")
|
||||
|
||||
all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes
|
||||
|
||||
@@ -160,7 +160,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
h_avg = h.groupby("bins")["h_value"].mean()
|
||||
weights = pd.Series(np.zeros(N, dtype=float))
|
||||
for i_b, b in enumerate(h_avg.index):
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay**k_th * h_avg[i_b] + 0.1)
|
||||
return weights
|
||||
|
||||
def feature_selection(self, df_train, loss_values):
|
||||
@@ -249,7 +249,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
return pred
|
||||
|
||||
def predict_sub(self, submodel, df_data, features):
|
||||
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
|
||||
x_data = df_data["feature"].loc[:, features]
|
||||
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
|
||||
return pred_sub
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import LightGBMFInt
|
||||
from ...data.dataset.weight import Reweighter
|
||||
from qlib.workflow import R
|
||||
|
||||
|
||||
class LGBModel(ModelFT, LightGBMFInt):
|
||||
@@ -59,27 +60,34 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
num_boost_round=None,
|
||||
early_stopping_rounds=None,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
evals_result=None,
|
||||
reweighter=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
if evals_result is None:
|
||||
evals_result = {} # in case of unsafety of Python default values
|
||||
ds_l = self._prepare_data(dataset, reweighter)
|
||||
ds, names = list(zip(*ds_l))
|
||||
early_stopping_callback = lgb.early_stopping(
|
||||
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
|
||||
)
|
||||
# NOTE: if you encounter error here. Please upgrade your lightgbm
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
evals_result_callback = lgb.record_evaluation(evals_result)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
ds[0], # training dataset
|
||||
num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round,
|
||||
valid_sets=ds,
|
||||
valid_names=names,
|
||||
early_stopping_rounds=(
|
||||
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
|
||||
),
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
|
||||
**kwargs,
|
||||
)
|
||||
for k in names:
|
||||
evals_result[k] = list(evals_result[k].values())[0]
|
||||
for key, val in evals_result[k].items():
|
||||
name = f"{key}.{k}"
|
||||
for epoch, m in enumerate(val):
|
||||
R.log_metrics(**{name.replace("@", "_"): m}, step=epoch)
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
@@ -101,9 +109,10 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter)
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
|
||||
if dtrain.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
@@ -111,5 +120,5 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
callbacks=[verbose_eval_callback],
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
"""
|
||||
Test the signal in high frequency test set
|
||||
"""
|
||||
if self.model == None:
|
||||
if self.model is None:
|
||||
raise ValueError("Model hasn't been trained yet")
|
||||
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
df_test.dropna(inplace=True)
|
||||
@@ -92,7 +92,10 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
# Convert label into alpha
|
||||
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
|
||||
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
|
||||
mapping_fn = lambda x: 0 if x < 0 else 1
|
||||
|
||||
def mapping_fn(x):
|
||||
return 0 if x < 0 else 1
|
||||
|
||||
df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn)
|
||||
df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn)
|
||||
x_train, y_train = df_train["feature"], df_train["label_c"].values
|
||||
@@ -110,20 +113,21 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
evals_result=None,
|
||||
):
|
||||
if evals_result is None:
|
||||
evals_result = dict()
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
early_stopping_callback = lgb.early_stopping(early_stopping_rounds)
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
evals_result_callback = lgb.record_evaluation(evals_result)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
@@ -149,6 +153,7 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
@@ -156,5 +161,5 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
callbacks=[verbose_eval_callback],
|
||||
)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
import os
|
||||
from pdb import set_trace
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
import copy
|
||||
from typing import Text, Union
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -146,7 +144,7 @@ class ADARNN(Model):
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.cuda()
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
@@ -155,7 +153,7 @@ class ADARNN(Model):
|
||||
def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None):
|
||||
self.model.train()
|
||||
criterion = nn.MSELoss()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||
len_loader = np.inf
|
||||
for loader in train_loader_list:
|
||||
if len(loader) < len_loader:
|
||||
@@ -167,7 +165,7 @@ class ADARNN(Model):
|
||||
list_label = []
|
||||
for data in data_all:
|
||||
# feature :[36, 24, 6]
|
||||
feature, label_reg = data[0].cuda().float(), data[1].cuda().float()
|
||||
feature, label_reg = data[0].to(self.device).float(), data[1].to(self.device).float()
|
||||
list_feat.append(feature)
|
||||
list_label.append(label_reg)
|
||||
flag = False
|
||||
@@ -181,12 +179,12 @@ class ADARNN(Model):
|
||||
if flag:
|
||||
continue
|
||||
|
||||
total_loss = torch.zeros(1).cuda()
|
||||
for i in range(len(index)):
|
||||
feature_s = list_feat[index[i][0]]
|
||||
feature_t = list_feat[index[i][1]]
|
||||
label_reg_s = list_label[index[i][0]]
|
||||
label_reg_t = list_label[index[i][1]]
|
||||
total_loss = torch.zeros(1).to(self.device)
|
||||
for i, n in enumerate(index):
|
||||
feature_s = list_feat[n[0]]
|
||||
feature_t = list_feat[n[1]]
|
||||
label_reg_s = list_label[n[0]]
|
||||
label_reg_t = list_label[n[1]]
|
||||
feature_all = torch.cat((feature_s, feature_t), 0)
|
||||
|
||||
if epoch < self.pre_epoch:
|
||||
@@ -327,7 +325,7 @@ class ADARNN(Model):
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().cuda()
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model.predict(x_batch).detach().cpu().numpy()
|
||||
@@ -337,7 +335,7 @@ class ADARNN(Model):
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
def transform_type(self, init_weight):
|
||||
weight = torch.ones(self.num_layers, self.len_seq).cuda()
|
||||
weight = torch.ones(self.num_layers, self.len_seq).to(self.device)
|
||||
for i in range(self.num_layers):
|
||||
for j in range(self.len_seq):
|
||||
weight[i, j] = init_weight[i][j].item()
|
||||
@@ -391,6 +389,7 @@ class AdaRNN(nn.Module):
|
||||
len_seq=9,
|
||||
model_type="AdaRNN",
|
||||
trans_loss="mmd",
|
||||
GPU=0,
|
||||
):
|
||||
super(AdaRNN, self).__init__()
|
||||
self.use_bottleneck = use_bottleneck
|
||||
@@ -401,6 +400,7 @@ class AdaRNN(nn.Module):
|
||||
self.model_type = model_type
|
||||
self.trans_loss = trans_loss
|
||||
self.len_seq = len_seq
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
in_size = self.n_input
|
||||
|
||||
features = nn.ModuleList()
|
||||
@@ -410,7 +410,7 @@ class AdaRNN(nn.Module):
|
||||
in_size = hidden
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
if use_bottleneck == True: # finance
|
||||
if use_bottleneck is True: # finance
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Linear(n_hiddens[-1], bottleneck_width),
|
||||
nn.Linear(bottleneck_width, bottleneck_width),
|
||||
@@ -449,7 +449,7 @@ class AdaRNN(nn.Module):
|
||||
def forward_pre_train(self, x, len_win=0):
|
||||
out = self.gru_features(x)
|
||||
fea = out[0] # [2N,L,H]
|
||||
if self.use_bottleneck == True:
|
||||
if self.use_bottleneck is True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
@@ -457,9 +457,9 @@ class AdaRNN(nn.Module):
|
||||
|
||||
out_list_all, out_weight_list = out[1], out[2]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
loss_transfer = torch.zeros((1,)).to(self.device)
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
h_start = 0
|
||||
for j in range(h_start, self.len_seq, 1):
|
||||
i_start = j - len_win if j - len_win >= 0 else 0
|
||||
@@ -471,7 +471,7 @@ class AdaRNN(nn.Module):
|
||||
else 1 / (self.len_seq - h_start) * (2 * len_win + 1)
|
||||
)
|
||||
loss_transfer = loss_transfer + weight * criterion_transder.compute(
|
||||
out_list_s[i][:, j, :], out_list_t[i][:, k, :]
|
||||
n[:, j, :], out_list_t[i][:, k, :]
|
||||
)
|
||||
return fc_out, loss_transfer, out_weight_list
|
||||
|
||||
@@ -484,7 +484,7 @@ class AdaRNN(nn.Module):
|
||||
out, _ = self.features[i](x_input.float())
|
||||
x_input = out
|
||||
out_lis.append(out)
|
||||
if self.model_type == "AdaRNN" and predict == False:
|
||||
if self.model_type == "AdaRNN" and predict is False:
|
||||
out_gate = self.process_gate_weight(x_input, i)
|
||||
out_weight_list.append(out_gate)
|
||||
return out, out_lis, out_weight_list
|
||||
@@ -518,16 +518,16 @@ class AdaRNN(nn.Module):
|
||||
|
||||
out_list_all = out[1]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
loss_transfer = torch.zeros((1,)).to(self.device)
|
||||
if weight_mat is None:
|
||||
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).cuda()
|
||||
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).to(self.device)
|
||||
else:
|
||||
weight = weight_mat
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
for j in range(self.len_seq):
|
||||
loss_trans = criterion_transder.compute(out_list_s[i][:, j, :], out_list_t[i][:, j, :])
|
||||
loss_trans = criterion_transder.compute(n[:, j, :], out_list_t[i][:, j, :])
|
||||
loss_transfer = loss_transfer + weight[i, j] * loss_trans
|
||||
dist_mat[i, j] = loss_trans
|
||||
return fc_out, loss_transfer, dist_mat, weight
|
||||
@@ -546,7 +546,7 @@ class AdaRNN(nn.Module):
|
||||
def predict(self, x):
|
||||
out = self.gru_features(x, predict=True)
|
||||
fea = out[0]
|
||||
if self.use_bottleneck == True:
|
||||
if self.use_bottleneck is True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
@@ -555,12 +555,13 @@ class AdaRNN(nn.Module):
|
||||
|
||||
|
||||
class TransferLoss:
|
||||
def __init__(self, loss_type="cosine", input_dim=512):
|
||||
def __init__(self, loss_type="cosine", input_dim=512, GPU=0):
|
||||
"""
|
||||
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
|
||||
"""
|
||||
self.loss_type = loss_type
|
||||
self.input_dim = input_dim
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
|
||||
def compute(self, X, Y):
|
||||
"""Compute adaptation loss
|
||||
@@ -572,22 +573,22 @@ class TransferLoss:
|
||||
Returns:
|
||||
[tensor] -- transfer loss
|
||||
"""
|
||||
if self.loss_type == "mmd_lin" or self.loss_type == "mmd":
|
||||
if self.loss_type in ("mmd_lin", "mmd"):
|
||||
mmdloss = MMD_loss(kernel_type="linear")
|
||||
loss = mmdloss(X, Y)
|
||||
elif self.loss_type == "coral":
|
||||
loss = CORAL(X, Y)
|
||||
elif self.loss_type == "cosine" or self.loss_type == "cos":
|
||||
loss = CORAL(X, Y, self.device)
|
||||
elif self.loss_type in ("cosine", "cos"):
|
||||
loss = 1 - cosine(X, Y)
|
||||
elif self.loss_type == "kl":
|
||||
loss = kl_div(X, Y)
|
||||
elif self.loss_type == "js":
|
||||
loss = js(X, Y)
|
||||
elif self.loss_type == "mine":
|
||||
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).cuda()
|
||||
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).to(self.device)
|
||||
loss = mine_model(X, Y)
|
||||
elif self.loss_type == "adv":
|
||||
loss = adv(X, Y, input_dim=self.input_dim, hidden_dim=32)
|
||||
loss = adv(X, Y, self.device, input_dim=self.input_dim, hidden_dim=32)
|
||||
elif self.loss_type == "mmd_rbf":
|
||||
mmdloss = MMD_loss(kernel_type="rbf")
|
||||
loss = mmdloss(X, Y)
|
||||
@@ -632,12 +633,12 @@ class Discriminator(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def adv(source, target, input_dim=256, hidden_dim=512):
|
||||
def adv(source, target, device, input_dim=256, hidden_dim=512):
|
||||
domain_loss = nn.BCELoss()
|
||||
# !!! Pay attention to .cuda !!!
|
||||
adv_net = Discriminator(input_dim, hidden_dim).cuda()
|
||||
domain_src = torch.ones(len(source)).cuda()
|
||||
domain_tar = torch.zeros(len(target)).cuda()
|
||||
adv_net = Discriminator(input_dim, hidden_dim).to(device)
|
||||
domain_src = torch.ones(len(source)).to(device)
|
||||
domain_tar = torch.zeros(len(target)).to(device)
|
||||
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
|
||||
reverse_src = ReverseLayerF.apply(source, 1)
|
||||
reverse_tar = ReverseLayerF.apply(target, 1)
|
||||
@@ -648,16 +649,16 @@ def adv(source, target, input_dim=256, hidden_dim=512):
|
||||
return loss
|
||||
|
||||
|
||||
def CORAL(source, target):
|
||||
def CORAL(source, target, device):
|
||||
d = source.size(1)
|
||||
ns, nt = source.size(0), target.size(0)
|
||||
|
||||
# source covariance
|
||||
tmp_s = torch.ones((1, ns)).cuda() @ source
|
||||
tmp_s = torch.ones((1, ns)).to(device) @ source
|
||||
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
|
||||
|
||||
# target covariance
|
||||
tmp_t = torch.ones((1, nt)).cuda() @ target
|
||||
tmp_t = torch.ones((1, nt)).to(device) @ target
|
||||
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
|
||||
|
||||
# frobenius norm
|
||||
@@ -684,9 +685,9 @@ class MMD_loss(nn.Module):
|
||||
if fix_sigma:
|
||||
bandwidth = fix_sigma
|
||||
else:
|
||||
bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
|
||||
bandwidth = torch.sum(L2_distance.data) / (n_samples**2 - n_samples)
|
||||
bandwidth /= kernel_mul ** (kernel_num // 2)
|
||||
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
|
||||
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
|
||||
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
||||
return sum(kernel_val)
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from qlib.contrib.model.pytorch_lstm import LSTMModel
|
||||
from qlib.contrib.model.pytorch_utils import count_parameters
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.processor import CSRankNorm
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import get_or_create_path
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -150,7 +149,7 @@ class ALSTM(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -312,8 +311,8 @@ class ALSTMModel(nn.Module):
|
||||
def _build_model(self):
|
||||
try:
|
||||
klass = getattr(nn, self.rnn_type.upper())
|
||||
except:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
|
||||
except Exception as e:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e
|
||||
self.net = nn.Sequential()
|
||||
self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
|
||||
self.net.add_module("act", nn.Tanh())
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -20,7 +19,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.utils import ConcatDataset
|
||||
from ...data.dataset.weight import Reweighter
|
||||
@@ -160,7 +159,7 @@ class ALSTM(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -320,8 +319,8 @@ class ALSTMModel(nn.Module):
|
||||
def _build_model(self):
|
||||
try:
|
||||
klass = getattr(nn, self.rnn_type.upper())
|
||||
except:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
|
||||
except Exception as e:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e
|
||||
self.net = nn.Sequential()
|
||||
self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
|
||||
self.net.add_module("act", nn.Tanh())
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -158,7 +157,7 @@ class GATs(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -263,7 +262,9 @@ class GATs(Model):
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
|
||||
}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -19,7 +18,6 @@ from torch.utils.data import Sampler
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
@@ -178,7 +176,7 @@ class GATs(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -279,7 +277,9 @@ class GATs(Model):
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
|
||||
}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -150,7 +149,7 @@ class GRU(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -19,7 +18,6 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.utils import ConcatDataset
|
||||
from ...data.dataset.weight import Reweighter
|
||||
@@ -159,7 +157,7 @@ class GRU(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
503
qlib/contrib/model/pytorch_hist.py
Normal file
503
qlib/contrib/model/pytorch_hist.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import urllib.request
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
|
||||
class HIST(Model):
|
||||
"""HIST Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lr : float
|
||||
learning rate
|
||||
d_feat : int
|
||||
input dimensions for each time step
|
||||
metric : str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
model_path=None,
|
||||
stock2concept=None,
|
||||
stock_index=None,
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("HIST")
|
||||
self.logger.info("HIST pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.model_path = model_path
|
||||
self.stock2concept = stock2concept
|
||||
self.stock_index = stock_index
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"HIST parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nstock2concept : {}"
|
||||
"\nstock_index : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
model_path,
|
||||
stock2concept,
|
||||
stock_index,
|
||||
GPU,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.HIST_model = HISTModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.HIST_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.HIST_model)))
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.HIST_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.HIST_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.HIST_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "ic":
|
||||
x = pred[mask]
|
||||
y = label[mask]
|
||||
|
||||
vx = x - torch.mean(x)
|
||||
vy = y - torch.mean(y)
|
||||
return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)))
|
||||
|
||||
if self.metric == ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def train_epoch(self, x_train, y_train, stock_index):
|
||||
|
||||
stock2concept_matrix = np.load(self.stock2concept)
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
stock_index = stock_index.values
|
||||
stock_index[np.isnan(stock_index)] = 733
|
||||
self.HIST_model.train()
|
||||
|
||||
# organize the train data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)
|
||||
concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index[batch]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[batch]).float().to(self.device)
|
||||
pred = self.HIST_model(feature, concept_matrix)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.HIST_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y, stock_index):
|
||||
|
||||
# prepare training data
|
||||
stock2concept_matrix = np.load(self.stock2concept)
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
stock_index = stock_index.values
|
||||
stock_index[np.isnan(stock_index)] = 733
|
||||
self.HIST_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
# organize the test data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index[batch]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[batch]).float().to(self.device)
|
||||
with torch.no_grad():
|
||||
pred = self.HIST_model(feature, concept_matrix)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
if not os.path.exists(self.stock2concept):
|
||||
url = "http://fintech.msra.cn/stock_data/downloads/qlib_csi300_stock2concept.npy"
|
||||
urllib.request.urlretrieve(url, self.stock2concept)
|
||||
|
||||
stock_index = np.load(self.stock_index, allow_pickle=True).item()
|
||||
df_train["stock_index"] = 733
|
||||
df_train["stock_index"] = df_train.index.get_level_values("instrument").map(stock_index)
|
||||
df_valid["stock_index"] = 733
|
||||
df_valid["stock_index"] = df_valid.index.get_level_values("instrument").map(stock_index)
|
||||
|
||||
x_train, y_train, stock_index_train = df_train["feature"], df_train["label"], df_train["stock_index"]
|
||||
x_valid, y_valid, stock_index_valid = df_valid["feature"], df_valid["label"], df_valid["stock_index"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
|
||||
model_dict = self.HIST_model.state_dict()
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
|
||||
}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.HIST_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train, stock_index_train)
|
||||
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train, stock_index_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid, stock_index_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.HIST_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.HIST_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
stock2concept_matrix = np.load(self.stock2concept)
|
||||
stock_index = np.load(self.stock_index, allow_pickle=True).item()
|
||||
df_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
df_test["stock_index"] = 733
|
||||
df_test["stock_index"] = df_test.index.get_level_values("instrument").map(stock_index)
|
||||
stock_index_test = df_test["stock_index"].values
|
||||
stock_index_test[np.isnan(stock_index_test)] = 733
|
||||
stock_index_test = stock_index_test.astype("int")
|
||||
df_test = df_test.drop(["stock_index"], axis=1)
|
||||
index = df_test.index
|
||||
|
||||
self.HIST_model.eval()
|
||||
x_values = df_test.values
|
||||
preds = []
|
||||
|
||||
# organize the data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(df_test, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index_test[batch]]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.HIST_model(x_batch, concept_matrix).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class HISTModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"):
|
||||
super().__init__()
|
||||
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
if base_model == "GRU":
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
elif base_model == "LSTM":
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
|
||||
self.fc_es = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es.weight)
|
||||
self.fc_is = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is.weight)
|
||||
|
||||
self.fc_es_middle = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es_middle.weight)
|
||||
self.fc_is_middle = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is_middle.weight)
|
||||
|
||||
self.fc_es_fore = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es_fore.weight)
|
||||
self.fc_is_fore = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is_fore.weight)
|
||||
self.fc_indi_fore = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_indi_fore.weight)
|
||||
|
||||
self.fc_es_back = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es_back.weight)
|
||||
self.fc_is_back = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is_back.weight)
|
||||
self.fc_indi = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_indi.weight)
|
||||
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
self.softmax_s2t = torch.nn.Softmax(dim=0)
|
||||
self.softmax_t2s = torch.nn.Softmax(dim=1)
|
||||
|
||||
self.fc_out_es = nn.Linear(hidden_size, 1)
|
||||
self.fc_out_is = nn.Linear(hidden_size, 1)
|
||||
self.fc_out_indi = nn.Linear(hidden_size, 1)
|
||||
self.fc_out = nn.Linear(hidden_size, 1)
|
||||
|
||||
def cal_cos_similarity(self, x, y): # the 2nd dimension of x and y are the same
|
||||
xy = x.mm(torch.t(y))
|
||||
x_norm = torch.sqrt(torch.sum(x * x, dim=1)).reshape(-1, 1)
|
||||
y_norm = torch.sqrt(torch.sum(y * y, dim=1)).reshape(-1, 1)
|
||||
cos_similarity = xy / (x_norm.mm(torch.t(y_norm)) + 1e-6)
|
||||
return cos_similarity
|
||||
|
||||
def forward(self, x, concept_matrix):
|
||||
device = torch.device(torch.get_device(x))
|
||||
|
||||
x_hidden = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x_hidden = x_hidden.permute(0, 2, 1) # [N, T, F]
|
||||
x_hidden, _ = self.rnn(x_hidden)
|
||||
x_hidden = x_hidden[:, -1, :]
|
||||
|
||||
# Predefined Concept Module
|
||||
|
||||
stock_to_concept = concept_matrix
|
||||
|
||||
stock_to_concept_sum = torch.sum(stock_to_concept, 0).reshape(1, -1).repeat(stock_to_concept.shape[0], 1)
|
||||
stock_to_concept_sum = stock_to_concept_sum.mul(concept_matrix)
|
||||
|
||||
stock_to_concept_sum = stock_to_concept_sum + (
|
||||
torch.ones(stock_to_concept.shape[0], stock_to_concept.shape[1]).to(device)
|
||||
)
|
||||
stock_to_concept = stock_to_concept / stock_to_concept_sum
|
||||
hidden = torch.t(stock_to_concept).mm(x_hidden)
|
||||
|
||||
hidden = hidden[hidden.sum(1) != 0]
|
||||
|
||||
concept_to_stock = self.cal_cos_similarity(x_hidden, hidden)
|
||||
concept_to_stock = self.softmax_t2s(concept_to_stock)
|
||||
|
||||
e_shared_info = concept_to_stock.mm(hidden)
|
||||
e_shared_info = self.fc_es(e_shared_info)
|
||||
|
||||
e_shared_back = self.fc_es_back(e_shared_info)
|
||||
output_es = self.fc_es_fore(e_shared_info)
|
||||
output_es = self.leaky_relu(output_es)
|
||||
|
||||
# Hidden Concept Module
|
||||
i_shared_info = x_hidden - e_shared_back
|
||||
hidden = i_shared_info
|
||||
i_stock_to_concept = self.cal_cos_similarity(i_shared_info, hidden)
|
||||
dim = i_stock_to_concept.shape[0]
|
||||
diag = i_stock_to_concept.diagonal(0)
|
||||
i_stock_to_concept = i_stock_to_concept * (torch.ones(dim, dim) - torch.eye(dim)).to(device)
|
||||
row = torch.linspace(0, dim - 1, dim).to(device).long()
|
||||
column = i_stock_to_concept.max(1)[1].long()
|
||||
value = i_stock_to_concept.max(1)[0]
|
||||
i_stock_to_concept[row, column] = 10
|
||||
i_stock_to_concept[i_stock_to_concept != 10] = 0
|
||||
i_stock_to_concept[row, column] = value
|
||||
i_stock_to_concept = i_stock_to_concept + torch.diag_embed((i_stock_to_concept.sum(0) != 0).float() * diag)
|
||||
hidden = torch.t(i_shared_info).mm(i_stock_to_concept).t()
|
||||
hidden = hidden[hidden.sum(1) != 0]
|
||||
|
||||
i_concept_to_stock = self.cal_cos_similarity(i_shared_info, hidden)
|
||||
i_concept_to_stock = self.softmax_t2s(i_concept_to_stock)
|
||||
i_shared_info = i_concept_to_stock.mm(hidden)
|
||||
i_shared_info = self.fc_is(i_shared_info)
|
||||
|
||||
i_shared_back = self.fc_is_back(i_shared_info)
|
||||
output_is = self.fc_is_fore(i_shared_info)
|
||||
output_is = self.leaky_relu(output_is)
|
||||
|
||||
# Individual Information Module
|
||||
individual_info = x_hidden - e_shared_back - i_shared_back
|
||||
output_indi = individual_info
|
||||
output_indi = self.fc_indi(output_indi)
|
||||
output_indi = self.leaky_relu(output_indi)
|
||||
|
||||
# Stock Trend Prediction
|
||||
all_info = output_es + output_is + output_indi
|
||||
pred_all = self.fc_out(all_info).squeeze()
|
||||
|
||||
return pred_all
|
||||
446
qlib/contrib/model/pytorch_igmtf.py
Normal file
446
qlib/contrib/model/pytorch_igmtf.py
Normal file
@@ -0,0 +1,446 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
|
||||
class IGMTF(Model):
|
||||
"""IGMTF Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
model_path=None,
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("IGMTF")
|
||||
self.logger.info("IMGTF pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"IGMTF parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
model_path,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.igmtf_model = IGMTFModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.igmtf_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.igmtf_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.igmtf_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.igmtf_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.igmtf_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "ic":
|
||||
x = pred[mask]
|
||||
y = label[mask]
|
||||
|
||||
vx = x - torch.mean(x)
|
||||
vy = y - torch.mean(y)
|
||||
return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)))
|
||||
|
||||
if self.metric == ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def get_train_hidden(self, x_train):
|
||||
x_train_values = x_train.values
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
self.igmtf_model.eval()
|
||||
train_hidden = []
|
||||
train_hidden_day = []
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)
|
||||
out = self.igmtf_model(feature, get_hidden=True)
|
||||
train_hidden.append(out.detach().cpu())
|
||||
train_hidden_day.append(out.detach().cpu().mean(dim=0).unsqueeze(dim=0))
|
||||
|
||||
train_hidden = np.asarray(train_hidden, dtype=object)
|
||||
train_hidden_day = torch.cat(train_hidden_day)
|
||||
|
||||
return train_hidden, train_hidden_day
|
||||
|
||||
def train_epoch(self, x_train, y_train, train_hidden, train_hidden_day):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.igmtf_model.train()
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[batch]).float().to(self.device)
|
||||
pred = self.igmtf_model(feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.igmtf_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y, train_hidden, train_hidden_day):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.igmtf_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[batch]).float().to(self.device)
|
||||
|
||||
pred = self.igmtf_model(feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.igmtf_model.state_dict()
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
|
||||
}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.igmtf_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
train_hidden, train_hidden_day = self.get_train_hidden(x_train)
|
||||
self.train_epoch(x_train, y_train, train_hidden, train_hidden_day)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train, train_hidden, train_hidden_day)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid, train_hidden, train_hidden_day)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.igmtf_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.igmtf_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_train = dataset.prepare("train", col_set="feature", data_key=DataHandlerLP.DK_L)
|
||||
train_hidden, train_hidden_day = self.get_train_hidden(x_train)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.igmtf_model.eval()
|
||||
x_values = x_test.values
|
||||
preds = []
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = (
|
||||
self.igmtf_model(x_batch, train_hidden=train_hidden, train_hidden_day=train_hidden_day)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class IGMTFModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"):
|
||||
super().__init__()
|
||||
|
||||
if base_model == "GRU":
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
elif base_model == "LSTM":
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
self.lins = nn.Sequential()
|
||||
for i in range(2):
|
||||
self.lins.add_module("linear" + str(i), nn.Linear(hidden_size, hidden_size))
|
||||
self.lins.add_module("leakyrelu" + str(i), nn.LeakyReLU())
|
||||
self.fc_output = nn.Linear(hidden_size * 2, hidden_size * 2)
|
||||
self.project1 = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.project2 = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.fc_out_pred = nn.Linear(hidden_size * 2, 1)
|
||||
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
self.d_feat = d_feat
|
||||
|
||||
def cal_cos_similarity(self, x, y): # the 2nd dimension of x and y are the same
|
||||
xy = x.mm(torch.t(y))
|
||||
x_norm = torch.sqrt(torch.sum(x * x, dim=1)).reshape(-1, 1)
|
||||
y_norm = torch.sqrt(torch.sum(y * y, dim=1)).reshape(-1, 1)
|
||||
cos_similarity = xy / (x_norm.mm(torch.t(y_norm)) + 1e-6)
|
||||
return cos_similarity
|
||||
|
||||
def sparse_dense_mul(self, s, d):
|
||||
i = s._indices()
|
||||
v = s._values()
|
||||
dv = d[i[0, :], i[1, :]] # get values from relevant entries of dense matrix
|
||||
return torch.sparse.FloatTensor(i, v * dv, s.size())
|
||||
|
||||
def forward(self, x, get_hidden=False, train_hidden=None, train_hidden_day=None, k_day=10, n_neighbor=10):
|
||||
# x: [N, F*T]
|
||||
device = x.device
|
||||
x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x = x.permute(0, 2, 1) # [N, T, F]
|
||||
out, _ = self.rnn(x)
|
||||
out = out[:, -1, :]
|
||||
out = self.lins(out)
|
||||
mini_batch_out = out
|
||||
if get_hidden is True:
|
||||
return mini_batch_out
|
||||
|
||||
mini_batch_out_day = torch.mean(mini_batch_out, dim=0).unsqueeze(0)
|
||||
day_similarity = self.cal_cos_similarity(mini_batch_out_day, train_hidden_day.to(device))
|
||||
day_index = torch.topk(day_similarity, k_day, dim=1)[1]
|
||||
sample_train_hidden = train_hidden[day_index.long().cpu()].squeeze()
|
||||
sample_train_hidden = torch.cat(list(sample_train_hidden)).to(device)
|
||||
sample_train_hidden = self.lins(sample_train_hidden)
|
||||
cos_similarity = self.cal_cos_similarity(self.project1(mini_batch_out), self.project2(sample_train_hidden))
|
||||
|
||||
row = (
|
||||
torch.linspace(0, x.shape[0] - 1, x.shape[0])
|
||||
.reshape([-1, 1])
|
||||
.repeat(1, n_neighbor)
|
||||
.reshape(1, -1)
|
||||
.to(device)
|
||||
)
|
||||
column = torch.topk(cos_similarity, n_neighbor, dim=1)[1].reshape(1, -1)
|
||||
mask = torch.sparse_coo_tensor(
|
||||
torch.cat([row, column]),
|
||||
torch.ones([row.shape[1]]).to(device) / n_neighbor,
|
||||
(x.shape[0], sample_train_hidden.shape[0]),
|
||||
)
|
||||
cos_similarity = self.sparse_dense_mul(mask, cos_similarity)
|
||||
|
||||
agg_out = torch.sparse.mm(cos_similarity, self.project2(sample_train_hidden))
|
||||
# out = self.fc_out(out).squeeze()
|
||||
out = self.fc_out_pred(torch.cat([mini_batch_out, agg_out], axis=1)).squeeze()
|
||||
return out
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -17,11 +16,9 @@ from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from torch.nn.modules.container import ModuleList
|
||||
|
||||
@@ -102,7 +99,7 @@ class LocalformerModel(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -18,9 +17,8 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from torch.nn.modules.container import ModuleList
|
||||
|
||||
@@ -101,7 +99,7 @@ class LocalformerModel(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user