mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
46 Commits
v0.8.6
...
mini_proje
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
949d96d768 | ||
|
|
597359f98f | ||
|
|
75aae820e8 | ||
|
|
558603beca | ||
|
|
157481abd1 | ||
|
|
9d7a0f032a | ||
|
|
58f9eed3c9 | ||
|
|
8f1e28c43f | ||
|
|
e7c660f0d4 | ||
|
|
2752bdc92c | ||
|
|
687edd79d0 | ||
|
|
ba705d39e0 | ||
|
|
a53f59cdf7 | ||
|
|
8e063828f9 | ||
|
|
86f08e47e8 | ||
|
|
8199822ca0 | ||
|
|
1b9915501c | ||
|
|
c65c598bde | ||
|
|
fb5779a64c | ||
|
|
d149c2b177 | ||
|
|
6fddae9965 | ||
|
|
107d716cf8 | ||
|
|
792285b64f | ||
|
|
78b6b16640 | ||
|
|
b9bba4940f | ||
|
|
c34051c1ce | ||
|
|
a0c83d7997 | ||
|
|
82b10ee37a | ||
|
|
9b446f9a92 | ||
|
|
59b1820447 | ||
|
|
1dededa33f | ||
|
|
e62684eddf | ||
|
|
8a5efda0f6 | ||
|
|
a6700d81ff | ||
|
|
623774d8fb | ||
|
|
3db22452fb | ||
|
|
b655f90511 | ||
|
|
5e404909cf | ||
|
|
23c657a7a2 | ||
|
|
9bf3423a64 | ||
|
|
25ecb1135f | ||
|
|
2ca0d88d2d | ||
|
|
50d74b5560 | ||
|
|
a87b02619a | ||
|
|
da676a20a2 | ||
|
|
13d904d9a9 |
94
.github/workflows/test_macos.yml
vendored
94
.github/workflows/test_macos.yml
vendored
@@ -1,94 +0,0 @@
|
||||
# There are some issues (in the downloading data phase) on MacOS when running with other tests. So we split it into an individual config.
|
||||
name: Test MacOS
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
python -m pip install pip --upgrade
|
||||
python -m pip install wheel --upgrade
|
||||
python -m pip install black
|
||||
python -m black qlib -l 120 --check --diff
|
||||
# Test Qlib installed with pip
|
||||
|
||||
- name: Check Qlib with flake8
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install flake8
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
python -m pip install numpy==1.19.5
|
||||
python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
- name: Make html with sphnix
|
||||
run: |
|
||||
pip install -U sphinx
|
||||
pip install sphinx_rtd_theme readthedocs_sphinx_ext
|
||||
pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
|
||||
cd docs
|
||||
sphinx-build -b html . build
|
||||
cd ..
|
||||
- name: Install Lightgbm for MacOS
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
|
||||
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
python -m pip install gym tianshou torch
|
||||
pip install -e .
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -U pyopenssl idna
|
||||
python -m pip install black pytest
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
pip install -r scripts/data_collector/pit/requirements.txt
|
||||
cd tests
|
||||
python -m pytest . --durations=0
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
57
.github/workflows/test_qlib_from_pip.yml
vendored
Normal file
57
.github/workflows/test_qlib_from_pip.yml
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
name: Test qlib from pip
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
timeout-minutes: 120
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from pip
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Update pip to the latest version
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Qlib installation test
|
||||
run: |
|
||||
python -m pip install pyqlib
|
||||
# Specify the numpy version because the numpy upgrade caused the CI test to fail,
|
||||
# and this line of code will be removed when the next version of qlib is released.
|
||||
python -m pip install "numpy<1.23"
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
|
||||
- name: Downloads dependencies data
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Test workflow by config
|
||||
run: |
|
||||
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Test
|
||||
name: Test qlib from source
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -8,42 +8,61 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
timeout-minutes: 180
|
||||
# we may retry for 3 times for `Unit tests with Pytest`
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04]
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Test qlib from source
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Update pip to the latest version
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
run: |
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Installing pytorch for ubuntu
|
||||
if: ${{ matrix.os == 'ubuntu-18.04' || matrix.os == 'ubuntu-20.04' }}
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
- name: Installing pytorch for windows
|
||||
if: ${{ matrix.os == 'windows-latest' }}
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Set up Python tools
|
||||
run: |
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install -e .[dev]
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install black wheel
|
||||
black qlib -l 120 --check --diff
|
||||
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
pip install numpy==1.19.5 ruamel.yaml
|
||||
pip install pyqlib --ignore-installed
|
||||
black . -l 120 --check --diff
|
||||
|
||||
- name: Make html with sphinx
|
||||
run: |
|
||||
pip install -U sphinx
|
||||
pip install sphinx_rtd_theme readthedocs_sphinx_ext
|
||||
pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
|
||||
cd docs
|
||||
sphinx-build -b html . build
|
||||
cd ..
|
||||
|
||||
|
||||
# Check Qlib with pylint
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
@@ -67,11 +86,9 @@ jobs:
|
||||
# W1309: f-string-without-interpolation
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
- name: Check Qlib with pylint
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install pylint
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
|
||||
|
||||
# The following flake8 error codes were ignored:
|
||||
@@ -95,47 +112,44 @@ jobs:
|
||||
# Description: If there is whitespace before ":", it cannot pass the black check.
|
||||
- name: Check Qlib with flake8
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install flake8
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
# https://github.com/python/mypy/issues/10600
|
||||
- name: Check Qlib with mypy
|
||||
run: |
|
||||
pip install mypy
|
||||
mypy qlib --install-types --non-interactive || true
|
||||
mypy qlib
|
||||
mypy qlib --verbose
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
|
||||
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl /tmp/qlibpublic/data --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
|
||||
- name: Test workflow by config (install from pip)
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install gym tianshou torch
|
||||
pip install -e .
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install black pytest
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
pip install -r scripts/data_collector/pit/requirements.txt
|
||||
cd tests
|
||||
python -m pytest . --durations=10
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
# Version 0.52.0 of numba must be installed manually in CI, otherwise it will cause incompatibility with the latest version of numpy.
|
||||
python -m pip install numba==0.52.0
|
||||
# You must update numpy manually, because when installing python tools, it will try to uninstall numpy and cause CI to fail.
|
||||
python -m pip install --upgrade numpy
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
uses: nick-fields/retry@v2
|
||||
with:
|
||||
timeout_minutes: 60
|
||||
max_attempts: 3
|
||||
command: |
|
||||
cd tests
|
||||
python -m pytest . -m "not slow" --durations=0
|
||||
59
.github/workflows/test_qlib_from_source_slow.yml
vendored
Normal file
59
.github/workflows/test_qlib_from_source_slow.yml
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
name: Test qlib from source slow
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
timeout-minutes: 720
|
||||
# we may retry for 3 times for `Unit tests with Pytest`
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source slow
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# python -m pip is necessary to upgrade pip.
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
|
||||
- name: Downloads dependencies data
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
uses: nick-fields/retry@v2
|
||||
with:
|
||||
timeout_minutes: 240
|
||||
max_attempts: 3
|
||||
command: |
|
||||
cd tests
|
||||
python -m pytest . -m "slow" --durations=0
|
||||
@@ -1,6 +1,6 @@
|
||||
[mypy]
|
||||
exclude = (?x)(
|
||||
^qlib/backtest
|
||||
^qlib/backtest/high_performance_ds\.py$
|
||||
| ^qlib/contrib
|
||||
| ^qlib/data
|
||||
| ^qlib/model
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.1.0
|
||||
rev: 22.6.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: ["qlib", "-l 120"]
|
||||
|
||||
56
CHANGES.rst
56
CHANGES.rst
@@ -1,63 +1,63 @@
|
||||
Changelog
|
||||
====================
|
||||
=========
|
||||
Here you can see the full list of changes between each QLib release.
|
||||
|
||||
Version 0.1.0
|
||||
--------------------
|
||||
-------------
|
||||
This is the initial release of QLib library.
|
||||
|
||||
Version 0.1.1
|
||||
--------------------
|
||||
-------------
|
||||
Performance optimize. Add more features and operators.
|
||||
|
||||
Version 0.1.2
|
||||
--------------------
|
||||
- Support operator syntax. Now ``High() - Low()`` is equivalent to ``Sub(High(), Low())``.
|
||||
-------------
|
||||
- Support operator syntax. Now ``High() - Low()`` is equivalent to ``Sub(High(), Low())``.
|
||||
- Add more technical indicators.
|
||||
|
||||
Version 0.1.3
|
||||
--------------------
|
||||
-------------
|
||||
Bug fix and add instruments filtering mechanism.
|
||||
|
||||
Version 0.2.0
|
||||
--------------------
|
||||
-------------
|
||||
- Redesign ``LocalProvider`` database format for performance improvement.
|
||||
- Support load features as string fields.
|
||||
- Add scripts for database construction.
|
||||
- More operators and technical indicators.
|
||||
|
||||
Version 0.2.1
|
||||
--------------------
|
||||
-------------
|
||||
- Support registering user-defined ``Provider``.
|
||||
- Support use operators in string format, e.g. ``['Ref($close, 1)']`` is valid field format.
|
||||
- Support dynamic fields in ``$some_field`` format. And existing fields like ``Close()`` may be deprecated in the future.
|
||||
|
||||
Version 0.2.2
|
||||
--------------------
|
||||
-------------
|
||||
- Add ``disk_cache`` for reusing features (enabled by default).
|
||||
- Add ``qlib.contrib`` for experimental model construction and evaluation.
|
||||
|
||||
|
||||
Version 0.2.3
|
||||
--------------------
|
||||
-------------
|
||||
- Add ``backtest`` module
|
||||
- Decoupling the Strategy, Account, Position, Exchange from the backtest module
|
||||
|
||||
Version 0.2.4
|
||||
--------------------
|
||||
-------------
|
||||
- Add ``profit attribution`` module
|
||||
- Add ``rick_control`` and ``cost_control`` strategies
|
||||
|
||||
|
||||
Version 0.3.0
|
||||
--------------------
|
||||
-------------
|
||||
- Add ``estimator`` module
|
||||
|
||||
Version 0.3.1
|
||||
--------------------
|
||||
-------------
|
||||
- Add ``filter`` module
|
||||
|
||||
Version 0.3.2
|
||||
--------------------
|
||||
-------------
|
||||
- Add real price trading, if the ``factor`` field in the data set is incomplete, use ``adj_price`` trading
|
||||
- Refactor ``handler`` ``launcher`` ``trainer`` code
|
||||
- Support ``backtest`` configuration parameters in the configuration file
|
||||
@@ -65,16 +65,16 @@ Version 0.3.2
|
||||
- Fix bug of ``filter`` module
|
||||
|
||||
Version 0.3.3
|
||||
-------------------
|
||||
-------------
|
||||
- Fix bug of ``filter`` module
|
||||
|
||||
Version 0.3.4
|
||||
--------------------
|
||||
-------------
|
||||
- Support for ``finetune model``
|
||||
- Refactor ``fetcher`` code
|
||||
|
||||
Version 0.3.5
|
||||
--------------------
|
||||
-------------
|
||||
- Support multi-label training, you can provide multiple label in ``handler``. (But LightGBM doesn't support due to the algorithm itself)
|
||||
- Refactor ``handler`` code, dataset.py is no longer used, and you can deploy your own labels and features in ``feature_label_config``
|
||||
- Handler only offer DataFrame. Also, ``trainer`` and model.py only receive DataFrame
|
||||
@@ -82,7 +82,7 @@ Version 0.3.5
|
||||
- Move some date config from ``handler`` to ``trainer``
|
||||
|
||||
Version 0.4.0
|
||||
--------------------
|
||||
-------------
|
||||
- Add `data` package that holds all data-related codes
|
||||
- Reform the data provider structure
|
||||
- Create a server for data centralized management `qlib-server<https://amc-msra.visualstudio.com/trading-algo/_git/qlib-server>`_
|
||||
@@ -100,7 +100,7 @@ Version 0.4.0
|
||||
|
||||
|
||||
Version 0.4.1
|
||||
--------------------
|
||||
-------------
|
||||
- Add support Windows
|
||||
- Fix ``instruments`` type bug
|
||||
- Fix ``features`` is empty bug(It will cause failure in updating)
|
||||
@@ -112,19 +112,19 @@ Version 0.4.1
|
||||
|
||||
|
||||
Version 0.4.2
|
||||
--------------------
|
||||
-------------
|
||||
- Refactor DataHandler
|
||||
- Add ``Alpha360`` DataHandler
|
||||
|
||||
|
||||
Version 0.4.3
|
||||
--------------------
|
||||
-------------
|
||||
- Implementing Online Inference and Trading Framework
|
||||
- Refactoring The interfaces of backtest and strategy module.
|
||||
|
||||
|
||||
Version 0.4.4
|
||||
--------------------
|
||||
-------------
|
||||
- Optimize cache generation performance
|
||||
- Add report module
|
||||
- Fix bug when using ``ServerDatasetCache`` offline.
|
||||
@@ -138,7 +138,7 @@ Version 0.4.4
|
||||
|
||||
|
||||
Version 0.4.5
|
||||
--------------------
|
||||
-------------
|
||||
- Add multi-kernel implementation for both client and server.
|
||||
- Support a new way to load data from client which skips dataset cache.
|
||||
- Change the default dataset method from single kernel implementation to multi kernel implementation.
|
||||
@@ -146,14 +146,14 @@ Version 0.4.5
|
||||
- Support a new method to write config file by using dict.
|
||||
|
||||
Version 0.4.6
|
||||
--------------------
|
||||
-------------
|
||||
- Some bugs are fixed
|
||||
- The default config in `Version 0.4.5` is not friendly to daily frequency data.
|
||||
- Backtest error in TopkWeightStrategy when `WithInteract=True`.
|
||||
|
||||
|
||||
Version 0.5.0
|
||||
--------------------
|
||||
-------------
|
||||
- First opensource version
|
||||
- Refine the docs, code
|
||||
- Add baselines
|
||||
@@ -161,7 +161,7 @@ Version 0.5.0
|
||||
|
||||
|
||||
Version 0.8.0
|
||||
--------------------
|
||||
-------------
|
||||
- The backtest is greatly refactored.
|
||||
- Nested decision execution framework is supported
|
||||
- There are lots of changes for daily trading, it is hard to list all of them. But a few important changes could be noticed
|
||||
@@ -175,5 +175,5 @@ Version 0.8.0
|
||||
|
||||
|
||||
Other Versions
|
||||
----------------------------------
|
||||
--------------
|
||||
Please refer to `Github release Notes <https://github.com/microsoft/qlib/releases>`_
|
||||
|
||||
19
README.md
19
README.md
@@ -172,10 +172,23 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
```
|
||||
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommanded approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
|
||||
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test.yml) may help you find the problem.
|
||||
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test_qlib_from_source.yml) may help you find the problem.
|
||||
|
||||
## Data Preparation
|
||||
Load and prepare data by running the following code:
|
||||
|
||||
### Get with module
|
||||
```bash
|
||||
# get 1d data
|
||||
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# get 1min data
|
||||
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
|
||||
```
|
||||
|
||||
### Get from source
|
||||
|
||||
```bash
|
||||
# get 1d data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
@@ -197,6 +210,8 @@ We recommend users to prepare their own data if they have a high-quality dataset
|
||||
>
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
>
|
||||
> **NOTE**: Users can't incrementally update data based on the offline data provided by Qlib(some fields are removed to reduce the data size). Users should use [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance) to download Yahoo data from scratch and then incrementally update it.
|
||||
>
|
||||
> For more information, please refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
|
||||
* Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
@@ -458,7 +473,7 @@ Before we released Qlib as an open-source project on Github in Sep 2020, Qlib is
|
||||
|
||||
This project welcomes contributions and suggestions.
|
||||
**Here are some
|
||||
[code standards](docs/developer/code_standard.rst) for submiting a pull request.**
|
||||
[code standards and development guidance](docs/developer/code_standard_and_dev_guide.rst) for submiting a pull request.**
|
||||
|
||||
Making contributions is not a hard thing. Solving an issue(maybe just answering a question raised in [issues list](https://github.com/microsoft/qlib/issues) or [gitter](https://gitter.im/Microsoft/qlib)), fixing/issuing a bug, improving the documents and even fixing a typo are important contributions to Qlib.
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ Qlib FAQ
|
||||
############
|
||||
|
||||
Qlib Frequently Asked Questions
|
||||
================================
|
||||
===============================
|
||||
.. contents::
|
||||
:depth: 1
|
||||
:local:
|
||||
@@ -13,7 +13,7 @@ Qlib Frequently Asked Questions
|
||||
|
||||
|
||||
1. RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase...
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
-----------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
@@ -52,7 +52,7 @@ This is caused by the limitation of multiprocessing under windows OS. Please ref
|
||||
|
||||
|
||||
2. qlib.data.cache.QlibCacheException: It sees the key(...) of the redis lock has existed in your redis db now.
|
||||
-----------------------------------------------------------------------------------------------------------------
|
||||
---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
It sees the key of the redis lock has existed in your redis db now. You can use the following command to clear your redis keys and rerun your commands
|
||||
|
||||
@@ -72,7 +72,7 @@ If the issue is not resolved, use ``keys *`` to find if multiple keys exist. If
|
||||
Also, feel free to post a new issue in our GitHub repository. We always check each issue carefully and try our best to solve them.
|
||||
|
||||
3. ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
-----------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -101,7 +101,7 @@ Also, feel free to post a new issue in our GitHub repository. We always check ea
|
||||
|
||||
|
||||
4. BadNamespaceError: / is not a connected namespace
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
----------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -125,7 +125,7 @@ Also, feel free to post a new issue in our GitHub repository. We always check ea
|
||||
|
||||
|
||||
5. TypeError: send() got an unexpected keyword argument 'binary'
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
----------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
.. _pit:
|
||||
|
||||
===========================
|
||||
============================
|
||||
(P)oint-(I)n-(T)ime Database
|
||||
===========================
|
||||
============================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
------------
|
||||
Point-in-time data is a very important consideration when performing any sort of historical market analysis.
|
||||
Point-in-time data is a very important consideration when performing any sort of historical market analysis.
|
||||
|
||||
For example, let’s say we are backtesting a trading strategy and we are using the past five years of historical data as our input.
|
||||
Our model is assumed to trade once a day, at the market close, and we’ll say we are calculating the trading signal for 1 January 2020 in our backtest. At that point, we should only have data for 1 January 2020, 31 December 2019, 30 December 2019 etc.
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
.. _alpha:
|
||||
|
||||
===========================
|
||||
Building Formulaic Alphas
|
||||
===========================
|
||||
=========================
|
||||
Building Formulaic Alphas
|
||||
=========================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
============
|
||||
|
||||
In quantitative trading practice, designing novel factors that can explain and predict future asset returns are of vital importance to the profitability of a strategy. Such factors are usually called alpha factors, or alphas in short.
|
||||
|
||||
@@ -15,28 +15,28 @@ A formulaic alpha, as the name suggests, is a kind of alpha that can be presente
|
||||
|
||||
|
||||
Building Formulaic Alphas in ``Qlib``
|
||||
======================================
|
||||
=====================================
|
||||
|
||||
In ``Qlib``, users can easily build formulaic alphas.
|
||||
|
||||
Example
|
||||
-----------------
|
||||
-------
|
||||
|
||||
`MACD`, short for moving average convergence/divergence, is a formulaic alpha used in technical analysis of stock prices. It is designed to reveal changes in the strength, direction, momentum, and duration of a trend in a stock's price.
|
||||
|
||||
`MACD` can be presented as the following formula:
|
||||
|
||||
.. math::
|
||||
.. math::
|
||||
|
||||
MACD = 2\times (DIF-DEA)
|
||||
|
||||
.. note::
|
||||
|
||||
`DIF` means Differential value, which is 12-period EMA minus 26-period EMA.
|
||||
|
||||
|
||||
.. math::
|
||||
|
||||
DIF = \frac{EMA(CLOSE, 12) - EMA(CLOSE, 26)}{CLOSE}
|
||||
DIF = \frac{EMA(CLOSE, 12) - EMA(CLOSE, 26)}{CLOSE}
|
||||
|
||||
`DEA`means a 9-period EMA of the DIF.
|
||||
|
||||
@@ -65,7 +65,7 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
>> print(df)
|
||||
feature label
|
||||
MACD LABEL
|
||||
datetime instrument
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 -0.011547 -0.019672
|
||||
SH600004 0.002745 -0.014721
|
||||
SH600006 0.010133 0.002911
|
||||
@@ -79,7 +79,7 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
SZ300315 -0.030557 0.012455
|
||||
|
||||
Reference
|
||||
===========
|
||||
=========
|
||||
|
||||
To learn more about ``Data Loader``, please refer to `Data Loader <../component/data.html#data-loader>`_
|
||||
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
.. _serial:
|
||||
|
||||
=================================
|
||||
=============
|
||||
Serialization
|
||||
=================================
|
||||
=============
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
``Qlib`` supports dumping the state of ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc. into a disk and reloading them.
|
||||
============
|
||||
``Qlib`` supports dumping the state of ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc. into a disk and reloading them.
|
||||
|
||||
Serializable Class
|
||||
========================
|
||||
==================
|
||||
|
||||
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
|
||||
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
|
||||
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
|
||||
However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
|
||||
|
||||
Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here <https://pypi.org/project/dill/>`_).
|
||||
|
||||
Example
|
||||
==========================
|
||||
``Qlib``'s serializable class includes ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc., which are subclass of ``qlib.utils.serial.Serializable``.
|
||||
=======
|
||||
``Qlib``'s serializable class includes ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc., which are subclass of ``qlib.utils.serial.Serializable``.
|
||||
Specifically, ``qlib.data.dataset.DatasetH`` is one of them. Users can serialize ``DatasetH`` as follows.
|
||||
|
||||
.. code-block:: Python
|
||||
@@ -33,7 +33,7 @@ Specifically, ``qlib.data.dataset.DatasetH`` is one of them. Users can serialize
|
||||
dataset = pickle.load(file_dataset)
|
||||
|
||||
.. note::
|
||||
Only state of ``DatasetH`` should be saved on the disk, such as some `mean` and `variance` used for data normalization, etc.
|
||||
Only state of ``DatasetH`` should be saved on the disk, such as some `mean` and `variance` used for data normalization, etc.
|
||||
|
||||
After reloading the ``DatasetH``, users need to reinitialize it. It means that users can reset some states of ``DatasetH`` or ``QlibDataHandler`` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states (data is not state and should not be saved on the disk).
|
||||
|
||||
@@ -41,5 +41,5 @@ A more detailed example is in this `link <https://github.com/microsoft/qlib/tree
|
||||
|
||||
|
||||
API
|
||||
===================
|
||||
===
|
||||
Please refer to `Serializable API <../reference/api.html#module-qlib.utils.serial.Serializable>`_.
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
.. _server:
|
||||
|
||||
=================================
|
||||
=============================
|
||||
``Online`` & ``Offline`` mode
|
||||
=================================
|
||||
=============================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
============
|
||||
|
||||
``Qlib`` supports ``Online`` mode and ``Offline`` mode. Only the ``Offline`` mode is introduced in this document.
|
||||
``Qlib`` supports ``Online`` mode and ``Offline`` mode. Only the ``Offline`` mode is introduced in this document.
|
||||
|
||||
The ``Online`` mode is designed to solve the following problems:
|
||||
|
||||
@@ -18,12 +18,12 @@ The ``Online`` mode is designed to solve the following problems:
|
||||
- Make the data can be accessed in a remote way.
|
||||
|
||||
Qlib-Server
|
||||
===============
|
||||
===========
|
||||
|
||||
``Qlib-Server`` is the assorted server system for ``Qlib``, which utilizes ``Qlib`` for basic calculations and provides extensive server system and cache mechanism. With QLibServer, the data provided for ``Qlib`` can be managed in a centralized manner. With ``Qlib-Server``, users can use ``Qlib`` in ``Online`` mode.
|
||||
``Qlib-Server`` is the assorted server system for ``Qlib``, which utilizes ``Qlib`` for basic calculations and provides extensive server system and cache mechanism. With QLibServer, the data provided for ``Qlib`` can be managed in a centralized manner. With ``Qlib-Server``, users can use ``Qlib`` in ``Online`` mode.
|
||||
|
||||
|
||||
|
||||
Reference
|
||||
=================
|
||||
If users are interested in ``Qlib-Server`` and ``Online`` mode, please refer to `Qlib-Server Project <https://github.com/microsoft/qlib-server>`_ and `Qlib-Server Document <https://qlib-server.readthedocs.io/en/latest/>`_.
|
||||
=========
|
||||
If users are interested in ``Qlib-Server`` and ``Online`` mode, please refer to `Qlib-Server Project <https://github.com/microsoft/qlib-server>`_ and `Qlib-Server Document <https://qlib-server.readthedocs.io/en/latest/>`_.
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
.. _task_management:
|
||||
|
||||
=================================
|
||||
===============
|
||||
Task Management
|
||||
=================================
|
||||
===============
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
============
|
||||
|
||||
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
|
||||
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
|
||||
@@ -36,7 +36,7 @@ Here is the base class of ``TaskGen``:
|
||||
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
|
||||
|
||||
Task Storing
|
||||
===============
|
||||
============
|
||||
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
|
||||
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
|
||||
Users **MUST** finish the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
|
||||
@@ -57,7 +57,7 @@ Users need to provide the MongoDB URL and database name for using ``TaskManager`
|
||||
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
|
||||
|
||||
Task Training
|
||||
===============
|
||||
=============
|
||||
After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
|
||||
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
|
||||
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
.. include:: ../../CHANGES.rst
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
.. _data:
|
||||
|
||||
================================
|
||||
==================================
|
||||
Data Layer: Data Framework & Usage
|
||||
================================
|
||||
==================================
|
||||
|
||||
Introduction
|
||||
============================
|
||||
============
|
||||
|
||||
``Data Layer`` provides user-friendly APIs to manage and retrieve data. It provides high-performance data infrastructure.
|
||||
``Data Layer`` provides user-friendly APIs to manage and retrieve data. It provides high-performance data infrastructure.
|
||||
|
||||
It is designed for quantitative investment. For example, users could build formulaic alphas with ``Data Layer`` easily. Please refer to `Building Formulaic Alphas <../advanced/alpha.html>`_ for more details.
|
||||
|
||||
@@ -23,16 +23,16 @@ The introduction of ``Data Layer`` includes the following parts.
|
||||
|
||||
Here is a typical example of Qlib data workflow
|
||||
|
||||
- Users download data and converting data into Qlib format(with filename suffix `.bin`). In this step, typically only some basic data are stored on disk(such as OHLCV).
|
||||
- Users download data and converting data into Qlib format(with filename suffix `.bin`). In this step, typically only some basic data are stored on disk(such as OHLCV).
|
||||
- Creating some basic features based on Qlib's expression Engine(e.g. "Ref($close, 60) / $close", the return of last 60 trading days). Supported operators in the expression engine can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/data/ops.py>`_. This step is typically implemented in Qlib's `Data Loader <https://qlib.readthedocs.io/en/latest/component/data.html#data-loader>`_ which is a component of `Data Handler <https://qlib.readthedocs.io/en/latest/component/data.html#data-handler>`_ .
|
||||
- If users require more complicated data processing (e.g. data normalization), `Data Handler <https://qlib.readthedocs.io/en/latest/component/data.html#data-handler>`_ support user-customized processors to process data(some predefined processors can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`_). The processors are different from operators in expression engine. It is designed for some complicated data processing methods which is hard to supported in operators in expression engine.
|
||||
- At last, `Dataset <https://qlib.readthedocs.io/en/latest/component/data.html#dataset>`_ is responsible to prepare model-specific dataset from the processed data of Data Handler
|
||||
|
||||
Data Preparation
|
||||
============================
|
||||
================
|
||||
|
||||
Qlib Format Data
|
||||
------------------
|
||||
----------------
|
||||
|
||||
We've specially designed a data structure to manage financial data, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information.
|
||||
Such data will be stored with filename suffix `.bin` (We'll call them `.bin` file, `.bin` format, or qlib format). `.bin` file is designed for scientific computing on finance data.
|
||||
@@ -50,11 +50,16 @@ Alpha158 √ √
|
||||
Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency dataset example through this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
|
||||
|
||||
Qlib Format Dataset
|
||||
--------------------
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows.
|
||||
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
-------------------
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows. User can also use numpy to load `.bin` file to validate data.
|
||||
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
Users can leverage `$factor` to get the original trading price (e.g. `$close / $factor` to get the original close price).
|
||||
|
||||
Here are some discussions about the price adjusting of Qlib.
|
||||
|
||||
- https://github.com/microsoft/qlib/issues/991#issuecomment-1075252402
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# download 1d
|
||||
@@ -104,7 +109,7 @@ Automatic update of daily frequency data
|
||||
|
||||
|
||||
Converting CSV Format into Qlib Format
|
||||
-------------------------------------------
|
||||
--------------------------------------
|
||||
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
|
||||
|
||||
@@ -126,16 +131,16 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
|
||||
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
|
||||
|
||||
- Name the CSV file after a stock: `SH600000.csv`, `AAPL.csv` (not case sensitive).
|
||||
|
||||
|
||||
- CSV file includes a column of the stock name. User **must** specify the column name when dumping the data. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all ... --symbol_field_name symbol
|
||||
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
.. code-block::
|
||||
.. code-block::
|
||||
|
||||
symbol,close
|
||||
SH600000,120
|
||||
@@ -145,10 +150,10 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all ... --date_field_name date
|
||||
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
.. code-block::
|
||||
.. code-block::
|
||||
|
||||
symbol,date,close,open,volume
|
||||
SH600000,2020-11-01,120,121,12300000
|
||||
@@ -172,7 +177,7 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
.. note::
|
||||
|
||||
The arguments of `--include_fields` should correspond with the column names of CSV files. The columns names of dataset provided by ``Qlib`` should include open, close, high, low, volume and factor at least.
|
||||
|
||||
|
||||
- `open`
|
||||
The adjusted opening price
|
||||
- `close`
|
||||
@@ -186,11 +191,11 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
- `factor`
|
||||
The Restoration factor. Normally, ``factor = adjusted_price / original_price``, `adjusted price` reference: `split adjusted <https://www.investopedia.com/terms/s/splitadjusted.asp>`_
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
If you want to use your own alpha-factor which can't be calculate by OCHLV, like PE, EPS and so on, you could add it to the CSV files with OHCLV together and then dump it to the Qlib format data.
|
||||
|
||||
Stock Pool (Market)
|
||||
--------------------------------
|
||||
-------------------
|
||||
|
||||
``Qlib`` defines `stock pool <https://github.com/microsoft/qlib/blob/main/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml#L4>`_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows.
|
||||
|
||||
@@ -200,7 +205,7 @@ Stock Pool (Market)
|
||||
|
||||
|
||||
Multiple Stock Modes
|
||||
--------------------------------
|
||||
--------------------
|
||||
|
||||
``Qlib`` now provides two different stock modes for users: China-Stock Mode & US-Stock Mode. Here are some different settings of these two modes:
|
||||
|
||||
@@ -218,23 +223,23 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in china-stock mode
|
||||
Supposed that users download their Qlib format data in the directory ``~/.qlib/qlib_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.constant import REG_CN
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=REG_CN)
|
||||
|
||||
|
||||
|
||||
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
|
||||
- Download us-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in US-stock mode
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/qlib_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.config import REG_US
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)
|
||||
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -242,14 +247,14 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
|
||||
|
||||
Data API
|
||||
========================
|
||||
========
|
||||
|
||||
Data Retrieval
|
||||
---------------
|
||||
--------------
|
||||
Users can use APIs in ``qlib.data`` to retrieve data, please refer to `Data Retrieval <../start/getdata.html>`_.
|
||||
|
||||
Feature
|
||||
------------------
|
||||
-------
|
||||
|
||||
``Qlib`` provides `Feature` and `ExpressionOps` to fetch the features according to users' needs.
|
||||
|
||||
@@ -264,7 +269,7 @@ Feature
|
||||
To know more about ``Feature``, please refer to `Feature API <../reference/api.html#module-qlib.data.base>`_.
|
||||
|
||||
Filter
|
||||
-------------------
|
||||
------
|
||||
``Qlib`` provides `NameDFilter` and `ExpressionDFilter` to filter the instruments according to users' needs.
|
||||
|
||||
- `NameDFilter`
|
||||
@@ -272,7 +277,7 @@ Filter
|
||||
|
||||
- `ExpressionDFilter`
|
||||
Expression dynamic instrument filter. Filter the instruments based on a certain expression. An expression rule indicating a certain feature field is required.
|
||||
|
||||
|
||||
- `basic features filter`: rule_expression = '$close/$open>5'
|
||||
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
|
||||
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
|
||||
@@ -299,29 +304,29 @@ Here is a simple example showing how to use filter in a basic ``Qlib`` workflow
|
||||
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
|
||||
|
||||
Reference
|
||||
-------------
|
||||
---------
|
||||
|
||||
To know more about ``Data API``, please refer to `Data API <../reference/api.html#data>`_.
|
||||
|
||||
|
||||
Data Loader
|
||||
=================
|
||||
===========
|
||||
|
||||
``Data Loader`` in ``Qlib`` is designed to load raw data from the original data source. It will be loaded and used in the ``Data Handler`` module.
|
||||
|
||||
QlibDataLoader
|
||||
---------------
|
||||
--------------
|
||||
|
||||
The ``QlibDataLoader`` class in ``Qlib`` is such an interface that allows users to load raw data from the ``Qlib`` data source.
|
||||
|
||||
StaticDataLoader
|
||||
---------------
|
||||
----------------
|
||||
|
||||
The ``StaticDataLoader`` class in ``Qlib`` is such an interface that allows users to load raw data from file or as provided.
|
||||
|
||||
|
||||
Interface
|
||||
------------
|
||||
---------
|
||||
|
||||
Here are some interfaces of the ``QlibDataLoader`` class:
|
||||
|
||||
@@ -329,28 +334,28 @@ Here are some interfaces of the ``QlibDataLoader`` class:
|
||||
:members:
|
||||
|
||||
API
|
||||
-----------
|
||||
---
|
||||
|
||||
To know more about ``Data Loader``, please refer to `Data Loader API <../reference/api.html#module-qlib.data.dataset.loader>`_.
|
||||
|
||||
|
||||
Data Handler
|
||||
=================
|
||||
============
|
||||
|
||||
The ``Data Handler`` module in ``Qlib`` is designed to handler those common data processing methods which will be used by most of the models.
|
||||
|
||||
Users can use ``Data Handler`` in an automatic workflow by ``qrun``, refer to `Workflow: Workflow Management <workflow.html>`_ for more details.
|
||||
Users can use ``Data Handler`` in an automatic workflow by ``qrun``, refer to `Workflow: Workflow Management <workflow.html>`_ for more details.
|
||||
|
||||
DataHandlerLP
|
||||
--------------
|
||||
-------------
|
||||
|
||||
In addition to use ``Data Handler`` in an automatic workflow with ``qrun``, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data (standardization, remove NaN, etc.) and build datasets.
|
||||
In addition to use ``Data Handler`` in an automatic workflow with ``qrun``, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data (standardization, remove NaN, etc.) and build datasets.
|
||||
|
||||
In order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some learnable ``Processors`` which can learn the parameters of data processing(e.g., parameters for zscore normalization). When new data comes in, these `trained` ``Processors`` can then process the new data and thus processing real-time data in an efficient way becomes possible. More information about ``Processors`` will be listed in the next subsection.
|
||||
|
||||
|
||||
Interface
|
||||
----------------------
|
||||
---------
|
||||
|
||||
Here are some important interfaces that ``DataHandlerLP`` provides:
|
||||
|
||||
@@ -364,7 +369,7 @@ Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that
|
||||
|
||||
|
||||
Processor
|
||||
----------
|
||||
---------
|
||||
|
||||
The ``Processor`` module in ``Qlib`` is designed to be learnable and it is responsible for handling data processing such as `normalization` and `drop none/nan features/labels`.
|
||||
|
||||
@@ -382,14 +387,14 @@ The ``Processor`` module in ``Qlib`` is designed to be learnable and it is respo
|
||||
- ``CSRankNorm``: `processor` that applies cross sectional rank normalization.
|
||||
- ``CSZFillna``: `processor` that fills N/A values in a cross sectional way by the mean of the column.
|
||||
|
||||
Users can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`_).
|
||||
Users can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`_).
|
||||
|
||||
To know more about ``Processor``, please refer to `Processor API <../reference/api.html#module-qlib.data.dataset.processor>`_.
|
||||
|
||||
Example
|
||||
--------------
|
||||
-------
|
||||
|
||||
``Data Handler`` can be run with ``qrun`` by modifying the configuration file, and can also be used as a single module.
|
||||
``Data Handler`` can be run with ``qrun`` by modifying the configuration file, and can also be used as a single module.
|
||||
|
||||
Know more about how to run ``Data Handler`` with ``qrun``, please refer to `Workflow: Workflow Management <workflow.html>`_
|
||||
|
||||
@@ -427,17 +432,17 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day.
|
||||
|
||||
API
|
||||
---------
|
||||
---
|
||||
|
||||
To know more about ``Data Handler``, please refer to `Data Handler API <../reference/api.html#module-qlib.data.dataset.handler>`_.
|
||||
|
||||
|
||||
Dataset
|
||||
=================
|
||||
=======
|
||||
|
||||
The ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing.
|
||||
|
||||
The motivation of this module is that we want to maximize the flexibility of different models to handle data that are suitable for themselves. This module gives the model the flexibility to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
|
||||
The motivation of this module is that we want to maximize the flexibility of different models to handle data that are suitable for themselves. This module gives the model the flexibility to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
|
||||
|
||||
If user's model need process its data in a different way, user could implement his own ``Dataset`` class. If the model's
|
||||
data processing is not special, ``DatasetH`` can be used directly.
|
||||
@@ -448,18 +453,18 @@ The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most im
|
||||
:members:
|
||||
|
||||
API
|
||||
---------
|
||||
---
|
||||
|
||||
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#dataset>`_.
|
||||
|
||||
|
||||
Cache
|
||||
==========
|
||||
=====
|
||||
|
||||
``Cache`` is an optional module that helps accelerate providing data by saving some frequently-used data as cache file. ``Qlib`` provides a `Memcache` class to cache the most-frequently-used data in memory, an inheritable `ExpressionCache` class, and an inheritable `DatasetCache` class.
|
||||
|
||||
Global Memory Cache
|
||||
---------------------
|
||||
-------------------
|
||||
|
||||
`Memcache` is a global memory cache mechanism that composes of three `MemCacheUnit` instances to cache **Calendar**, **Instruments**, and **Features**. The `MemCache` is defined globally in `cache.py` as `H`. Users can use `H['c'], H['i'], H['f']` to get/set `memcache`.
|
||||
|
||||
@@ -471,7 +476,7 @@ Global Memory Cache
|
||||
|
||||
|
||||
ExpressionCache
|
||||
-----------------
|
||||
---------------
|
||||
|
||||
`ExpressionCache` is a cache mechanism that saves expressions such as **Mean($close, 5)**. Users can inherit this base class to define their own cache mechanism that saves expressions according to the following steps.
|
||||
|
||||
@@ -486,7 +491,7 @@ The following shows the details about the interfaces:
|
||||
``Qlib`` has currently provided implemented disk cache `DiskExpressionCache` which inherits from `ExpressionCache` . The expressions data will be stored in the disk.
|
||||
|
||||
DatasetCache
|
||||
-----------------
|
||||
------------
|
||||
|
||||
`DatasetCache` is a cache mechanism that saves datasets. A certain dataset is regulated by a stock pool configuration (or a series of instruments, though not recommended), a list of expressions or static feature fields, the start time, and end time for the collected features and the frequency. Users can inherit this base class to define their own cache mechanism that saves datasets according to the following steps.
|
||||
|
||||
@@ -503,7 +508,7 @@ The following shows the details about the interfaces:
|
||||
|
||||
|
||||
Data and Cache File Structure
|
||||
==================================
|
||||
=============================
|
||||
|
||||
We've specially designed a file structure to manage data and cache, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information. The file structure of data and cache is listed as follows.
|
||||
|
||||
@@ -536,4 +541,3 @@ We've specially designed a file structure to manage data and cache, please refer
|
||||
- .meta : an assorted meta file recording the stockpool config, field names and visit times
|
||||
- .index : an assorted index file recording the line index of all calendars
|
||||
- ...
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
.. _highfreq:
|
||||
|
||||
============================================
|
||||
========================================================================
|
||||
Design of Nested Decision Execution Framework for High-Frequency Trading
|
||||
============================================
|
||||
========================================================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
============
|
||||
|
||||
Daily trading (e.g. portfolio management) and intraday trading (e.g. orders execution) are two hot topics in Quant investment and usually studied separately.
|
||||
|
||||
@@ -15,18 +15,18 @@ In order to support the joint backtest strategies in multiple levels, a correspo
|
||||
|
||||
Besides backtesting, the optimization of strategies from different levels is not standalone and can be affected by each other.
|
||||
For example, the best portfolio management strategy may change with the performance of order executions(e.g. a portfolio with higher turnover may becomes a better choice when we improve the order execution strategies).
|
||||
To achieve the overall good performance , it is necessary to consider the interaction of strategies in different level.
|
||||
To achieve the overall good performance , it is necessary to consider the interaction of strategies in different level.
|
||||
|
||||
Therefore, building a new framework for trading in multiple levels becomes necessary to solve the various problems mentioned above, for which we designed a nested decision execution framework that consider the interaction of strategies.
|
||||
|
||||
.. image:: ../_static/img/framework.svg
|
||||
|
||||
The design of the framework is shown in the yellow part in the middle of the figure above. Each level consists of ``Trading Agent`` and ``Execution Env``. ``Trading Agent`` has its own data processing module (``Information Extractor``), forecasting module (``Forecast Model``) and decision generator (``Decision Generator``). The trading algorithm generates the decisions by the ``Decision Generator`` based on the forecast signals output by the ``Forecast Module``, and the decisions generated by the trading algorithm are passed to the ``Execution Env``, which returns the execution results.
|
||||
The design of the framework is shown in the yellow part in the middle of the figure above. Each level consists of ``Trading Agent`` and ``Execution Env``. ``Trading Agent`` has its own data processing module (``Information Extractor``), forecasting module (``Forecast Model``) and decision generator (``Decision Generator``). The trading algorithm generates the decisions by the ``Decision Generator`` based on the forecast signals output by the ``Forecast Module``, and the decisions generated by the trading algorithm are passed to the ``Execution Env``, which returns the execution results.
|
||||
|
||||
The frequency of trading algorithm, decision content and execution environment can be customized by users (e.g. intraday trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The flexibility of nested decision execution framework makes it easy for users to explore the effects of combining different levels of trading strategies and break down the optimization barriers between different levels of trading algorithm.
|
||||
|
||||
Example
|
||||
===========================
|
||||
=======
|
||||
|
||||
An example of nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
.. _meta:
|
||||
|
||||
=================================
|
||||
======================================================
|
||||
Meta Controller: Meta-Task & Meta-Dataset & Meta-Model
|
||||
=================================
|
||||
======================================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
============
|
||||
``Meta Controller`` provides guidance to ``Forecast Model``, which aims to learn regular patterns among a series of forecasting tasks and use learned patterns to guide forthcoming forecasting tasks. Users can implement their own meta-model instance based on ``Meta Controller`` module.
|
||||
|
||||
Meta Task
|
||||
=============
|
||||
=========
|
||||
|
||||
A `Meta Task` instance is the basic element in the meta-learning framework. It saves the data that can be used for the `Meta Model`. Multiple `Meta Task` instances may share the same `Data Handler`, controlled by `Meta Dataset`. Users should use `prepare_task_data()` to obtain the data that can be directly fed into the `Meta Model`.
|
||||
|
||||
@@ -19,7 +19,7 @@ A `Meta Task` instance is the basic element in the meta-learning framework. It s
|
||||
:members:
|
||||
|
||||
Meta Dataset
|
||||
=============
|
||||
============
|
||||
|
||||
`Meta Dataset` controls the meta-information generating process. It is on the duty of providing data for training the `Meta Model`. Users should use `prepare_tasks` to retrieve a list of `Meta Task` instances.
|
||||
|
||||
@@ -27,26 +27,26 @@ Meta Dataset
|
||||
:members:
|
||||
|
||||
Meta Model
|
||||
=============
|
||||
==========
|
||||
|
||||
General Meta Model
|
||||
------------------
|
||||
`Meta Model` instance is the part that controls the workflow. The usage of the `Meta Model` includes:
|
||||
1. Users train their `Meta Model` with the `fit` function.
|
||||
1. Users train their `Meta Model` with the `fit` function.
|
||||
2. The `Meta Model` instance guides the workflow by giving useful information via the `inference` function.
|
||||
|
||||
.. autoclass:: qlib.model.meta.model.MetaModel
|
||||
:members:
|
||||
|
||||
Meta Task Model
|
||||
------------------
|
||||
---------------
|
||||
This type of meta-model may interact with task definitions directly. Then, the `Meta Task Model` is the class for them to inherit from. They guide the base tasks by modifying the base task definitions. The function `prepare_tasks` can be used to obtain the modified base task definitions.
|
||||
|
||||
.. autoclass:: qlib.model.meta.model.MetaTaskModel
|
||||
:members:
|
||||
|
||||
Meta Guide Model
|
||||
------------------
|
||||
----------------
|
||||
This type of meta-model participates in the training process of the base forecasting model. The meta-model may guide the base forecasting models during their training to improve their performances.
|
||||
|
||||
.. autoclass:: qlib.model.meta.model.MetaGuideModel
|
||||
@@ -54,9 +54,9 @@ This type of meta-model participates in the training process of the base forecas
|
||||
|
||||
|
||||
Example
|
||||
=============
|
||||
``Qlib`` provides an implementation of ``Meta Model`` module, ``DDG-DA``,
|
||||
which adapts to the market dynamics.
|
||||
=======
|
||||
``Qlib`` provides an implementation of ``Meta Model`` module, ``DDG-DA``,
|
||||
which adapts to the market dynamics.
|
||||
|
||||
``DDG-DA`` includes four steps:
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
.. _model:
|
||||
|
||||
============================================
|
||||
===========================================
|
||||
Forecast Model: Model Training & Prediction
|
||||
============================================
|
||||
===========================================
|
||||
|
||||
Introduction
|
||||
===================
|
||||
============
|
||||
|
||||
``Forecast Model`` is designed to make the `prediction score` about stocks. Users can use the ``Forecast Model`` in an automatic workflow by ``qrun``, please refer to `Workflow: Workflow Management <workflow.html>`_.
|
||||
``Forecast Model`` is designed to make the `prediction score` about stocks. Users can use the ``Forecast Model`` in an automatic workflow by ``qrun``, please refer to `Workflow: Workflow Management <workflow.html>`_.
|
||||
|
||||
Because the components in ``Qlib`` are designed in a loosely-coupled way, ``Forecast Model`` can be used as an independent module also.
|
||||
|
||||
@@ -22,11 +22,11 @@ The base class provides the following interfaces:
|
||||
:members:
|
||||
|
||||
``Qlib`` also provides a base class `qlib.model.base.ModelFT <../reference/api.html#qlib.model.base.ModelFT>`_, which includes the method for finetuning the model.
|
||||
|
||||
|
||||
For other interfaces such as `finetune`, please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.
|
||||
|
||||
Example
|
||||
==================
|
||||
=======
|
||||
|
||||
``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``MLP``, ``LSTM``, etc.. These models are treated as the baselines of ``Forecast Model``. The following steps show how to run`` LightGBM`` as an independent module.
|
||||
|
||||
@@ -84,7 +84,7 @@ Example
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
@@ -100,22 +100,22 @@ Example
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
.. note::
|
||||
|
||||
.. note::
|
||||
|
||||
`Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler <data.html#data-handler>`_.
|
||||
`SignalRecord` is the `Record Template` in ``Qlib``, please refer to `Workflow <recorder.html#record-template>`_.
|
||||
|
||||
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
|
||||
Technically, the meaning of the model prediction depends on the label setting designed by user.
|
||||
By default, the meaning of the score is normally the rating of the instruments by the forecasting model. The higher the score, the more profit the instruments.
|
||||
By default, the meaning of the score is normally the rating of the instruments by the forecasting model. The higher the score, the more profit the instruments.
|
||||
|
||||
|
||||
Custom Model
|
||||
===================
|
||||
============
|
||||
|
||||
Qlib supports custom models. If users are interested in customizing their own models and integrating the models into ``Qlib``, please refer to `Custom Model Integration <../start/integration.html>`_.
|
||||
|
||||
|
||||
API
|
||||
===================
|
||||
===
|
||||
Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
.. _online:
|
||||
|
||||
=================================
|
||||
==============
|
||||
Online Serving
|
||||
=================================
|
||||
==============
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
============
|
||||
|
||||
.. image:: ../_static/img/online_serving.png
|
||||
:align: center
|
||||
@@ -15,7 +15,7 @@ Introduction
|
||||
|
||||
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
|
||||
``Online Serving`` is a set of modules for online models using the latest data,
|
||||
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
|
||||
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
|
||||
|
||||
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
|
||||
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
|
||||
@@ -28,25 +28,25 @@ Known limitations currently
|
||||
|
||||
|
||||
Online Manager
|
||||
=============
|
||||
==============
|
||||
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
=============
|
||||
===============
|
||||
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
=============
|
||||
===========
|
||||
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
Updater
|
||||
=============
|
||||
=======
|
||||
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
|
||||
@@ -6,8 +6,8 @@ Qlib Recorder: Experiment Management
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
``Qlib`` contains an experiment management system named ``QlibRecorder``, which is designed to help users handle experiment and analyse results in an efficient way.
|
||||
============
|
||||
``Qlib`` contains an experiment management system named ``QlibRecorder``, which is designed to help users handle experiment and analyse results in an efficient way.
|
||||
|
||||
There are three components of the system:
|
||||
|
||||
@@ -34,13 +34,13 @@ Here is a general view of the structure of the system:
|
||||
- Recorder 2
|
||||
- ...
|
||||
- ...
|
||||
|
||||
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
|
||||
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
|
||||
If users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, please refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.
|
||||
|
||||
Qlib Recorder
|
||||
===================
|
||||
=============
|
||||
``QlibRecorder`` provides a high level API for users to use the experiment management system. The interfaces are wrapped in the variable ``R`` in ``Qlib``, and users can directly use ``R`` to interact with the system. The following command shows how to import ``R`` in Python:
|
||||
|
||||
.. code-block:: Python
|
||||
@@ -55,7 +55,7 @@ Here are the available interfaces of ``QlibRecorder``:
|
||||
:members:
|
||||
|
||||
Experiment Manager
|
||||
===================
|
||||
==================
|
||||
|
||||
The ``ExpManager`` module in ``Qlib`` is responsible for managing different experiments. Most of the APIs of ``ExpManager`` are similar to ``QlibRecorder``, and the most important API will be the ``get_exp`` method. User can directly refer to the documents above for some detailed information about how to use the ``get_exp`` method.
|
||||
|
||||
@@ -65,7 +65,7 @@ The ``ExpManager`` module in ``Qlib`` is responsible for managing different expe
|
||||
For other interfaces such as `create_exp`, `delete_exp`, please refer to `Experiment Manager API <../reference/api.html#experiment-manager>`_.
|
||||
|
||||
Experiment
|
||||
===================
|
||||
==========
|
||||
|
||||
The ``Experiment`` class is solely responsible for a single experiment, and it will handle any operations that are related to an experiment. Basic methods such as `start`, `end` an experiment are included. Besides, methods related to `recorders` are also available: such methods include `get_recorder` and `list_recorders`.
|
||||
|
||||
@@ -77,7 +77,7 @@ For other interfaces such as `search_records`, `delete_recorder`, please refer t
|
||||
``Qlib`` also provides a default ``Experiment``, which will be created and used under certain situations when users use the APIs such as `log_metrics` or `get_exp`. If the default ``Experiment`` is used, there will be related logged information when running ``Qlib``. Users are able to change the name of the default ``Experiment`` in the config file of ``Qlib`` or during ``Qlib``'s `initialization <../start/initialization.html#parameters>`_, which is set to be '`Experiment`'.
|
||||
|
||||
Recorder
|
||||
===================
|
||||
========
|
||||
|
||||
The ``Recorder`` class is responsible for a single recorder. It will handle some detailed operations such as ``log_metrics``, ``log_params`` of a single run. It is designed to help user to easily track results and things being generated during a run.
|
||||
|
||||
@@ -89,7 +89,7 @@ Here are some important APIs that are not included in the ``QlibRecorder``:
|
||||
For other interfaces such as `save_objects`, `load_object`, please refer to `Recorder API <../reference/api.html#recorder>`_.
|
||||
|
||||
Record Template
|
||||
===================
|
||||
===============
|
||||
|
||||
The ``RecordTemp`` class is a class that enables generate experiment results such as IC and backtest in a certain format. We have provided three different `Record Template` class:
|
||||
|
||||
@@ -131,7 +131,7 @@ Here is a simple exampke of what is done in ``PortAnaRecord``, which users can r
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
.. _report:
|
||||
|
||||
==========================================
|
||||
=======================================
|
||||
Analysis: Evaluation & Results Analysis
|
||||
==========================================
|
||||
=======================================
|
||||
|
||||
Introduction
|
||||
===================
|
||||
============
|
||||
|
||||
``Analysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:
|
||||
|
||||
@@ -24,7 +24,7 @@ All of the accumulated profit metrics(e.g. return, max drawdown) in Qlib are cal
|
||||
This avoids the metrics or the plots being skewed exponentially over time.
|
||||
|
||||
Graphical Reports
|
||||
===================
|
||||
=================
|
||||
|
||||
Users can run the following code to get all supported reports.
|
||||
|
||||
@@ -41,13 +41,13 @@ Users can run the following code to get all supported reports.
|
||||
|
||||
|
||||
Usage & Example
|
||||
===================
|
||||
===============
|
||||
|
||||
Usage of `analysis_position.report`
|
||||
-----------------------------------
|
||||
|
||||
API
|
||||
~~~~~~~~~~~~~~~~
|
||||
~~~
|
||||
|
||||
.. automodule:: qlib.contrib.report.analysis_position.report
|
||||
:members:
|
||||
@@ -58,7 +58,7 @@ Graphical Result
|
||||
.. note::
|
||||
|
||||
- Axis X: Trading day
|
||||
- Axis Y:
|
||||
- Axis Y:
|
||||
- `cum bench`
|
||||
Cumulative returns series of benchmark
|
||||
- `cum return wo cost`
|
||||
@@ -82,34 +82,34 @@ Graphical Result
|
||||
- The shaded part above: Maximum drawdown corresponding to `cum return wo cost`
|
||||
- The shaded part below: Maximum drawdown corresponding to `cum ex return wo cost`
|
||||
|
||||
.. image:: ../_static/img/analysis/report.png
|
||||
.. image:: ../_static/img/analysis/report.png
|
||||
|
||||
|
||||
Usage of `analysis_position.score_ic`
|
||||
-------------------------------------
|
||||
|
||||
API
|
||||
~~~~~~~~~~~~~~~~
|
||||
~~~
|
||||
|
||||
.. automodule:: qlib.contrib.report.analysis_position.score_ic
|
||||
:members:
|
||||
|
||||
|
||||
Graphical Result
|
||||
~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. note::
|
||||
.. note::
|
||||
|
||||
- Axis X: Trading day
|
||||
- Axis Y:
|
||||
- Axis Y:
|
||||
- `ic`
|
||||
The `Pearson correlation coefficient` series between `label` and `prediction score`.
|
||||
In the above example, the `label` is formulated as `Ref($close, -2)/Ref($close, -1)-1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
|
||||
|
||||
- `rank_ic`
|
||||
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.
|
||||
|
||||
.. image:: ../_static/img/analysis/score_ic.png
|
||||
.. image:: ../_static/img/analysis/score_ic.png
|
||||
|
||||
|
||||
.. Usage of `analysis_position.cumulative_return`
|
||||
@@ -124,7 +124,7 @@ Graphical Result
|
||||
.. Graphical Result
|
||||
.. ~~~~~~~~~~~~~~~~~
|
||||
..
|
||||
.. .. note::
|
||||
.. .. note::
|
||||
..
|
||||
.. - Axis X: Trading day
|
||||
.. - Axis Y:
|
||||
@@ -134,27 +134,27 @@ Graphical Result
|
||||
.. - In the **buy_minus_sell** graph, the **y** value of the **weight** graph at the bottom is `buy_weight + sell_weight`.
|
||||
.. - In each graph, the **red line** in the histogram on the right represents the average.
|
||||
..
|
||||
.. .. image:: ../_static/img/analysis/cumulative_return_buy.png
|
||||
.. .. image:: ../_static/img/analysis/cumulative_return_buy.png
|
||||
..
|
||||
.. .. image:: ../_static/img/analysis/cumulative_return_sell.png
|
||||
.. .. image:: ../_static/img/analysis/cumulative_return_sell.png
|
||||
..
|
||||
.. .. image:: ../_static/img/analysis/cumulative_return_buy_minus_sell.png
|
||||
.. .. image:: ../_static/img/analysis/cumulative_return_buy_minus_sell.png
|
||||
..
|
||||
.. .. image:: ../_static/img/analysis/cumulative_return_hold.png
|
||||
.. .. image:: ../_static/img/analysis/cumulative_return_hold.png
|
||||
|
||||
|
||||
Usage of `analysis_position.risk_analysis`
|
||||
----------------------------------------------
|
||||
------------------------------------------
|
||||
|
||||
API
|
||||
~~~~~~~~~~~~~~~~
|
||||
~~~
|
||||
|
||||
.. automodule:: qlib.contrib.report.analysis_position.risk_analysis
|
||||
:members:
|
||||
|
||||
|
||||
Graphical Result
|
||||
~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -210,7 +210,7 @@ Graphical Result
|
||||
The `Standard Deviation` series of monthly `CAR` (cumulative abnormal return) without cost.
|
||||
- `excess_return_with_cost_max_drawdown`
|
||||
The `Standard Deviation` series of monthly `CAR` (cumulative abnormal return) with cost.
|
||||
|
||||
|
||||
|
||||
.. image:: ../_static/img/analysis/risk_analysis_annualized_return.png
|
||||
:align: center
|
||||
@@ -221,58 +221,58 @@ Graphical Result
|
||||
.. image:: ../_static/img/analysis/risk_analysis_information_ratio.png
|
||||
:align: center
|
||||
|
||||
.. image:: ../_static/img/analysis/risk_analysis_std.png
|
||||
.. image:: ../_static/img/analysis/risk_analysis_std.png
|
||||
:align: center
|
||||
|
||||
..
|
||||
.. Usage of `analysis_position.rank_label`
|
||||
.. ----------------------------------------------
|
||||
.. ---------------------------------------
|
||||
..
|
||||
.. API
|
||||
.. ~~~~~
|
||||
.. ~~~
|
||||
..
|
||||
.. .. automodule:: qlib.contrib.report.analysis_position.rank_label
|
||||
.. :members:
|
||||
..
|
||||
..
|
||||
.. Graphical Result
|
||||
.. ~~~~~~~~~~~~~~~~~
|
||||
.. ~~~~~~~~~~~~~~~~
|
||||
..
|
||||
.. .. note::
|
||||
.. .. note::
|
||||
..
|
||||
.. - hold/sell/buy graphics:
|
||||
.. - Axis X: Trading day
|
||||
.. - Axis Y:
|
||||
.. - Axis Y:
|
||||
.. Average `ranking ratio`of `label` for stocks that is held/sold/bought on the trading day.
|
||||
..
|
||||
.. In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. The `ranking ratio` can be formulated as follows.
|
||||
.. .. math::
|
||||
..
|
||||
..
|
||||
.. ranking\ ratio = \frac{Ascending\ Ranking\ of\ label}{Number\ of\ Stocks\ in\ the\ Portfolio}
|
||||
..
|
||||
.. .. image:: ../_static/img/analysis/rank_label_hold.png
|
||||
.. .. image:: ../_static/img/analysis/rank_label_hold.png
|
||||
.. :align: center
|
||||
..
|
||||
.. .. image:: ../_static/img/analysis/rank_label_buy.png
|
||||
.. .. image:: ../_static/img/analysis/rank_label_buy.png
|
||||
.. :align: center
|
||||
..
|
||||
.. .. image:: ../_static/img/analysis/rank_label_sell.png
|
||||
.. .. image:: ../_static/img/analysis/rank_label_sell.png
|
||||
.. :align: center
|
||||
..
|
||||
..
|
||||
|
||||
Usage of `analysis_model.analysis_model_performance`
|
||||
-----------------------------------------------------
|
||||
----------------------------------------------------
|
||||
|
||||
API
|
||||
~~~~~
|
||||
~~~
|
||||
|
||||
.. automodule:: qlib.contrib.report.analysis_model.analysis_model_performance
|
||||
:members:
|
||||
|
||||
|
||||
Graphical Results
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -291,13 +291,13 @@ Graphical Results
|
||||
The Difference series between `Cumulative Return` of `Group1` and of `Group5`
|
||||
- `long-average`
|
||||
The Difference series between `Cumulative Return` of `Group1` and average `Cumulative Return` for all stocks.
|
||||
|
||||
|
||||
The `ranking ratio` can be formulated as follows.
|
||||
.. math::
|
||||
|
||||
|
||||
ranking\ ratio = \frac{Ascending\ Ranking\ of\ label}{Number\ of\ Stocks\ in\ the\ Portfolio}
|
||||
|
||||
.. image:: ../_static/img/analysis/analysis_model_cumulative_return.png
|
||||
.. image:: ../_static/img/analysis/analysis_model_cumulative_return.png
|
||||
:align: center
|
||||
|
||||
.. note::
|
||||
@@ -305,7 +305,7 @@ Graphical Results
|
||||
The distribution of long-short/long-average returns on each trading day
|
||||
|
||||
|
||||
.. image:: ../_static/img/analysis/analysis_model_long_short.png
|
||||
.. image:: ../_static/img/analysis/analysis_model_long_short.png
|
||||
:align: center
|
||||
|
||||
.. TODO: ask xiao yang for detial
|
||||
@@ -315,14 +315,14 @@ Graphical Results
|
||||
- The `Pearson correlation coefficient` series between `labels` and `prediction scores` of stocks in portfolio.
|
||||
- The graphics reports can be used to evaluate the `prediction scores`.
|
||||
|
||||
.. image:: ../_static/img/analysis/analysis_model_IC.png
|
||||
.. image:: ../_static/img/analysis/analysis_model_IC.png
|
||||
:align: center
|
||||
|
||||
.. note::
|
||||
- Monthly IC
|
||||
Monthly average of the `Information Coefficient`
|
||||
|
||||
.. image:: ../_static/img/analysis/analysis_model_monthly_IC.png
|
||||
.. image:: ../_static/img/analysis/analysis_model_monthly_IC.png
|
||||
:align: center
|
||||
|
||||
.. note::
|
||||
@@ -331,14 +331,14 @@ Graphical Results
|
||||
- IC Normal Dist. Q-Q
|
||||
The `Quantile-Quantile Plot` is used for the normal distribution of `Information Coefficient` on each trading day.
|
||||
|
||||
.. image:: ../_static/img/analysis/analysis_model_NDQ.png
|
||||
.. image:: ../_static/img/analysis/analysis_model_NDQ.png
|
||||
:align: center
|
||||
|
||||
.. note::
|
||||
- Auto Correlation
|
||||
- The `Pearson correlation coefficient` series between the latest `prediction scores` and the `prediction scores` `lag` days ago of stocks in portfolio on each trading day.
|
||||
- The `Pearson correlation coefficient` series between the latest `prediction scores` and the `prediction scores` `lag` days ago of stocks in portfolio on each trading day.
|
||||
- The graphics reports can be used to estimate the turnover rate.
|
||||
|
||||
|
||||
.. image:: ../_static/img/analysis/analysis_model_auto_correlation.png
|
||||
|
||||
.. image:: ../_static/img/analysis/analysis_model_auto_correlation.png
|
||||
:align: center
|
||||
|
||||
@@ -6,7 +6,7 @@ Portfolio Strategy: Portfolio Management
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
============
|
||||
|
||||
``Portfolio Strategy`` is designed to adopt different portfolio strategies, which means that users can adopt different algorithms to generate investment portfolios based on the prediction scores of the ``Forecast Model``. Users can use the ``Portfolio Strategy`` in an automatic workflow by ``Workflow`` module, please refer to `Workflow: Workflow Management <workflow.html>`_.
|
||||
|
||||
@@ -20,7 +20,7 @@ Base Class & Interface
|
||||
======================
|
||||
|
||||
BaseStrategy
|
||||
------------------
|
||||
------------
|
||||
|
||||
Qlib provides a base class ``qlib.strategy.base.BaseStrategy``. All strategy classes need to inherit the base class and implement its interface.
|
||||
|
||||
@@ -32,7 +32,7 @@ Qlib provides a base class ``qlib.strategy.base.BaseStrategy``. All strategy cla
|
||||
Users can inherit `BaseStrategy` to customize their strategy class.
|
||||
|
||||
WeightStrategyBase
|
||||
--------------------
|
||||
------------------
|
||||
|
||||
Qlib also provides a class ``qlib.contrib.strategy.WeightStrategyBase`` that is a subclass of `BaseStrategy`.
|
||||
|
||||
@@ -60,7 +60,7 @@ Implemented Strategy
|
||||
Qlib provides a implemented strategy classes named `TopkDropoutStrategy`.
|
||||
|
||||
TopkDropoutStrategy
|
||||
------------------
|
||||
-------------------
|
||||
`TopkDropoutStrategy` is a subclass of `BaseStrategy` and implement the interface `generate_order_list` whose process is as follows.
|
||||
|
||||
- Adopt the ``Topk-Drop`` algorithm to calculate the target amount of each stock
|
||||
@@ -74,16 +74,16 @@ TopkDropoutStrategy
|
||||
In general, the number of stocks currently held is `Topk`, with the exception of being zero at the beginning period of trading.
|
||||
For each trading day, let $d$ be the number of the instruments currently held and with a rank $\gt K$ when ranked by the prediction scores from high to low.
|
||||
Then `d` number of stocks currently held with the worst `prediction score` will be sold, and the same number of unheld stocks with the best `prediction score` will be bought.
|
||||
|
||||
|
||||
In general, $d=$`Drop`, especially when the pool of the candidate instruments is large, $K$ is large, and `Drop` is small.
|
||||
|
||||
|
||||
In most cases, ``TopkDrop`` algorithm sells and buys `Drop` stocks every trading day, which yields a turnover rate of 2$\times$`Drop`/$K$.
|
||||
|
||||
|
||||
The following images illustrate a typical scenario.
|
||||
.. image:: ../_static/img/topk_drop.png
|
||||
:alt: Topk-Drop
|
||||
|
||||
|
||||
|
||||
|
||||
- Generate the order list from the target amount
|
||||
|
||||
@@ -98,12 +98,12 @@ and `qlib.contrib.strategy.optimizer.enhanced_indexing.EnhancedIndexingOptimizer
|
||||
|
||||
|
||||
Usage & Example
|
||||
====================
|
||||
===============
|
||||
|
||||
First, user can create a model to get trading signals(the variable name is ``pred_score`` in following cases).
|
||||
|
||||
Prediction Score
|
||||
-----------------
|
||||
----------------
|
||||
|
||||
The `prediction score` is a pandas DataFrame. Its index is <datetime(pd.Timestamp), instrument(str)> and it must
|
||||
contains a `score` column.
|
||||
@@ -134,7 +134,7 @@ Qlib didn't add a step to scale the prediction score to a unified scale due to t
|
||||
- The model has the flexibility to define the target, loss, and data processing. So we don't think there is a silver bullet to rescale it back directly barely based on the model's outputs. If you want to scale it back to some meaningful values(e.g. stock returns.), an intuitive solution is to create a regression model for the model's recent outputs and your recent target values.
|
||||
|
||||
Running backtest
|
||||
-----------------
|
||||
----------------
|
||||
|
||||
- In most cases, users could backtest their portfolio management strategy with ``backtest_daily``.
|
||||
|
||||
@@ -195,7 +195,7 @@ Running backtest
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
# Benchmark is for calculating the excess return of your strategy.
|
||||
# Its data format will be like **ONE normal instrument**.
|
||||
# Its data format will be like **ONE normal instrument**.
|
||||
# For example, you can query its data with the code below
|
||||
# `D.features(["SH000300"], ["$close"], start_time='2010-01-01', end_time='2017-12-31', freq='day')`
|
||||
# It is different from the argument `market`, which indicates a universe of stocks (e.g. **A SET** of stocks like csi300)
|
||||
@@ -262,7 +262,7 @@ Running backtest
|
||||
|
||||
|
||||
Result
|
||||
------------------
|
||||
------
|
||||
|
||||
The backtest results are in the following form:
|
||||
|
||||
@@ -307,5 +307,5 @@ The backtest results are in the following form:
|
||||
|
||||
|
||||
Reference
|
||||
===================
|
||||
=========
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
.. _workflow:
|
||||
|
||||
=================================
|
||||
=============================
|
||||
Workflow: Workflow Management
|
||||
=================================
|
||||
=============================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
============
|
||||
|
||||
The components in `Qlib Framework <../introduction/introduction.html#framework>`_ are designed in a loosely-coupled way. Users could build their own Quant research workflow with these components like `Example <https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py>`_.
|
||||
|
||||
@@ -28,7 +28,7 @@ With ``qrun``, user can easily start an `execution`, which includes the followin
|
||||
For each `execution`, ``Qlib`` has a complete system to tracking all the information as well as artifacts generated during training, inference and evaluation phase. For more information about how ``Qlib`` handles this, please refer to the related document: `Recorder: Experiment Management <../component/recorder.html>`_.
|
||||
|
||||
Complete Example
|
||||
===================
|
||||
================
|
||||
|
||||
Before getting into details, here is a complete example of ``qrun``, which defines the workflow in typical Quant research.
|
||||
Below is a typical config file of ``qrun``.
|
||||
@@ -54,7 +54,7 @@ Below is a typical config file of ``qrun``.
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
@@ -90,13 +90,13 @@ Below is a typical config file of ``qrun``.
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
|
||||
@@ -111,22 +111,22 @@ If users want to use ``qrun`` under debug mode, please use the following command
|
||||
|
||||
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
.. note::
|
||||
.. note::
|
||||
|
||||
`qrun` will be placed in your $PATH directory when installing ``Qlib``.
|
||||
|
||||
.. note::
|
||||
|
||||
.. note::
|
||||
|
||||
The symbol `&` in `yaml` file stands for an anchor of a field, which is useful when another fields include this parameter as part of the value. Taking the configuration file above as an example, users can directly change the value of `market` and `benchmark` without traversing the entire configuration file.
|
||||
|
||||
|
||||
Configuration File
|
||||
===================
|
||||
==================
|
||||
|
||||
Let's get into details of ``qrun`` in this section.
|
||||
Before using ``qrun``, users need to prepare a configuration file. The following content shows how to prepare each part of the configuration file.
|
||||
|
||||
The design logic of the configuration file is very simple. It predefines fixed workflows and provide this yaml interface to users to define how to initialize each component.
|
||||
The design logic of the configuration file is very simple. It predefines fixed workflows and provide this yaml interface to users to define how to initialize each component.
|
||||
It follow the design of `init_instance_by_config <https://github.com/microsoft/qlib/blob/2aee9e0145decc3e71def70909639b5e5a6f4b58/qlib/utils/__init__.py#L264>`_ . It defines the initialization of each component of Qlib, which typically include the class and the initialization arguments.
|
||||
|
||||
For example, the following yaml and code are equivalent.
|
||||
@@ -166,7 +166,7 @@ For example, the following yaml and code are equivalent.
|
||||
|
||||
|
||||
Qlib Init Section
|
||||
--------------------
|
||||
-----------------
|
||||
|
||||
At first, the configuration file needs to contain several basic parameters which will be used for qlib initialization.
|
||||
|
||||
@@ -181,21 +181,21 @@ The meaning of each field is as follows:
|
||||
Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.
|
||||
|
||||
- `region`
|
||||
- If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
|
||||
- If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
|
||||
- If `region` == "cn", ``Qlib`` will be initialized in China-stock mode.
|
||||
|
||||
.. note::
|
||||
|
||||
.. note::
|
||||
|
||||
The value of `region` should be aligned with the data stored in `provider_uri`.
|
||||
|
||||
|
||||
Task Section
|
||||
--------------------
|
||||
------------
|
||||
|
||||
The `task` field in the configuration corresponds to a `task`, which contains the parameters of three different subsections: `Model`, `Dataset` and `Record`.
|
||||
|
||||
Model Section
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
In the `task` field, the `model` section describes the parameters of the model to be used for training and inference. For more information about the base ``Model`` class, please refer to `Qlib Model <../component/model.html>`_.
|
||||
|
||||
@@ -224,14 +224,14 @@ The meaning of each field is as follows:
|
||||
Type: str. The path for the model in qlib.
|
||||
|
||||
- `kwargs`
|
||||
The keywords arguments for the model. Please refer to the specific model implementation for more information: `models <https://github.com/microsoft/qlib/blob/main/qlib/contrib/model>`_.
|
||||
The keywords arguments for the model. Please refer to the specific model implementation for more information: `models <https://github.com/microsoft/qlib/blob/main/qlib/contrib/model>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
.. note::
|
||||
|
||||
``Qlib`` provides a util named: ``init_instance_by_config`` to initialize any class inside ``Qlib`` with the configuration includes the fields: `class`, `module_path` and `kwargs`.
|
||||
|
||||
Dataset Section
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Data <../component/data.html#dataset>`_.
|
||||
|
||||
@@ -266,7 +266,7 @@ Here is the configuration for the ``Dataset`` module which will take care of dat
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
|
||||
Record Section
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
The `record` field is about the parameters the ``Record`` module in ``Qlib``. ``Record`` is responsible for tracking training process and results such as `information Coefficient (IC)` and `backtest` in a standard format.
|
||||
|
||||
@@ -282,7 +282,7 @@ The following script is the configuration of `backtest` and the `strategy` used
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
@@ -299,13 +299,13 @@ Here is the configuration details of different `Record Template` such as ``Signa
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
record:
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
For more information about the ``Record`` module in ``Qlib``, user can refer to the related document: `Record <../component/recorder.html#record-template>`_.
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
.. _code_standard:
|
||||
|
||||
=================================
|
||||
=============
|
||||
Code Standard
|
||||
=================================
|
||||
=============
|
||||
|
||||
Docstring
|
||||
=================================
|
||||
=========
|
||||
Please use the `Numpydoc Style <https://stackoverflow.com/a/24385103>`_.
|
||||
|
||||
Continuous Integration
|
||||
=================================
|
||||
Continuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.
|
||||
======================
|
||||
Continuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.
|
||||
|
||||
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
|
||||
|
||||
@@ -23,7 +23,7 @@ When you submit a PR request, you can check whether your code passes the CI test
|
||||
python -m black . -l 120
|
||||
|
||||
|
||||
2. Qlib will check your code style pylint. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L66).
|
||||
2. Qlib will check your code style pylint. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L66).
|
||||
Sometime pylint's restrictions are not that reasonable. You can ignore specific errors like this
|
||||
|
||||
.. code-block:: python
|
||||
@@ -45,4 +45,16 @@ When you submit a PR request, you can check whether your code passes the CI test
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -e .[dev]
|
||||
pre-commit install
|
||||
pre-commit install
|
||||
|
||||
|
||||
=================================
|
||||
Development Guidance
|
||||
=================================
|
||||
|
||||
As a developer, you often want make changes to `Qlib` and hope it would reflect directly in your environment without reinstalling it. You can install `Qlib` in editable mode with following command.
|
||||
The `[dev]` option will help you to install some related packages when developing `Qlib` (e.g. pytest, sphinx)
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -e .[dev]
|
||||
@@ -1,12 +1,12 @@
|
||||
.. _client:
|
||||
|
||||
Qlib Client-Server Framework
|
||||
===================
|
||||
============================
|
||||
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
-----------
|
||||
------------
|
||||
Client-Server is designed to solve following problems
|
||||
|
||||
- Manage the data in a centralized way. Users don't have to manage data of different versions.
|
||||
@@ -159,13 +159,11 @@ Limitations
|
||||
2. The rolling operation expression with parameter `0` can not be updated rightly under mechanism of the client-server framework.
|
||||
|
||||
API
|
||||
********************
|
||||
***
|
||||
|
||||
The client is based on `python-socketio<https://python-socketio.readthedocs.io>`_ which is a framework that supports WebSocket client for Python language. The client can only propose requests and receive results, which do not include any calculating procedure.
|
||||
|
||||
Class
|
||||
--------------------
|
||||
-----
|
||||
|
||||
.. automodule:: qlib.data.client
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
.. _online:
|
||||
|
||||
Online
|
||||
===================
|
||||
======
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
-------------------
|
||||
------------
|
||||
|
||||
Welcome to use Online, this module simulates what will be like if we do the real trading use our model and strategy.
|
||||
|
||||
@@ -31,11 +31,11 @@ The file structure can be viewed at fileStruct_.
|
||||
|
||||
|
||||
Example
|
||||
-------------------
|
||||
-------
|
||||
|
||||
Let's take an example,
|
||||
|
||||
.. note:: Make sure you have the latest version of `qlib` installed.
|
||||
.. note:: Make sure you have the latest version of `qlib` installed.
|
||||
|
||||
If you want to use the models and data provided by `qlib`, you only need to do as follows.
|
||||
|
||||
@@ -93,7 +93,7 @@ If Your account was saved in "./user_data/", you can see the performance of your
|
||||
Here 'SH000905' represents csi500 and 'SH000300' represents csi300
|
||||
|
||||
Manage your account
|
||||
--------------------
|
||||
-------------------
|
||||
|
||||
Any account processed by `online` should be saved in a folder. you can use commands
|
||||
defined to manage your accounts.
|
||||
@@ -161,7 +161,7 @@ be called at each trading date.
|
||||
>> online update -date 2019-10-16 -path ./user_data/
|
||||
|
||||
API
|
||||
------------------
|
||||
---
|
||||
|
||||
All those operations are based on defined in `qlib.contrib.online.operator`
|
||||
|
||||
@@ -170,7 +170,7 @@ All those operations are based on defined in `qlib.contrib.online.operator`
|
||||
.. _fileStruct:
|
||||
|
||||
File structure
|
||||
------------------
|
||||
--------------
|
||||
|
||||
'user_data' indicates the root of folder.
|
||||
Name that bold indicates it’s a folder, otherwise it’s a document.
|
||||
@@ -214,7 +214,7 @@ Configuration file
|
||||
The configure file used in `online` should contain the model and strategy information.
|
||||
|
||||
About the model
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
First, your configuration file needs to have a field about the model,
|
||||
this field and its contents determine the model we used when generating score at predict date.
|
||||
@@ -243,7 +243,7 @@ contains 2 methods used in `online` module.
|
||||
|
||||
|
||||
About the strategy
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Your need define the strategy used to generate the order list at predict date.
|
||||
|
||||
@@ -259,7 +259,7 @@ Followings are two examples for a TopkAmountStrategy
|
||||
n_drop: 10
|
||||
|
||||
Generated files
|
||||
------------------
|
||||
---------------
|
||||
|
||||
The 'online_generate' command will create the order list at {folder_path}/{user_id}/temp/,
|
||||
the name of that is orderlist_{YYYY-MM-DD}.json, YYYY-MM-DD is the date that those orders to be executed.
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
.. _tuner:
|
||||
|
||||
Tuner
|
||||
===================
|
||||
=====
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
-------------------
|
||||
------------
|
||||
|
||||
Welcome to use Tuner, this document is based on that you can use Estimator proficiently and correctly.
|
||||
|
||||
@@ -41,19 +41,19 @@ We write a simple configuration example as following,
|
||||
tuner_class: QLibTuner
|
||||
qlib_client:
|
||||
auto_mount: False
|
||||
logging_level: INFO
|
||||
logging_level: INFO
|
||||
optimization_criteria:
|
||||
report_type: model
|
||||
report_factor: model_score
|
||||
optim_type: max
|
||||
tuner_pipeline:
|
||||
-
|
||||
model:
|
||||
-
|
||||
model:
|
||||
class: SomeModel
|
||||
space: SomeModelSpace
|
||||
trainer:
|
||||
trainer:
|
||||
class: RollingTrainer
|
||||
strategy:
|
||||
strategy:
|
||||
class: TopkAmountStrategy
|
||||
space: TopkAmountStrategySpace
|
||||
max_evals: 2
|
||||
@@ -166,13 +166,13 @@ Also, there are some optional fields. The meaning of each field is as follows:
|
||||
The class of tuner, str type, must be an already implemented model, such as `QLibTuner` in `qlib`, or a custom tuner, but it must be a subclass of `qlib.contrib.tuner.Tuner`, the default value is `QLibTuner`.
|
||||
|
||||
- `tuner_module_path`
|
||||
The module path, str type, absolute url is also supported, indicates the path of the implementation of tuner. The default value is `qlib.contrib.tuner.tuner`
|
||||
The module path, str type, absolute url is also supported, indicates the path of the implementation of tuner. The default value is `qlib.contrib.tuner.tuner`
|
||||
|
||||
About the optimization criteria
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
You need to designate a factor to optimize, for tuner need a factor to decide which case is better than other cases.
|
||||
Usually, we use the result of `estimator`, such as backtest results and the score of model.
|
||||
Usually, we use the result of `estimator`, such as backtest results and the score of model.
|
||||
|
||||
This part needs contain these fields:
|
||||
|
||||
@@ -203,13 +203,13 @@ The tuner pipeline contains different tuners, and the `tuner` program will proce
|
||||
.. code-block:: YAML
|
||||
|
||||
tuner_pipeline:
|
||||
-
|
||||
model:
|
||||
-
|
||||
model:
|
||||
class: SomeModel
|
||||
space: SomeModelSpace
|
||||
trainer:
|
||||
trainer:
|
||||
class: RollingTrainer
|
||||
strategy:
|
||||
strategy:
|
||||
class: TopkAmountStrategy
|
||||
space: TopkAmountStrategySpace
|
||||
max_evals: 2
|
||||
@@ -249,25 +249,25 @@ You need to use the same dataset to evaluate your different `estimator` experime
|
||||
test_start_date: 2016-07-01
|
||||
test_end_date: 2018-04-30
|
||||
|
||||
- `rolling_period`
|
||||
- `rolling_period`
|
||||
The rolling period, integer type, indicates how many time steps need rolling when rolling the data. The default value is `60`. If you use `RollingTrainer`, this config will be used, or it will be ignored.
|
||||
|
||||
- `train_start_date`
|
||||
Training start time, str type.
|
||||
|
||||
- `train_end_date`
|
||||
- `train_end_date`
|
||||
Training end time, str type.
|
||||
|
||||
- `validate_start_date`
|
||||
- `validate_start_date`
|
||||
Validation start time, str type.
|
||||
|
||||
- `validate_end_date`
|
||||
- `validate_end_date`
|
||||
Validation end time, str type.
|
||||
|
||||
- `test_start_date`
|
||||
- `test_start_date`
|
||||
Test start time, str type.
|
||||
|
||||
- `test_end_date`
|
||||
- `test_end_date`
|
||||
Test end time, str type. If `test_end_date` is `-1` or greater than the last date of the data, the last date of the data will be used as `test_end_date`.
|
||||
|
||||
About the data and backtest
|
||||
@@ -315,11 +315,10 @@ About the data and backtest
|
||||
Experiment Result
|
||||
-----------------
|
||||
|
||||
All the results are stored in experiment file directly, you can check them directly in the corresponding files.
|
||||
All the results are stored in experiment file directly, you can check them directly in the corresponding files.
|
||||
What we save are as following:
|
||||
|
||||
- Global optimal parameters
|
||||
- Local optimal parameters of each tuner
|
||||
- Config file of this `tuner` experiment
|
||||
- Every `estimator` experiments result in the process
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
============================================================
|
||||
======================
|
||||
``Qlib`` Documentation
|
||||
============================================================
|
||||
======================
|
||||
|
||||
``Qlib`` is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.
|
||||
|
||||
@@ -24,12 +24,12 @@ Document Structure
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
:caption: FIRST STEPS:
|
||||
|
||||
|
||||
Installation <start/installation.rst>
|
||||
Initialization <start/initialization.rst>
|
||||
Data Retrieval <start/getdata.rst>
|
||||
Custom Model Integration <start/integration.rst>
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
@@ -48,7 +48,7 @@ Document Structure
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
:caption: ADVANCED TOPICS:
|
||||
|
||||
|
||||
Building Formulaic Alphas <advanced/alpha.rst>
|
||||
Online & Offline mode <advanced/server.rst>
|
||||
Serialization <advanced/serial.rst>
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
===============================
|
||||
|
||||
Introduction
|
||||
===================
|
||||
============
|
||||
|
||||
.. image:: ../_static/img/logo/white_bg_rec+word.png
|
||||
:align: center
|
||||
@@ -13,8 +13,8 @@ Introduction
|
||||
With ``Qlib``, users can easily try their ideas to create better Quant investment strategies.
|
||||
|
||||
Framework
|
||||
===================
|
||||
|
||||
=========
|
||||
|
||||
.. image:: ../_static/img/framework.svg
|
||||
:align: center
|
||||
|
||||
@@ -27,7 +27,7 @@ At the module level, Qlib is a platform that consists of above components. The c
|
||||
Name Description
|
||||
======================== ==============================================================================
|
||||
`Infrastructure` layer `Infrastructure` layer provides underlying support for Quant research.
|
||||
`DataServer` provides high-performance infrastructure for users to manage
|
||||
`DataServer` provides high-performance infrastructure for users to manage
|
||||
and retrieve raw data. `Trainer` provides flexible interface to control
|
||||
the training process of models which enable algorithms controlling the
|
||||
training process.
|
||||
@@ -35,13 +35,13 @@ Name Description
|
||||
`Workflow` layer `Workflow` layer covers the whole workflow of quantitative investment.
|
||||
`Information Extractor` extracts data for models. `Forecast Model` focuses
|
||||
on producing all kinds of forecast signals (e.g. *alpha*, risk) for other
|
||||
modules. With these signals `Decision Generator` will generate the target
|
||||
modules. With these signals `Decision Generator` will generate the target
|
||||
trading decisions(i.e. portfolio, orders) to be executed by `Execution Env`
|
||||
(i.e. the trading market). There may be multiple levels of `Trading Agent`
|
||||
and `Execution Env` (e.g. an *order executor trading agent and intraday
|
||||
order execution environment* could behave like an interday trading
|
||||
environment and nested in *daily portfolio management trading agent and
|
||||
interday trading environment* )
|
||||
interday trading environment* )
|
||||
|
||||
`Interface` layer `Interface` layer tries to present a user-friendly interface for the underlying
|
||||
system. `Analyser` module will provide users detailed analysis reports of
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
|
||||
===============================
|
||||
===========
|
||||
Quick Start
|
||||
===============================
|
||||
===========
|
||||
|
||||
Introduction
|
||||
==============
|
||||
============
|
||||
|
||||
This ``Quick Start`` guide tries to demonstrate
|
||||
|
||||
@@ -14,7 +14,7 @@ This ``Quick Start`` guide tries to demonstrate
|
||||
|
||||
|
||||
Installation
|
||||
==================
|
||||
============
|
||||
|
||||
Users can easily intsall ``Qlib`` according to the following steps:
|
||||
|
||||
@@ -34,7 +34,7 @@ Users can easily intsall ``Qlib`` according to the following steps:
|
||||
To known more about `installation`, please refer to `Qlib Installation <../start/installation.html>`_.
|
||||
|
||||
Prepare Data
|
||||
==============
|
||||
============
|
||||
|
||||
Load and prepare data by running the following code:
|
||||
|
||||
@@ -47,14 +47,14 @@ This dataset is created by public data collected by crawler scripts in ``scripts
|
||||
To known more about `prepare data`, please refer to `Data Preparation <../component/data.html#data-preparation>`_.
|
||||
|
||||
Auto Quant Research Workflow
|
||||
====================================
|
||||
============================
|
||||
|
||||
``Qlib`` provides a tool named ``qrun`` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). Users can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||
``Qlib`` provides a tool named ``qrun`` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). Users can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||
|
||||
- Quant Research Workflow:
|
||||
- Quant Research Workflow:
|
||||
- Run ``qrun`` with a config file of the LightGBM model `workflow_config_lightgbm.yaml` as following.
|
||||
|
||||
.. code-block::
|
||||
.. code-block::
|
||||
|
||||
cd examples # Avoid running program under the directory contains `qlib`
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
|
||||
@@ -64,7 +64,7 @@ Auto Quant Research Workflow
|
||||
The result of ``qrun`` is as follows, which is also the typical result of ``Forecast model(alpha)``. Please refer to `Intraday Trading <../component/backtest.html>`_. for more details about the result.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
risk
|
||||
excess_return_without_cost mean 0.000605
|
||||
std 0.005481
|
||||
@@ -77,7 +77,7 @@ Auto Quant Research Workflow
|
||||
information_ratio 1.187411
|
||||
max_drawdown -0.075024
|
||||
|
||||
|
||||
|
||||
To know more about `workflow` and `qrun`, please refer to `Workflow: Workflow Management <../component/workflow.html>`_.
|
||||
|
||||
- Graphical Reports Analysis:
|
||||
@@ -89,6 +89,6 @@ Auto Quant Research Workflow
|
||||
|
||||
|
||||
Custom Model Integration
|
||||
===============================================
|
||||
========================
|
||||
|
||||
``Qlib`` provides a batch of models (such as ``lightGBM`` and ``MLP`` models) as examples of ``Forecast Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``. If users are interested in the custom model, please refer to `Custom Model Integration <../start/integration.html>`_.
|
||||
|
||||
35
docs/make.bat
Normal file
35
docs/make.bat
Normal file
@@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.https://www.sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
||||
@@ -1,7 +1,7 @@
|
||||
.. _api:
|
||||
================================
|
||||
=============
|
||||
API Reference
|
||||
================================
|
||||
=============
|
||||
|
||||
|
||||
|
||||
@@ -9,32 +9,32 @@ Here you can find all ``Qlib`` interfaces.
|
||||
|
||||
|
||||
Data
|
||||
====================
|
||||
====
|
||||
|
||||
Provider
|
||||
--------------------
|
||||
--------
|
||||
|
||||
.. automodule:: qlib.data.data
|
||||
:members:
|
||||
|
||||
|
||||
Filter
|
||||
--------------------
|
||||
------
|
||||
|
||||
.. automodule:: qlib.data.filter
|
||||
:members:
|
||||
|
||||
Class
|
||||
--------------------
|
||||
-----
|
||||
.. automodule:: qlib.data.base
|
||||
:members:
|
||||
|
||||
Operator
|
||||
--------------------
|
||||
--------
|
||||
.. automodule:: qlib.data.ops
|
||||
:members:
|
||||
|
||||
|
||||
Cache
|
||||
----------------
|
||||
-----
|
||||
.. autoclass:: qlib.data.cache.MemCacheUnit
|
||||
:members:
|
||||
|
||||
@@ -55,7 +55,7 @@ Cache
|
||||
|
||||
|
||||
Storage
|
||||
-------------
|
||||
-------
|
||||
.. autoclass:: qlib.data.storage.storage.BaseStorage
|
||||
:members:
|
||||
|
||||
@@ -82,52 +82,52 @@ Storage
|
||||
|
||||
|
||||
Dataset
|
||||
---------------
|
||||
-------
|
||||
|
||||
Dataset Class
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.__init__
|
||||
:members:
|
||||
|
||||
Data Loader
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.loader
|
||||
:members:
|
||||
|
||||
Data Handler
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.handler
|
||||
:members:
|
||||
|
||||
Processor
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.processor
|
||||
:members:
|
||||
|
||||
|
||||
Contrib
|
||||
====================
|
||||
=======
|
||||
|
||||
Model
|
||||
--------------------
|
||||
-----
|
||||
.. automodule:: qlib.model.base
|
||||
:members:
|
||||
|
||||
Strategy
|
||||
-------------------
|
||||
--------
|
||||
|
||||
.. automodule:: qlib.contrib.strategy.strategy
|
||||
:members:
|
||||
|
||||
Evaluate
|
||||
-----------------
|
||||
--------
|
||||
|
||||
.. automodule:: qlib.contrib.evaluate
|
||||
:members:
|
||||
|
||||
|
||||
|
||||
Report
|
||||
-----------------
|
||||
------
|
||||
|
||||
.. automodule:: qlib.contrib.report.analysis_position.report
|
||||
:members:
|
||||
@@ -159,103 +159,100 @@ Report
|
||||
|
||||
|
||||
Workflow
|
||||
====================
|
||||
========
|
||||
|
||||
|
||||
Experiment Manager
|
||||
--------------------
|
||||
------------------
|
||||
.. autoclass:: qlib.workflow.expm.ExpManager
|
||||
:members:
|
||||
|
||||
Experiment
|
||||
--------------------
|
||||
----------
|
||||
.. autoclass:: qlib.workflow.exp.Experiment
|
||||
:members:
|
||||
|
||||
Recorder
|
||||
--------------------
|
||||
--------
|
||||
.. autoclass:: qlib.workflow.recorder.Recorder
|
||||
:members:
|
||||
|
||||
Record Template
|
||||
--------------------
|
||||
---------------
|
||||
.. automodule:: qlib.workflow.record_temp
|
||||
:members:
|
||||
|
||||
Task Management
|
||||
====================
|
||||
===============
|
||||
|
||||
|
||||
TaskGen
|
||||
--------------------
|
||||
-------
|
||||
.. automodule:: qlib.workflow.task.gen
|
||||
:members:
|
||||
|
||||
TaskManager
|
||||
--------------------
|
||||
-----------
|
||||
.. automodule:: qlib.workflow.task.manage
|
||||
:members:
|
||||
|
||||
Trainer
|
||||
--------------------
|
||||
-------
|
||||
.. automodule:: qlib.model.trainer
|
||||
:members:
|
||||
|
||||
Collector
|
||||
--------------------
|
||||
---------
|
||||
.. automodule:: qlib.workflow.task.collect
|
||||
:members:
|
||||
|
||||
Group
|
||||
--------------------
|
||||
-----
|
||||
.. automodule:: qlib.model.ens.group
|
||||
:members:
|
||||
|
||||
Ensemble
|
||||
--------------------
|
||||
--------
|
||||
.. automodule:: qlib.model.ens.ensemble
|
||||
:members:
|
||||
|
||||
Utils
|
||||
--------------------
|
||||
-----
|
||||
.. automodule:: qlib.workflow.task.utils
|
||||
:members:
|
||||
|
||||
|
||||
Online Serving
|
||||
====================
|
||||
==============
|
||||
|
||||
|
||||
Online Manager
|
||||
--------------------
|
||||
--------------
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
--------------------
|
||||
---------------
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
--------------------
|
||||
-----------
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
|
||||
RecordUpdater
|
||||
--------------------
|
||||
-------------
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
|
||||
|
||||
Utils
|
||||
====================
|
||||
=====
|
||||
|
||||
Serializable
|
||||
--------------------
|
||||
------------
|
||||
|
||||
.. automodule:: qlib.utils.serial.Serializable
|
||||
:members:
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
.. _getdata:
|
||||
|
||||
=============================
|
||||
==============
|
||||
Data Retrieval
|
||||
=============================
|
||||
==============
|
||||
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
====================
|
||||
============
|
||||
|
||||
Users can get stock data with ``Qlib``. The following examples demonstrate the basic user interface.
|
||||
|
||||
Examples
|
||||
====================
|
||||
========
|
||||
|
||||
|
||||
``QLib`` Initialization:
|
||||
@@ -30,7 +30,7 @@ If users followed steps in `initialization <initialization.html>`_ and downloade
|
||||
Load trading calendar with given time range and frequency:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
>> from qlib.data import D
|
||||
>> D.calendar(start_time='2010-01-01', end_time='2017-12-31', freq='day')[:2]
|
||||
[Timestamp('2010-01-04 00:00:00'), Timestamp('2010-01-05 00:00:00')]
|
||||
@@ -46,7 +46,7 @@ Parse a given market name into a stock pool config:
|
||||
Load instruments of certain stock pool in the given time range:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
>> from qlib.data import D
|
||||
>> instruments = D.instruments(market='csi300')
|
||||
>> D.list_instruments(instruments=instruments, start_time='2010-01-01', end_time='2017-12-31', as_list=True)[:6]
|
||||
@@ -79,14 +79,14 @@ For more details about filter, please refer `Filter API <../component/data.html>
|
||||
Load features of certain instruments in a given time range:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
>> from qlib.data import D
|
||||
>> instruments = ['SH600000']
|
||||
>> fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']
|
||||
>> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head()
|
||||
|
||||
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low
|
||||
instrument datetime
|
||||
instrument datetime
|
||||
SH600000 2010-01-04 86.778313 16162960.0 88.825928 88.061483 2.907631
|
||||
2010-01-05 87.433578 28117442.0 86.778313 87.679273 3.235252
|
||||
2010-01-06 85.713585 23632884.0 87.433578 86.641825 1.720009
|
||||
@@ -108,7 +108,7 @@ Load features of certain stock pool in a given time range:
|
||||
>> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head()
|
||||
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low
|
||||
instrument datetime
|
||||
instrument datetime
|
||||
SH600655 2010-01-04 2699.567383 158193.328125 2619.070312 2626.097738 124.580566
|
||||
2010-01-08 2612.359619 77501.406250 2584.567627 2623.220133 83.373047
|
||||
2010-01-11 2712.982422 160852.390625 2612.359619 2636.636556 146.621582
|
||||
@@ -127,7 +127,7 @@ For example, it looks quite long and complicated:
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data import D
|
||||
>> data = D.features(["sh600519"], ["(($high / $close) + ($open / $close)) * (($high / $close) + ($open / $close)) / ($high / $close) + ($open / $close)"], start_time="20200101")
|
||||
>> data = D.features(["sh600519"], ["(($high / $close) + ($open / $close)) * (($high / $close) + ($open / $close)) / (($high / $close) + ($open / $close))"], start_time="20200101")
|
||||
|
||||
|
||||
But using string is not the only way to implement the expression. You can also implement expression by code.
|
||||
@@ -147,5 +147,5 @@ Here is an exmaple which does the same thing as above examples.
|
||||
|
||||
|
||||
API
|
||||
====================
|
||||
===
|
||||
To know more about how to use the Data, go to API Reference: `Data API <../reference/api.html#data>`_
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
.. _initialization:
|
||||
|
||||
====================
|
||||
===================
|
||||
Qlib Initialization
|
||||
====================
|
||||
===================
|
||||
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Initialization
|
||||
=========================
|
||||
==============
|
||||
|
||||
Please follow the steps below to initialize ``Qlib``.
|
||||
|
||||
Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance <https://finance.yahoo.com/lookup>`_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>`_ for more information about customized dataset.
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
|
||||
Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ Initialize Qlib before calling other APIs: run following code in python.
|
||||
from qlib.constant import REG_CN
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
.. note::
|
||||
Do not import qlib package in the repository directory of ``Qlib``, otherwise, errors may occur.
|
||||
|
||||
@@ -56,16 +56,16 @@ The following are several important parameters of `qlib.init` (`Qlib` has a lot
|
||||
- `redis_port`
|
||||
Type: int, optional parameter(default: 6379), port of `redis`
|
||||
|
||||
.. note::
|
||||
|
||||
.. note::
|
||||
|
||||
The value of `region` should be aligned with the data stored in `provider_uri`. Currently, ``scripts/get_data.py`` only provides China stock market data. If users want to use the US stock market data, they should prepare their own US-stock data in `provider_uri` and switch to US-stock mode.
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <../component/data.html#cache>`_ for details.
|
||||
- `exp_manager`
|
||||
Type: dict, optional parameter, the setting of `experiment manager` to be used in qlib. Users can specify an experiment manager class, as well as the tracking URI for all the experiments. However, please be aware that we only support input of a dictionary in the following style for `exp_manager`. For more information about `exp_manager`, users can refer to `Recorder: Experiment Management <../component/recorder.html>`_.
|
||||
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# For example, if you want to set your tracking_uri to a <specific folder>, you can initialize qlib below
|
||||
@@ -78,7 +78,7 @@ The following are several important parameters of `qlib.init` (`Qlib` has a lot
|
||||
}
|
||||
})
|
||||
- `mongo`
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Users need to follow the steps in `installation <https://www.mongodb.com/try/download/community>`_ to install MongoDB firstly and then access it via a URI.
|
||||
Users can access mongodb with credential by setting "task_url" to a string like `"mongodb://%s:%s@%s" % (user, pwd, host + ":" + port)`.
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
.. _installation:
|
||||
|
||||
====================
|
||||
============
|
||||
Installation
|
||||
====================
|
||||
============
|
||||
|
||||
.. currentmodule:: qlib
|
||||
|
||||
@@ -24,7 +24,7 @@ Also, Users can install ``Qlib`` by the source code according to the following s
|
||||
|
||||
- Enter the root directory of ``Qlib``, in which the file ``setup.py`` exists.
|
||||
- Then, please execute the following command to install the environment dependencies and install ``Qlib``:
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install numpy
|
||||
@@ -34,7 +34,7 @@ Also, Users can install ``Qlib`` by the source code according to the following s
|
||||
|
||||
.. note::
|
||||
It's recommended to use anaconda/miniconda to setup the environment. ``Qlib`` needs lightgbm and pytorch packages, use pip to install them.
|
||||
|
||||
|
||||
|
||||
|
||||
Use the following code to make sure the installation successful:
|
||||
@@ -44,6 +44,3 @@ Use the following code to make sure the installation successful:
|
||||
>>> import qlib
|
||||
>>> qlib.__version__
|
||||
<LATEST VERSION>
|
||||
|
||||
|
||||
=====================
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
=========================================
|
||||
========================
|
||||
Custom Model Integration
|
||||
=========================================
|
||||
========================
|
||||
|
||||
Introduction
|
||||
===================
|
||||
============
|
||||
|
||||
``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``MLP``, ``LSTM``, etc.. These models are examples of ``Forecast Model``. In addition to the default models ``Qlib`` provide, users can integrate their own custom models into ``Qlib``.
|
||||
|
||||
@@ -14,7 +14,7 @@ Users can integrate their own custom models according to the following steps.
|
||||
- Test the custom model.
|
||||
|
||||
Custom Model Class
|
||||
===========================
|
||||
==================
|
||||
The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_ and override the methods in it.
|
||||
|
||||
- Override the `__init__` method
|
||||
@@ -36,7 +36,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
- The parameters could include some `optional` parameters with default values, such as `num_boost_round = 1000` for `GBDT`.
|
||||
- Code Example: In the following example, `num_boost_round = 1000` is an optional parameter.
|
||||
.. code-block:: Python
|
||||
|
||||
|
||||
def fit(self, dataset: DatasetH, num_boost_round = 1000, **kwargs):
|
||||
|
||||
# prepare dataset for lgb training and evaluation
|
||||
@@ -101,14 +101,14 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
)
|
||||
|
||||
Configuration File
|
||||
=======================
|
||||
==================
|
||||
|
||||
The configuration file is described in detail in the `Workflow <../component/workflow.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file. The configuration describes which models to use and how we can initialize it.
|
||||
|
||||
- Example: The following example describes the `model` field of configuration file about the custom lightgbm model mentioned above, where `module_path` is the module path, `class` is the class name, and `args` is the hyperparameter passed into the __init__ method. All parameters in the field is passed to `self._params` by `\*\*kwargs` in `__init__` except `loss = mse`.
|
||||
- Example: The following example describes the `model` field of configuration file about the custom lightgbm model mentioned above, where `module_path` is the module path, `class` is the class name, and `args` is the hyperparameter passed into the __init__ method. All parameters in the field is passed to `self._params` by `\*\*kwargs` in `__init__` except `loss = mse`.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
@@ -126,7 +126,7 @@ The configuration file is described in detail in the `Workflow <../component/wor
|
||||
Users could find configuration file of the baselines of the ``Model`` in ``examples/benchmarks``. All the configurations of different models are listed under the corresponding model folder.
|
||||
|
||||
Model Testing
|
||||
=====================
|
||||
=============
|
||||
Assuming that the configuration file is ``examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml``, users can run the following command to test the custom model:
|
||||
|
||||
.. code-block:: bash
|
||||
@@ -136,10 +136,10 @@ Assuming that the configuration file is ``examples/benchmarks/LightGBM/workflow_
|
||||
|
||||
.. note:: ``qrun`` is a built-in command of ``Qlib``.
|
||||
|
||||
Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/workflow_by_code.ipynb``.
|
||||
Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/workflow_by_code.ipynb``.
|
||||
|
||||
|
||||
Reference
|
||||
=====================
|
||||
=========
|
||||
|
||||
To know more about ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <../component/model.html>`_ and `Model API <../reference/api.html#module-qlib.model.base>`_.
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: CatBoostModel
|
||||
module_path: qlib.contrib.model.catboost_model
|
||||
kwargs:
|
||||
loss: RMSE
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
max_depth: 6
|
||||
num_leaves: 100
|
||||
thread_count: 20
|
||||
grow_policy: Lossguide
|
||||
bootstrap_type: Poisson
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,79 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: CatBoostModel
|
||||
module_path: qlib.contrib.model.catboost_model
|
||||
kwargs:
|
||||
loss: RMSE
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
max_depth: 6
|
||||
num_leaves: 100
|
||||
thread_count: 20
|
||||
grow_policy: Lossguide
|
||||
bootstrap_type: Poisson
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -37,7 +37,7 @@ task:
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
num_models: 3
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
@@ -53,11 +53,8 @@ task:
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 1
|
||||
- 1
|
||||
epochs: 28
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 28
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -44,7 +44,7 @@ task:
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
num_models: 3
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
@@ -60,11 +60,8 @@ task:
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 1
|
||||
- 1
|
||||
epochs: 136
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 136
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,4 +1,10 @@
|
||||
# LightGBM
|
||||
* Code: [https://github.com/microsoft/LightGBM](https://github.com/microsoft/LightGBM)
|
||||
* Paper: LightGBM: A Highly Efficient Gradient Boosting
|
||||
Decision Tree. [https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf](https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf).
|
||||
Decision Tree. [https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf](https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf).
|
||||
|
||||
|
||||
# Introductions about the settings/configs.
|
||||
|
||||
`workflow_config_lightgbm_multi_freq.yaml`
|
||||
- It uses data sources of different frequencies (i.e. multiple frequencies) for daily prediction.
|
||||
|
||||
@@ -35,13 +35,13 @@ task:
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
colsample_bytree: 0.9
|
||||
learning_rate: 0.1
|
||||
subsample: 0.9
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_leaves: 250
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LinearModel
|
||||
module_path: qlib.contrib.model.linear
|
||||
kwargs:
|
||||
estimator: ols
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: True
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
102
examples/benchmarks/MLP/workflow_config_mlp_Alpha158_csi500.yaml
Normal file
102
examples/benchmarks/MLP/workflow_config_mlp_Alpha158_csi500.yaml
Normal file
@@ -0,0 +1,102 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "CSZFillna",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
}
|
||||
]
|
||||
learn_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "DropnaProcessor",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
},
|
||||
"DropnaLabel",
|
||||
{
|
||||
"class": "CSZScoreNorm",
|
||||
"kwargs": {"fields_group": "label"}
|
||||
}
|
||||
]
|
||||
process_type: "independent"
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DNNModelPytorch
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 8192
|
||||
GPU: 0
|
||||
weight_decay: 0.0002
|
||||
pt_model_kwargs:
|
||||
input_dim: 157
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,89 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DNNModelPytorch
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
GPU: 0
|
||||
pt_model_kwargs:
|
||||
input_dim: 360
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -43,8 +43,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| TFT (Bryan Lim, et al.) | Alpha158(with selected 20 features) | 0.0358±0.00 | 0.2160±0.03 | 0.0116±0.01 | 0.0720±0.03 | 0.0847±0.02 | 0.8131±0.19 | -0.1824±0.03 |
|
||||
| MLP | Alpha158 | 0.0376±0.00 | 0.2846±0.02 | 0.0429±0.00 | 0.3220±0.01 | 0.0895±0.02 | 1.1408±0.23 | -0.1103±0.02 |
|
||||
| LightGBM(Guolin Ke, et al.) | Alpha158 | 0.0448±0.00 | 0.3660±0.00 | 0.0469±0.00 | 0.3877±0.00 | 0.0901±0.00 | 1.0164±0.00 | -0.1038±0.00 |
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 |
|
||||
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0521±0.00 | 0.4223±0.01 | 0.0502±0.00 | 0.4117±0.01 | 0.1158±0.01 | 1.3432±0.11 | -0.0920±0.01 |
|
||||
|
||||
### Alpha360 dataset
|
||||
|
||||
@@ -56,7 +55,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| Localformer(Juyong Jiang, et al.) | Alpha360 | 0.0404±0.00 | 0.2932±0.04 | 0.0542±0.00 | 0.4110±0.03 | 0.0246±0.02 | 0.3211±0.21 | -0.1095±0.02 |
|
||||
| CatBoost((Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0378±0.00 | 0.2714±0.00 | 0.0467±0.00 | 0.3659±0.00 | 0.0292±0.00 | 0.3781±0.00 | -0.0862±0.00 |
|
||||
| XGBoost(Tianqi Chen, et al.) | Alpha360 | 0.0394±0.00 | 0.2909±0.00 | 0.0448±0.00 | 0.3679±0.00 | 0.0344±0.00 | 0.4527±0.02 | -0.1004±0.00 |
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha360 | 0.0404±0.00 | 0.3023±0.00 | 0.0495±0.00 | 0.3898±0.00 | 0.0468±0.01 | 0.6302±0.20 | -0.0860±0.01 |
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha360 | 0.0390±0.00 | 0.2946±0.01 | 0.0486±0.00 | 0.3836±0.01 | 0.0462±0.01 | 0.6151±0.18 | -0.0915±0.01 |
|
||||
| LightGBM(Guolin Ke, et al.) | Alpha360 | 0.0400±0.00 | 0.3037±0.00 | 0.0499±0.00 | 0.4042±0.00 | 0.0558±0.00 | 0.7632±0.00 | -0.0659±0.00 |
|
||||
| TCN(Shaojie Bai, et al.) | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02 | 0.8295±0.34 | -0.1018±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0497±0.00 | 0.3829±0.04 | 0.0599±0.00 | 0.4736±0.03 | 0.0626±0.02 | 0.8651±0.31 | -0.0994±0.03 |
|
||||
@@ -75,10 +74,15 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
- The base model of TCTS is GRU.
|
||||
- About the datasets
|
||||
- Alpha158 is a tabular dataset. There are less spatial relationships between different features. Each feature are carefully desgined by human (a.k.a feature engineering)
|
||||
- Alpha158 is a tabular dataset. There are less spatial relationships between different features. Each feature are carefully designed by human (a.k.a feature engineering)
|
||||
- Alpha360 contains raw price and volue data without much feature engineering. There are strong strong spatial relationships between the features in the time dimension.
|
||||
- The metrics can be categorized into two
|
||||
- Signal-based evaluation: IC, ICIR, Rank IC, Rank ICIR
|
||||
- 
|
||||
- 
|
||||
- 
|
||||
- 
|
||||
- 
|
||||
- Portfolio-based metrics: Annualized Return, Information Ratio, Max Drawdown
|
||||
|
||||
## Results on CSI500
|
||||
@@ -103,16 +107,21 @@ python run_all_model.py run 3 lightgbm Alpha158 csi500 # for models with random
|
||||
```
|
||||
|
||||
### Alpha158 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| LightGBM | Alpha158 | 0.0377±0.00 | 0.3860±0.00 | 0.0448±0.00 | 0.4675±0.00 | 0.1151±0.00 | 1.3884±0.00 | -0.0898±0.00 |
|
||||
| Linear | Alpha158 | 0.0332±0.00 | 0.3044±0.00 | 0.0462±0.00 | 0.4326±0.00 | 0.0382±0.00 | 0.1723±0.00 | -0.4876±0.00 |
|
||||
| MLP | Alpha158 | 0.0229±0.01 | 0.2181±0.05 | 0.0360±0.00 | 0.3409±0.02 | 0.0043±0.02 | 0.0602±0.27 | -0.2184±0.04 |
|
||||
| LightGBM | Alpha158 | 0.0399±0.00 | 0.4065±0.00 | 0.0482±0.00 | 0.5101±0.00 | 0.1284±0.00 | 1.5650±0.00 | -0.0635±0.00 |
|
||||
| CatBoost | Alpha158 | 0.0345±0.00 | 0.2855±0.00 | 0.0417±0.00 | 0.3740±0.00 | 0.0496±0.00 | 0.5977±0.00 | -0.1496±0.00 |
|
||||
| DoubleEnsemble | Alpha158 | 0.0380±0.00 | 0.3659±0.00 | 0.0442±0.00 | 0.4324±0.00 | 0.0382±0.00 | 0.1723±0.00 | -0.4876±0.00 |
|
||||
|
||||
### Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| MLP | Alpha360 | 0.0258±0.00 | 0.2021±0.02 | 0.0426±0.00 | 0.3840±0.02 | 0.0022±0.02 | 0.0301±0.26 | -0.2064±0.02 |
|
||||
| LightGBM | Alpha360 | 0.0400±0.00 | 0.3605±0.00 | 0.0536±0.00 | 0.5431±0.00 | 0.0505±0.00 | 0.7658±0.02 | -0.1880±0.00 |
|
||||
|
||||
| CatBoost | Alpha360 | 0.0382±0.00 | 0.3229±0.00 | 0.0489±0.00 | 0.4649±0.00 | 0.0297±0.00 | 0.4227±0.02 | -0.1499±0.01 |
|
||||
| DoubleEnsemble | Alpha360 | 0.0361±0.00 | 0.3092±0.00 | 0.0499±0.00 | 0.4793±0.00 | 0.0382±0.00 | 0.1723±0.02 | -0.4876±0.00 |
|
||||
|
||||
# Contributing
|
||||
|
||||
@@ -129,3 +138,10 @@ If you want to contribute your new models, you can follow the steps below.
|
||||
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
|
||||
|
||||
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))
|
||||
|
||||
|
||||
# FAQ
|
||||
|
||||
Q: What's the difference between models with name `*.py` and `*_ts.py`?
|
||||
|
||||
A: Models with name `*_ts.py` are designed for `TSDatasetH` (`TSDatasetH` will create time-series automatically from tabular data). Models with name `*.py` are designed for `DatasetH` (`DatasetH` is usually used in tabular data. But users still can apply time-series models on tabular datasets if the columns has time-series relationships).
|
||||
|
||||
@@ -38,6 +38,9 @@
|
||||
" # install qlib\n",
|
||||
" ! pip install --upgrade numpy\n",
|
||||
" ! pip install pyqlib\n",
|
||||
" if 'google.colab' in sys.modules:\n",
|
||||
" # The Google colab environment is a little outdated. We have to downgrade the pyyaml to make it compatible with other packages\n",
|
||||
" ! pip install pyyaml==5.4.1\n",
|
||||
" # reload\n",
|
||||
" site.main()\n",
|
||||
"\n",
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Qlib provides two kinds of interfaces.
|
||||
(1) Users could define the Quant research workflow by a simple configuration.
|
||||
(2) Qlib is designed in a modularized way and supports creating research workflow by code just like building blocks.
|
||||
|
||||
The interface of (1) is `qrun XXX.yaml`. The interface of (2) is script like this, which nearly does the same thing as `qrun XXX.yaml`
|
||||
"""
|
||||
import qlib
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.8.6"
|
||||
__version__ = "0.8.6.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
@@ -94,7 +94,7 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
else:
|
||||
# Judging system type
|
||||
sys_type = platform.system()
|
||||
if "win" in sys_type.lower():
|
||||
if "windows" in sys_type.lower():
|
||||
# system: window
|
||||
exec_result = os.popen(f"mount -o anon {provider_uri} {mount_path}")
|
||||
result = exec_result.read()
|
||||
@@ -113,6 +113,8 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
# system: linux/Unix/Mac
|
||||
# check mount
|
||||
_remote_uri = provider_uri[:-1] if provider_uri.endswith("/") else provider_uri
|
||||
# `mount a /b/c` is different from `mount a /b/c/`. So we convert it into string to make sure handling it accurately
|
||||
mount_path = str(mount_path)
|
||||
_mount_path = mount_path[:-1] if mount_path.endswith("/") else mount_path
|
||||
_check_level_num = 2
|
||||
_is_mount = False
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -23,7 +23,6 @@ from ..utils import init_instance_by_config
|
||||
from .backtest import backtest_loop, collect_data_loop
|
||||
from .decision import Order
|
||||
from .exchange import Exchange
|
||||
from .position import Position
|
||||
from .utils import CommonInfrastructure
|
||||
|
||||
# make import more user-friendly by adding `from qlib.backtest import STH`
|
||||
@@ -43,8 +42,8 @@ def get_exchange(
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
**kwargs,
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Exchange:
|
||||
"""get_exchange
|
||||
|
||||
@@ -52,14 +51,15 @@ def get_exchange(
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange(). It could be None or any types that are acceptable by `init_instance_by_config`.
|
||||
exchange: Exchange
|
||||
It could be None or any types that are acceptable by `init_instance_by_config`.
|
||||
freq: str
|
||||
frequency of data.
|
||||
start_time: Union[pd.Timestamp, str]
|
||||
closed start time for backtest.
|
||||
end_time: Union[pd.Timestamp, str]
|
||||
closed end time for backtest.
|
||||
codes: list|str
|
||||
codes: Union[list, str]
|
||||
list stock_id list or a string of instruments (i.e. all, csi500, sse50)
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
@@ -70,10 +70,10 @@ def get_exchange(
|
||||
min_cost : float
|
||||
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
|
||||
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
|
||||
deal_price: Union[str, Tuple[str], List[str]]
|
||||
deal_price: Union[str, Tuple[str, str], List[str]]
|
||||
The `deal_price` supports following two types of input
|
||||
- <deal_price> : str
|
||||
- (<buy_price>, <sell_price>): Tuple[str] or List[str]
|
||||
- (<buy_price>, <sell_price>): Tuple[str, str] or List[str]
|
||||
|
||||
<deal_price>, <buy_price> or <sell_price> := <price>
|
||||
<price> := str
|
||||
@@ -151,28 +151,24 @@ def create_account_instance(
|
||||
Postion type.
|
||||
"""
|
||||
if isinstance(account, (int, float)):
|
||||
pos_kwargs = {"init_cash": account}
|
||||
init_cash = account
|
||||
position_dict = {}
|
||||
elif isinstance(account, dict):
|
||||
init_cash = account["cash"]
|
||||
del account["cash"]
|
||||
pos_kwargs = {
|
||||
"init_cash": init_cash,
|
||||
"position_dict": account,
|
||||
}
|
||||
init_cash = account.pop("cash")
|
||||
position_dict = account
|
||||
else:
|
||||
raise ValueError("account must be in (int, float, Position)")
|
||||
raise ValueError("account must be in (int, float, dict)")
|
||||
|
||||
kwargs = {
|
||||
"init_cash": account,
|
||||
"benchmark_config": {
|
||||
return Account(
|
||||
init_cash=init_cash,
|
||||
position_dict=position_dict,
|
||||
pos_type=pos_type,
|
||||
benchmark_config={
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
},
|
||||
"pos_type": pos_type,
|
||||
}
|
||||
kwargs.update(pos_kwargs)
|
||||
return Account(**kwargs)
|
||||
)
|
||||
|
||||
|
||||
def get_strategy_executor(
|
||||
@@ -181,7 +177,7 @@ def get_strategy_executor(
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
) -> Tuple[BaseStrategy, BaseExecutor]:
|
||||
@@ -222,7 +218,7 @@ def backtest(
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
@@ -285,7 +281,7 @@ def collect_data(
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
return_value: dict = None,
|
||||
@@ -339,7 +335,7 @@ def format_decisions(
|
||||
|
||||
cur_freq = decisions[0].strategy.trade_calendar.get_freq()
|
||||
|
||||
res = (cur_freq, [])
|
||||
res: Tuple[str, list] = (cur_freq, [])
|
||||
last_dec_idx = 0
|
||||
for i, dec in enumerate(decisions[1:], 1):
|
||||
if dec.strategy.trade_calendar.get_freq() == cur_freq:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -11,6 +11,7 @@ from qlib.utils import init_instance_by_config
|
||||
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
from .high_performance_ds import BaseOrderIndicator
|
||||
from .position import BasePosition
|
||||
from .report import Indicator, PortfolioMetrics
|
||||
|
||||
@@ -104,7 +105,7 @@ class Account:
|
||||
|
||||
self._pos_type = pos_type
|
||||
self._port_metr_enabled = port_metr_enabled
|
||||
self.benchmark_config = None # avoid no attribute error
|
||||
self.benchmark_config: dict = {} # avoid no attribute error
|
||||
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:
|
||||
@@ -124,8 +125,8 @@ class Account:
|
||||
self.accum_info = AccumulatedInfo()
|
||||
|
||||
# 2) following variables are not shared between layers
|
||||
self.portfolio_metrics = None
|
||||
self.hist_positions = {}
|
||||
self.portfolio_metrics: Optional[PortfolioMetrics] = None
|
||||
self.hist_positions: Dict[pd.Timestamp, BasePosition] = {}
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config)
|
||||
|
||||
def is_port_metr_enabled(self) -> bool:
|
||||
@@ -171,7 +172,7 @@ class Account:
|
||||
|
||||
self.reset_report(self.freq, self.benchmark_config)
|
||||
|
||||
def get_hist_positions(self) -> dict:
|
||||
def get_hist_positions(self) -> Dict[pd.Timestamp, BasePosition]:
|
||||
return self.hist_positions
|
||||
|
||||
def get_cash(self) -> float:
|
||||
@@ -230,13 +231,15 @@ class Account:
|
||||
"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
|
||||
assert self.current_position is not None
|
||||
|
||||
if not self.current_position.skip_update():
|
||||
stock_list = self.current_position.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time))
|
||||
self.current_position.update_stock_price(stock_id=code, price=bar_close)
|
||||
# update holding day count
|
||||
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
|
||||
@@ -249,6 +252,8 @@ class Account:
|
||||
# for the first trade date, account_value - init_cash
|
||||
# self.portfolio_metrics.is_empty() to judge is_first_trade_date
|
||||
# get last_account_value, last_total_cost, last_total_turnover
|
||||
assert self.portfolio_metrics is not None
|
||||
|
||||
if self.portfolio_metrics.is_empty():
|
||||
last_account_value = self.init_cash
|
||||
last_total_cost = 0
|
||||
@@ -299,9 +304,9 @@ class Account:
|
||||
trade_exchange: Exchange,
|
||||
atomic: bool,
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_info: list = None,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
trade_info: list = [],
|
||||
inner_order_indicators: List[BaseOrderIndicator] = [],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
|
||||
indicator_config: dict = {},
|
||||
) -> None:
|
||||
"""update trade indicators and order indicators in each bar end"""
|
||||
@@ -335,9 +340,9 @@ class Account:
|
||||
trade_exchange: Exchange,
|
||||
atomic: bool,
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_info: list = None,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
trade_info: list = [],
|
||||
inner_order_indicators: List[BaseOrderIndicator] = [],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
|
||||
indicator_config: dict = {},
|
||||
) -> None:
|
||||
"""update account at each trading bar step
|
||||
@@ -398,6 +403,7 @@ class Account:
|
||||
def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
|
||||
"""get the history portfolio_metrics and positions instance"""
|
||||
if self.is_port_metr_enabled():
|
||||
assert self.portfolio_metrics is not None
|
||||
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
|
||||
_positions = self.get_hist_positions()
|
||||
return _portfolio_metrics, _positions
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -36,10 +36,13 @@ def backtest_loop(
|
||||
indicator: Indicator
|
||||
it computes the trading indicator
|
||||
"""
|
||||
return_value = {}
|
||||
return_value: dict = {}
|
||||
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||
pass
|
||||
return return_value.get("portfolio_metrics"), return_value.get("indicator")
|
||||
|
||||
portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
|
||||
indicator = cast(Indicator, return_value.get("indicator"))
|
||||
return portfolio_metrics, indicator
|
||||
|
||||
|
||||
def collect_data_loop(
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import time
|
||||
from enum import IntEnum
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, List, Optional, Tuple, TypeVar, Union, cast
|
||||
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from qlib.data.data import Cal
|
||||
@@ -23,9 +24,11 @@ from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
DecisionType = TypeVar("DecisionType")
|
||||
|
||||
|
||||
class OrderDir(IntEnum):
|
||||
# Order direction
|
||||
# Order direction
|
||||
SELL = 0
|
||||
BUY = 1
|
||||
|
||||
@@ -65,7 +68,7 @@ class Order:
|
||||
# - not tradable: the deal_amount == 0 , factor is None
|
||||
# - the stock is suspended and the entire order fails. No cost for this order
|
||||
# - dealt or partially dealt: deal_amount >= 0 and factor is not None
|
||||
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
|
||||
deal_amount: float = 0.0 # `deal_amount` is a non-negative value
|
||||
factor: Optional[float] = None
|
||||
|
||||
# TODO:
|
||||
@@ -179,8 +182,8 @@ class OrderHelper:
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time if start_time is not None else pd.Timestamp(start_time),
|
||||
end_time=end_time if end_time is not None else pd.Timestamp(end_time),
|
||||
start_time=None if start_time is None else pd.Timestamp(start_time),
|
||||
end_time=None if end_time is None else pd.Timestamp(end_time),
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
@@ -246,7 +249,7 @@ class IdxTradeRange(TradeRange):
|
||||
class TradeRangeByTime(TradeRange):
|
||||
"""This is a helper function for make decisions"""
|
||||
|
||||
def __init__(self, start_time: str, end_time: str) -> None:
|
||||
def __init__(self, start_time: str | time, end_time: str | time) -> None:
|
||||
"""
|
||||
This is a callable class.
|
||||
|
||||
@@ -256,13 +259,13 @@ class TradeRangeByTime(TradeRange):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : str
|
||||
start_time : str | time
|
||||
e.g. "9:30"
|
||||
end_time : str
|
||||
end_time : str | time
|
||||
e.g. "14:30"
|
||||
"""
|
||||
self.start_time = pd.Timestamp(start_time).time()
|
||||
self.end_time = pd.Timestamp(end_time).time()
|
||||
self.start_time = pd.Timestamp(start_time).time() if isinstance(start_time, str) else start_time
|
||||
self.end_time = pd.Timestamp(end_time).time() if isinstance(end_time, str) else end_time
|
||||
assert self.start_time < self.end_time
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
|
||||
@@ -281,7 +284,7 @@ class TradeRangeByTime(TradeRange):
|
||||
return max(val_start, start_time), min(val_end, end_time)
|
||||
|
||||
|
||||
class BaseTradeDecision:
|
||||
class BaseTradeDecision(Generic[DecisionType]):
|
||||
"""
|
||||
Trade decisions ara made by strategy and executed by executor
|
||||
|
||||
@@ -316,20 +319,21 @@ class BaseTradeDecision:
|
||||
"""
|
||||
self.strategy = strategy
|
||||
self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
|
||||
self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
|
||||
if isinstance(trade_range, Tuple):
|
||||
# upper strategy has no knowledge about the sub executor before `_init_sub_trading`
|
||||
self.total_step: Optional[int] = None
|
||||
if isinstance(trade_range, tuple):
|
||||
# for Tuple[int, int]
|
||||
trade_range = IdxTradeRange(*trade_range)
|
||||
self.trade_range: TradeRange = trade_range
|
||||
self.trade_range: Optional[TradeRange] = trade_range
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
def get_decision(self) -> List[DecisionType]:
|
||||
"""
|
||||
get the **concrete decision** (e.g. execution orders)
|
||||
This will be called by the inner strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[object]:
|
||||
List[DecisionType:
|
||||
The decision result. Typically it is some orders
|
||||
Example:
|
||||
[]:
|
||||
@@ -363,13 +367,13 @@ class BaseTradeDecision:
|
||||
# purpose 2)
|
||||
return self.strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
def _get_range_limit(self, **kwargs) -> Tuple[int, int]:
|
||||
def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
|
||||
if self.trade_range is not None:
|
||||
return self.trade_range(trade_calendar=kwargs.get("inner_calendar"))
|
||||
return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar")))
|
||||
else:
|
||||
raise NotImplementedError("The decision didn't provide an index range")
|
||||
|
||||
def get_range_limit(self, **kwargs) -> Tuple[int, int]:
|
||||
def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
|
||||
"""
|
||||
return the expected step range for limiting the decision execution time
|
||||
Both left and right are **closed**
|
||||
@@ -421,6 +425,7 @@ class BaseTradeDecision:
|
||||
if getattr(self, "total_step", None) is not None:
|
||||
# if `self.update` is called.
|
||||
# Then the _start_idx, _end_idx should be clipped
|
||||
assert self.total_step is not None
|
||||
if _start_idx < 0 or _end_idx >= self.total_step:
|
||||
logger = get_module_logger("decision")
|
||||
logger.warning(
|
||||
@@ -516,7 +521,7 @@ class BaseTradeDecision:
|
||||
inner_trade_decision.trade_range = self.trade_range
|
||||
|
||||
|
||||
class EmptyTradeDecision(BaseTradeDecision):
|
||||
class EmptyTradeDecision(BaseTradeDecision[object]):
|
||||
def get_decision(self) -> List[object]:
|
||||
return []
|
||||
|
||||
@@ -524,23 +529,29 @@ class EmptyTradeDecision(BaseTradeDecision):
|
||||
return True
|
||||
|
||||
|
||||
class TradeDecisionWO(BaseTradeDecision):
|
||||
class TradeDecisionWO(BaseTradeDecision[Order]):
|
||||
"""
|
||||
Trade Decision (W)ith (O)rder.
|
||||
Besides, the time_range is also included.
|
||||
"""
|
||||
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
order_list: List[Order],
|
||||
strategy: BaseStrategy,
|
||||
trade_range: Union[Tuple[int, int], TradeRange] = None,
|
||||
) -> None:
|
||||
super().__init__(strategy, trade_range=trade_range)
|
||||
self.order_list = order_list
|
||||
self.order_list = cast(List[Order], order_list)
|
||||
start, end = strategy.trade_calendar.get_step_time()
|
||||
for o in order_list:
|
||||
assert isinstance(o, Order)
|
||||
if o.start_time is None:
|
||||
o.start_time = start
|
||||
if o.end_time is None:
|
||||
o.end_time = end
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
def get_decision(self) -> List[Order]:
|
||||
return self.order_list
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from ..utils.index_data import IndexData
|
||||
|
||||
@@ -32,7 +32,7 @@ class Exchange:
|
||||
start_time: Union[pd.Timestamp, str] = None,
|
||||
end_time: Union[pd.Timestamp, str] = None,
|
||||
codes: Union[list, str] = "all",
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] = None,
|
||||
subscribe_fields: list = [],
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
volume_threshold: Union[tuple, dict] = None,
|
||||
@@ -42,7 +42,7 @@ class Exchange:
|
||||
impact_cost: float = 0.0,
|
||||
extra_quote: pd.DataFrame = None,
|
||||
quote_cls: Type[BaseQuote] = NumpyQuote,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""__init__
|
||||
:param freq: frequency of data
|
||||
@@ -141,7 +141,7 @@ class Exchange:
|
||||
if limit_threshold is None:
|
||||
if C.region == REG_CN:
|
||||
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
|
||||
elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1:
|
||||
elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1:
|
||||
if C.region == REG_CN:
|
||||
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
|
||||
|
||||
@@ -150,7 +150,7 @@ class Exchange:
|
||||
deal_price = "$" + deal_price
|
||||
self.buy_price = self.sell_price = deal_price
|
||||
elif isinstance(deal_price, (tuple, list)):
|
||||
self.buy_price, self.sell_price = deal_price
|
||||
self.buy_price, self.sell_price = cast(Tuple[str, str], deal_price)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@@ -167,10 +167,10 @@ class Exchange:
|
||||
|
||||
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
||||
if self.limit_type == self.LT_TP_EXP:
|
||||
assert isinstance(limit_threshold, tuple)
|
||||
for exp in limit_threshold:
|
||||
necessary_fields.add(exp)
|
||||
all_fields = necessary_fields | set(vol_lt_fields)
|
||||
all_fields = list(all_fields | set(subscribe_fields))
|
||||
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
|
||||
|
||||
self.all_fields = all_fields
|
||||
|
||||
@@ -249,9 +249,9 @@ class Exchange:
|
||||
LT_FLT = "float" # float
|
||||
LT_NONE = "none" # none
|
||||
|
||||
def _get_limit_type(self, limit_threshold: Union[Tuple, float, None]) -> str:
|
||||
def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
|
||||
"""get limit type"""
|
||||
if isinstance(limit_threshold, Tuple):
|
||||
if isinstance(limit_threshold, tuple):
|
||||
return self.LT_TP_EXP
|
||||
elif isinstance(limit_threshold, float):
|
||||
return self.LT_FLT
|
||||
@@ -268,14 +268,16 @@ class Exchange:
|
||||
self.quote_df["limit_sell"] = False
|
||||
elif limit_type == self.LT_TP_EXP:
|
||||
# set limit
|
||||
limit_threshold = cast(tuple, limit_threshold)
|
||||
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
|
||||
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
|
||||
elif limit_type == self.LT_FLT:
|
||||
limit_threshold = cast(float, limit_threshold)
|
||||
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
|
||||
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
|
||||
|
||||
@staticmethod
|
||||
def _get_vol_limit(volume_threshold: Union[tuple, dict]) -> Tuple[Optional[list], Optional[list], set]:
|
||||
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
|
||||
"""
|
||||
preprocess the volume limit.
|
||||
get the fields need to get from qlib.
|
||||
@@ -340,11 +342,11 @@ class Exchange:
|
||||
if direction is None:
|
||||
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
||||
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
|
||||
return buy_limit or sell_limit
|
||||
return bool(buy_limit or sell_limit)
|
||||
elif direction == Order.BUY:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
||||
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all"))
|
||||
elif direction == Order.SELL:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
|
||||
return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all"))
|
||||
else:
|
||||
raise ValueError(f"direction {direction} is not supported!")
|
||||
|
||||
@@ -382,7 +384,7 @@ class Exchange:
|
||||
order: Order,
|
||||
trade_account: Account = None,
|
||||
position: BasePosition = None,
|
||||
dealt_order_amount: defaultdict = defaultdict(float),
|
||||
dealt_order_amount: Dict[str, float] = defaultdict(float),
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Deal order when the actual transaction
|
||||
@@ -426,9 +428,10 @@ class Exchange:
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
field: str,
|
||||
method: str = "ts_data_last",
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, method=method) # TODO: missing `field`?
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field=field, method=method)
|
||||
|
||||
def get_close(
|
||||
self,
|
||||
@@ -444,8 +447,8 @@ class Exchange:
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: str = "sum",
|
||||
) -> float:
|
||||
method: Optional[str] = "sum",
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
|
||||
|
||||
@@ -455,8 +458,8 @@ class Exchange:
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: OrderDir,
|
||||
method: str = "ts_data_last",
|
||||
) -> float:
|
||||
method: Optional[str] = "ts_data_last",
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
if direction == OrderDir.SELL:
|
||||
pstr = self.sell_price
|
||||
elif direction == OrderDir.BUY:
|
||||
@@ -544,7 +547,7 @@ class Exchange:
|
||||
)
|
||||
return amount_dict
|
||||
|
||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float) -> float:
|
||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
|
||||
"""
|
||||
Calculate the real adjust deal amount when considering the trading unit
|
||||
:param current_amount:
|
||||
@@ -572,7 +575,7 @@ class Exchange:
|
||||
current_position: dict,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> list:
|
||||
) -> List[Order]:
|
||||
"""
|
||||
Note: some future information is used in this function
|
||||
Parameter:
|
||||
@@ -681,6 +684,7 @@ class Exchange:
|
||||
factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
else:
|
||||
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
|
||||
assert factor is not None
|
||||
return factor
|
||||
|
||||
def get_amount_of_trade_unit(
|
||||
@@ -718,12 +722,12 @@ class Exchange:
|
||||
|
||||
def round_amount_by_trade_unit(
|
||||
self,
|
||||
deal_amount,
|
||||
deal_amount: float,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> float:
|
||||
"""Parameter
|
||||
Please refer to the docs of get_amount_of_trade_unit
|
||||
deal_amount : float, adjusted amount
|
||||
@@ -741,7 +745,7 @@ class Exchange:
|
||||
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
|
||||
return deal_amount
|
||||
|
||||
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int:
|
||||
def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Optional[float]:
|
||||
"""parse the capacity limit string and return the actual amount of orders that can be executed.
|
||||
NOTE:
|
||||
this function will change the order.deal_amount **inplace**
|
||||
@@ -753,15 +757,12 @@ class Exchange:
|
||||
dealt_order_amount : dict
|
||||
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
|
||||
"""
|
||||
if order.direction == Order.BUY:
|
||||
vol_limit = self.buy_vol_limit
|
||||
elif order.direction == Order.SELL:
|
||||
vol_limit = self.sell_vol_limit
|
||||
vol_limit = self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit
|
||||
|
||||
if vol_limit is None:
|
||||
return order.deal_amount
|
||||
|
||||
vol_limit_num = []
|
||||
vol_limit_num: List[float] = []
|
||||
for limit in vol_limit:
|
||||
assert isinstance(limit, tuple)
|
||||
if limit[0] == "current":
|
||||
@@ -772,7 +773,7 @@ class Exchange:
|
||||
field=limit[1],
|
||||
method="sum",
|
||||
)
|
||||
vol_limit_num.append(limit_value)
|
||||
vol_limit_num.append(cast(float, limit_value))
|
||||
elif limit[0] == "cum":
|
||||
limit_value = self.quote.get_data(
|
||||
order.stock_id,
|
||||
@@ -790,12 +791,14 @@ class Exchange:
|
||||
if vol_limit_min < orig_deal_amount:
|
||||
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")
|
||||
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
|
||||
return None
|
||||
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price: float, cash: float, cost_ratio: float) -> float:
|
||||
"""return the real order amount after cash limit for buying.
|
||||
Parameters
|
||||
----------
|
||||
trade_price : float
|
||||
position : cash
|
||||
cash : float
|
||||
cost_ratio : float
|
||||
|
||||
Return
|
||||
@@ -803,7 +806,7 @@ class Exchange:
|
||||
float
|
||||
the real order amount after cash limit for buying.
|
||||
"""
|
||||
max_trade_amount = 0
|
||||
max_trade_amount = 0.0
|
||||
if cash >= self.min_cost:
|
||||
# critical_price means the stock transaction price when the service fee is equal to min_cost.
|
||||
critical_price = self.min_cost / cost_ratio + self.min_cost
|
||||
@@ -829,8 +832,11 @@ class Exchange:
|
||||
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
|
||||
:return: trade_price, trade_val, trade_cost
|
||||
"""
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
|
||||
total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price
|
||||
trade_price = cast(
|
||||
float,
|
||||
self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction),
|
||||
)
|
||||
total_trade_val = cast(float, self.get_volume(order.stock_id, order.start_time, order.end_time)) * trade_price
|
||||
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
|
||||
order.deal_amount = order.amount # set to full amount and clip it step by step
|
||||
# Clipping amount first
|
||||
@@ -897,7 +903,7 @@ class Exchange:
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("order type {} error".format(order.type))
|
||||
raise NotImplementedError("order direction {} error".format(order.direction))
|
||||
|
||||
trade_val = order.deal_amount * trade_price
|
||||
trade_cost = max(trade_val * cost_ratio, self.min_cost)
|
||||
|
||||
@@ -4,7 +4,7 @@ import copy
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from types import GeneratorType
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Generator, List, Tuple, Union, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -16,13 +16,7 @@ from ..strategy.base import BaseStrategy
|
||||
from ..utils import init_instance_by_config
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
from .utils import (
|
||||
BaseInfrastructure,
|
||||
CommonInfrastructure,
|
||||
LevelInfrastructure,
|
||||
TradeCalendarManager,
|
||||
get_start_end_idx,
|
||||
)
|
||||
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, get_start_end_idx
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
@@ -39,8 +33,8 @@ class BaseExecutor:
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
settle_type=BasePosition.ST_NO, # TODO: add typehint
|
||||
**kwargs,
|
||||
settle_type: str = BasePosition.ST_NO,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
@@ -127,10 +121,10 @@ class BaseExecutor:
|
||||
get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")
|
||||
|
||||
# record deal order amount in one day
|
||||
self.dealt_order_amount = defaultdict(float)
|
||||
self.dealt_order_amount: Dict[str, float] = defaultdict(float)
|
||||
self.deal_day = None
|
||||
|
||||
def reset_common_infra(self, common_infra: BaseInfrastructure, copy_trade_account: bool = False) -> None:
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset trade_account
|
||||
@@ -141,14 +135,15 @@ class BaseExecutor:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
if common_infra.has("trade_account"):
|
||||
if copy_trade_account:
|
||||
# NOTE: there is a trick in the code.
|
||||
# shallow copy is used instead of deepcopy.
|
||||
# 1. So positions are shared
|
||||
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
|
||||
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
|
||||
else:
|
||||
self.trade_account: Account = common_infra.get("trade_account")
|
||||
# NOTE: there is a trick in the code.
|
||||
# shallow copy is used instead of deepcopy.
|
||||
# 1. So positions are shared
|
||||
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
|
||||
self.trade_account: Account = (
|
||||
copy.copy(common_infra.get("trade_account"))
|
||||
if copy_trade_account
|
||||
else common_infra.get("trade_account")
|
||||
)
|
||||
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
|
||||
|
||||
@property
|
||||
@@ -164,7 +159,7 @@ class BaseExecutor:
|
||||
"""
|
||||
return self.level_infra.get("trade_calendar")
|
||||
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs) -> None:
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
- reset `start_time` and `end_time`, used in trade calendar
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
@@ -200,20 +195,17 @@ class BaseExecutor:
|
||||
execute_result : List[object]
|
||||
the executed result for trade decision
|
||||
"""
|
||||
return_value = {}
|
||||
return_value: dict = {}
|
||||
for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
return cast(list, return_value.get("execute_result"))
|
||||
|
||||
@abstractmethod
|
||||
def _collect_data(
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
level: int = 0,
|
||||
) -> Union[
|
||||
Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]],
|
||||
Tuple[List[object], dict],
|
||||
]:
|
||||
) -> Union[Generator[Any, Any, Tuple[List[object], dict]], Tuple[List[object], dict]]:
|
||||
"""
|
||||
Please refer to the doc of collect_data
|
||||
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
|
||||
@@ -235,7 +227,7 @@ class BaseExecutor:
|
||||
trade_decision: BaseTradeDecision,
|
||||
return_value: dict = None,
|
||||
level: int = 0,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], List[object]]:
|
||||
) -> Generator[Any, Any, List[object]]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
his function will make a step forward
|
||||
@@ -332,7 +324,7 @@ class NestedExecutor(BaseExecutor):
|
||||
skip_empty_decision: bool = True,
|
||||
align_range_limit: bool = True,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
@@ -411,7 +403,7 @@ class NestedExecutor(BaseExecutor):
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
level: int = 0,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]]:
|
||||
) -> Generator[Any, Any, Tuple[List[object], dict]]:
|
||||
execute_result = []
|
||||
inner_order_indicators = []
|
||||
decision_list = []
|
||||
@@ -492,8 +484,9 @@ class NestedExecutor(BaseExecutor):
|
||||
inner_exe_res :
|
||||
the execution result of inner task
|
||||
"""
|
||||
self.inner_strategy.post_exe_step(inner_exe_res)
|
||||
|
||||
def get_all_executors(self) -> List[object]:
|
||||
def get_all_executors(self) -> List[BaseExecutor]:
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
return [self, *self.inner_executor.get_all_executors()]
|
||||
|
||||
@@ -536,7 +529,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
track_data: bool = False,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_type: str = TT_SERIAL,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
@@ -598,7 +591,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
trade_start_time, _ = self.trade_calendar.get_step_time()
|
||||
execute_result = []
|
||||
execute_result: list = []
|
||||
|
||||
for order in self._get_order_iterator(trade_decision):
|
||||
# execute the order.
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Dict, Iterable, List, Text, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -19,7 +21,7 @@ from ..utils.time import Freq, is_single_value
|
||||
|
||||
|
||||
class BaseQuote:
|
||||
def __init__(self, quote_df: pd.DataFrame, freq):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
|
||||
self.logger = get_module_logger("online operator", level=logging.INFO)
|
||||
|
||||
def get_all_stock(self) -> Iterable:
|
||||
@@ -39,7 +41,7 @@ class BaseQuote:
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
field: Union[str],
|
||||
method: Union[str, None] = None,
|
||||
method: Optional[str] = None,
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
"""get the specific field of stock data during start time and end_time,
|
||||
and apply method to the data.
|
||||
@@ -99,7 +101,7 @@ class BaseQuote:
|
||||
|
||||
|
||||
class PandasQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
|
||||
super().__init__(quote_df=quote_df, freq=freq)
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
@@ -124,7 +126,7 @@ class PandasQuote(BaseQuote):
|
||||
|
||||
|
||||
class NumpyQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq, region="cn"):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq: str, region: str = "cn") -> None:
|
||||
"""NumpyQuote
|
||||
|
||||
Parameters
|
||||
@@ -178,7 +180,8 @@ class NumpyQuote(BaseQuote):
|
||||
data = self._agg_data(data, method)
|
||||
return data
|
||||
|
||||
def _agg_data(self, data: IndexData, method):
|
||||
@staticmethod
|
||||
def _agg_data(data: IndexData, method: str) -> Union[IndexData, np.ndarray, None]:
|
||||
"""Agg data by specific method."""
|
||||
# FIXME: why not call the method of data directly?
|
||||
if method == "sum":
|
||||
@@ -224,31 +227,31 @@ class BaseSingleMetric:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `__init__` method")
|
||||
|
||||
def __add__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __add__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__add__` method")
|
||||
|
||||
def __radd__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __radd__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
return self + other
|
||||
|
||||
def __sub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __sub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__sub__` method")
|
||||
|
||||
def __rsub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __rsub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__rsub__` method")
|
||||
|
||||
def __mul__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __mul__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__mul__` method")
|
||||
|
||||
def __truediv__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __truediv__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__truediv__` method")
|
||||
|
||||
def __eq__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __eq__(self, other: object) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__eq__` method")
|
||||
|
||||
def __gt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __gt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__gt__` method")
|
||||
|
||||
def __lt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
|
||||
def __lt__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `__lt__` method")
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -265,7 +268,7 @@ class BaseSingleMetric:
|
||||
|
||||
raise NotImplementedError(f"Please implement the `count` method")
|
||||
|
||||
def abs(self) -> "BaseSingleMetric":
|
||||
def abs(self) -> BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `abs` method")
|
||||
|
||||
@property
|
||||
@@ -274,18 +277,18 @@ class BaseSingleMetric:
|
||||
|
||||
raise NotImplementedError(f"Please implement the `empty` method")
|
||||
|
||||
def add(self, other: "BaseSingleMetric", fill_value: float = None) -> "BaseSingleMetric":
|
||||
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
|
||||
"""Replace np.NaN with fill_value in two metrics and add them."""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `add` method")
|
||||
|
||||
def replace(self, replace_dict: dict) -> "BaseSingleMetric":
|
||||
def replace(self, replace_dict: dict) -> BaseSingleMetric:
|
||||
"""Replace the value of metric according to replace_dict."""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `replace` method")
|
||||
|
||||
def apply(self, func: dict) -> "BaseSingleMetric":
|
||||
"""Replace the value of metric with func(metric).
|
||||
def apply(self, func: Callable) -> BaseSingleMetric:
|
||||
"""Replace the value of metric with func (metric).
|
||||
Currently, the func is only qlib/backtest/order/Order.parse_dir.
|
||||
"""
|
||||
|
||||
@@ -304,11 +307,11 @@ class BaseOrderIndicator:
|
||||
to inherit the BaseSingleMetric.
|
||||
"""
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
def __init__(self):
|
||||
self.data = {} # will be created in the subclass
|
||||
self.logger = get_module_logger("online operator")
|
||||
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]):
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
|
||||
"""assign one metric.
|
||||
|
||||
Parameters
|
||||
@@ -328,7 +331,7 @@ class BaseOrderIndicator:
|
||||
|
||||
raise NotImplementedError(f"Please implement the 'assign' method")
|
||||
|
||||
def transfer(self, func: Callable, new_col: str = None) -> Union[None, BaseSingleMetric]:
|
||||
def transfer(self, func: Callable, new_col: str = None) -> Optional[BaseSingleMetric]:
|
||||
"""compute new metric with existing metrics.
|
||||
|
||||
Parameters
|
||||
@@ -352,6 +355,7 @@ class BaseOrderIndicator:
|
||||
tmp_metric = func(**func_kwargs)
|
||||
if new_col is not None:
|
||||
self.data[new_col] = tmp_metric
|
||||
return None
|
||||
else:
|
||||
return tmp_metric
|
||||
|
||||
@@ -372,7 +376,7 @@ class BaseOrderIndicator:
|
||||
|
||||
raise NotImplementedError(f"Please implement the 'get_metric_series' method")
|
||||
|
||||
def get_index_data(self, metric) -> SingleData:
|
||||
def get_index_data(self, metric: str) -> SingleData:
|
||||
"""get one metric with the format of SingleData
|
||||
|
||||
Parameters
|
||||
@@ -389,7 +393,12 @@ class BaseOrderIndicator:
|
||||
raise NotImplementedError(f"Please implement the 'get_index_data' method")
|
||||
|
||||
@staticmethod
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
|
||||
def sum_all_indicators(
|
||||
order_indicator: BaseOrderIndicator,
|
||||
indicators: List[BaseOrderIndicator],
|
||||
metrics: Union[str, List[str]],
|
||||
fill_value: float = 0,
|
||||
) -> None:
|
||||
"""sum indicators with the same metrics.
|
||||
and assign to the order_indicator(BaseOrderIndicator).
|
||||
NOTE: indicators could be a empty list when orders in lower level all fail.
|
||||
@@ -527,16 +536,17 @@ class PandasSingleMetric(SingleMetric):
|
||||
def index(self):
|
||||
return list(self.metric.index)
|
||||
|
||||
def add(self, other, fill_value=None):
|
||||
def add(self, other: BaseSingleMetric, fill_value: float = None) -> PandasSingleMetric:
|
||||
other = cast(PandasSingleMetric, other)
|
||||
return self.__class__(self.metric.add(other.metric, fill_value=fill_value))
|
||||
|
||||
def replace(self, replace_dict: dict):
|
||||
def replace(self, replace_dict: dict) -> PandasSingleMetric:
|
||||
return self.__class__(self.metric.replace(replace_dict))
|
||||
|
||||
def apply(self, func: Callable):
|
||||
def apply(self, func: Callable) -> PandasSingleMetric:
|
||||
return self.__class__(self.metric.apply(func))
|
||||
|
||||
def reindex(self, index, fill_value):
|
||||
def reindex(self, index: Any, fill_value: float) -> PandasSingleMetric:
|
||||
return self.__class__(self.metric.reindex(index, fill_value=fill_value))
|
||||
|
||||
def __repr__(self):
|
||||
@@ -550,13 +560,14 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
||||
Str is the name of metric.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(PandasOrderIndicator, self).__init__()
|
||||
self.data: Dict[str, PandasSingleMetric] = OrderedDict()
|
||||
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]):
|
||||
def assign(self, col: str, metric: Union[dict, pd.Series]) -> None:
|
||||
self.data[col] = PandasSingleMetric(metric)
|
||||
|
||||
def get_index_data(self, metric):
|
||||
def get_index_data(self, metric: str) -> SingleData:
|
||||
if metric in self.data:
|
||||
return idd.SingleData(self.data[metric].metric)
|
||||
else:
|
||||
@@ -572,7 +583,12 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
||||
return {k: v.metric for k, v in self.data.items()}
|
||||
|
||||
@staticmethod
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
|
||||
def sum_all_indicators(
|
||||
order_indicator: BaseOrderIndicator,
|
||||
indicators: List[BaseOrderIndicator],
|
||||
metrics: Union[str, List[str]],
|
||||
fill_value: float = 0,
|
||||
) -> None:
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
for metric in metrics:
|
||||
@@ -592,13 +608,14 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
Str is the name of metric.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(NumpyOrderIndicator, self).__init__()
|
||||
self.data: Dict[str, SingleData] = OrderedDict()
|
||||
|
||||
def assign(self, col: str, metric: dict):
|
||||
def assign(self, col: str, metric: dict) -> None:
|
||||
self.data[col] = idd.SingleData(metric)
|
||||
|
||||
def get_index_data(self, metric):
|
||||
def get_index_data(self, metric: str) -> SingleData:
|
||||
if metric in self.data:
|
||||
return self.data[metric]
|
||||
else:
|
||||
@@ -614,14 +631,18 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
return tmp_metric_dict
|
||||
|
||||
@staticmethod
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
|
||||
def sum_all_indicators(
|
||||
order_indicator: BaseOrderIndicator,
|
||||
indicators: List[BaseOrderIndicator],
|
||||
metrics: Union[str, List[str]],
|
||||
fill_value: float = 0,
|
||||
) -> None:
|
||||
# get all index(stock_id)
|
||||
stocks = set()
|
||||
stock_set: set = set()
|
||||
for indicator in indicators:
|
||||
# set(np.ndarray.tolist()) is faster than set(np.ndarray)
|
||||
stocks = stocks | set(indicator.data[metrics[0]].index.tolist())
|
||||
stocks = list(stocks)
|
||||
stocks.sort()
|
||||
stock_set = stock_set | set(indicator.data[metrics[0]].index.tolist())
|
||||
stocks = sorted(list(stock_set))
|
||||
|
||||
# add metric by index
|
||||
if isinstance(metrics, str):
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -18,9 +18,9 @@ class BasePosition:
|
||||
Please refer to the `Position` class for the position
|
||||
"""
|
||||
|
||||
def __init__(self, *args, cash: float = 0.0, **kwargs) -> None:
|
||||
def __init__(self, *args: Any, cash: float = 0.0, **kwargs: Any) -> None:
|
||||
self._settle_type = self.ST_NO
|
||||
self.position = {}
|
||||
self.position: dict = {}
|
||||
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
|
||||
pass
|
||||
@@ -96,13 +96,13 @@ class BasePosition:
|
||||
def calculate_value(self) -> float:
|
||||
raise NotImplementedError(f"Please implement the `calculate_value` method")
|
||||
|
||||
def get_stock_list(self) -> List:
|
||||
def get_stock_list(self) -> List[str]:
|
||||
"""
|
||||
Get the list of stocks in the position.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_list` method")
|
||||
|
||||
def get_stock_price(self, code) -> float:
|
||||
def get_stock_price(self, code: str) -> float:
|
||||
"""
|
||||
get the latest price of the stock
|
||||
|
||||
@@ -113,7 +113,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_price` method")
|
||||
|
||||
def get_stock_amount(self, code) -> float:
|
||||
def get_stock_amount(self, code: str) -> float:
|
||||
"""
|
||||
get the amount of the stock
|
||||
|
||||
@@ -144,7 +144,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_cash` method")
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
def get_stock_amount_dict(self) -> dict:
|
||||
"""
|
||||
generate stock amount dict {stock_id : amount of stock}
|
||||
|
||||
@@ -155,7 +155,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
|
||||
"""
|
||||
generate stock weight dict {stock_id : value weight of stock in the position}
|
||||
it is meaningful in the beginning or the end of each trade step
|
||||
@@ -174,7 +174,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
|
||||
|
||||
def add_count_all(self, bar) -> None:
|
||||
def add_count_all(self, bar: str) -> None:
|
||||
"""
|
||||
Will be called at the end of each bar on each level
|
||||
|
||||
@@ -195,7 +195,7 @@ class BasePosition:
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
ST_CASH = "cash"
|
||||
ST_NO = None
|
||||
ST_NO = "None" # String is more typehint friendly than None
|
||||
|
||||
def settle_start(self, settle_type: str) -> None:
|
||||
"""
|
||||
@@ -220,10 +220,10 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `settle_commit` method")
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.__dict__.__str__()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.__dict__.__repr__()
|
||||
|
||||
|
||||
@@ -532,7 +532,7 @@ class InfPosition(BasePosition):
|
||||
def calculate_value(self) -> float:
|
||||
raise NotImplementedError(f"InfPosition doesn't support calculating value")
|
||||
|
||||
def get_stock_list(self) -> list:
|
||||
def get_stock_list(self) -> List[str]:
|
||||
raise NotImplementedError(f"InfPosition doesn't support stock list position")
|
||||
|
||||
def get_stock_price(self, code: str) -> float:
|
||||
@@ -545,10 +545,10 @@ class InfPosition(BasePosition):
|
||||
def get_cash(self, include_settle: bool = False) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
def get_stock_amount_dict(self) -> dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def add_count_all(self, bar: str) -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import pathlib
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Text, Tuple, Type, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -15,7 +15,7 @@ from qlib.backtest.exchange import Exchange
|
||||
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from .high_performance_ds import BaseOrderIndicator, BaseSingleMetric, NumpyOrderIndicator
|
||||
|
||||
|
||||
class PortfolioMetrics:
|
||||
@@ -38,7 +38,7 @@ class PortfolioMetrics:
|
||||
update report
|
||||
"""
|
||||
|
||||
def __init__(self, freq: str = "day", benchmark_config: dict = {}):
|
||||
def __init__(self, freq: str = "day", benchmark_config: dict = {}) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -49,13 +49,17 @@ class PortfolioMetrics:
|
||||
- benchmark : Union[str, list, pd.Series]
|
||||
- If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
|
||||
example:
|
||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
||||
print(
|
||||
D.features(D.instruments('csi500'),
|
||||
['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head()
|
||||
)
|
||||
2017-01-04 0.011693
|
||||
2017-01-05 0.000721
|
||||
2017-01-06 -0.004322
|
||||
2017-01-09 0.006874
|
||||
2017-01-10 -0.003350
|
||||
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
- If `benchmark` is list, will use the daily average change of the stock pool in the list as the
|
||||
'bench'.
|
||||
- If `benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000300 CSI300
|
||||
- start_time : Union[str, pd.Timestamp], optional
|
||||
@@ -70,25 +74,26 @@ class PortfolioMetrics:
|
||||
self.init_vars()
|
||||
self.init_bench(freq=freq, benchmark_config=benchmark_config)
|
||||
|
||||
def init_vars(self):
|
||||
self.accounts = OrderedDict() # account position value for each trade time
|
||||
self.returns = OrderedDict() # daily return rate for each trade time
|
||||
self.total_turnovers = OrderedDict() # total turnover for each trade time
|
||||
self.turnovers = OrderedDict() # turnover for each trade time
|
||||
self.total_costs = OrderedDict() # total trade cost for each trade time
|
||||
self.costs = OrderedDict() # trade cost rate for each trade time
|
||||
self.values = OrderedDict() # value for each trade time
|
||||
self.cashes = OrderedDict()
|
||||
self.benches = OrderedDict()
|
||||
self.latest_pm_time = None # pd.TimeStamp
|
||||
def init_vars(self) -> None:
|
||||
self.accounts: dict = OrderedDict() # account position value for each trade time
|
||||
self.returns: dict = OrderedDict() # daily return rate for each trade time
|
||||
self.total_turnovers: dict = OrderedDict() # total turnover for each trade time
|
||||
self.turnovers: dict = OrderedDict() # turnover for each trade time
|
||||
self.total_costs: dict = OrderedDict() # total trade cost for each trade time
|
||||
self.costs: dict = OrderedDict() # trade cost rate for each trade time
|
||||
self.values: dict = OrderedDict() # value for each trade time
|
||||
self.cashes: dict = OrderedDict()
|
||||
self.benches: dict = OrderedDict()
|
||||
self.latest_pm_time: Optional[pd.TimeStamp] = None
|
||||
|
||||
def init_bench(self, freq=None, benchmark_config=None):
|
||||
def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
self.benchmark_config = benchmark_config
|
||||
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
|
||||
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
@staticmethod
|
||||
def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.Series]:
|
||||
if benchmark_config is None:
|
||||
return None
|
||||
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
|
||||
@@ -110,7 +115,12 @@ class PortfolioMetrics:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
|
||||
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
|
||||
def _sample_benchmark(
|
||||
self,
|
||||
bench: pd.Series,
|
||||
trade_start_time: Union[str, pd.Timestamp],
|
||||
trade_end_time: Union[str, pd.Timestamp],
|
||||
) -> Optional[float]:
|
||||
if self.bench is None:
|
||||
return None
|
||||
|
||||
@@ -120,35 +130,35 @@ class PortfolioMetrics:
|
||||
_ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||
return 0.0 if _ret is None else _ret - 1
|
||||
|
||||
def is_empty(self):
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.accounts) == 0
|
||||
|
||||
def get_latest_date(self):
|
||||
def get_latest_date(self) -> pd.Timestamp:
|
||||
return self.latest_pm_time
|
||||
|
||||
def get_latest_account_value(self):
|
||||
def get_latest_account_value(self) -> float:
|
||||
return self.accounts[self.latest_pm_time]
|
||||
|
||||
def get_latest_total_cost(self):
|
||||
def get_latest_total_cost(self) -> Any:
|
||||
return self.total_costs[self.latest_pm_time]
|
||||
|
||||
def get_latest_total_turnover(self):
|
||||
def get_latest_total_turnover(self) -> Any:
|
||||
return self.total_turnovers[self.latest_pm_time]
|
||||
|
||||
def update_portfolio_metrics_record(
|
||||
self,
|
||||
trade_start_time=None,
|
||||
trade_end_time=None,
|
||||
account_value=None,
|
||||
cash=None,
|
||||
return_rate=None,
|
||||
total_turnover=None,
|
||||
turnover_rate=None,
|
||||
total_cost=None,
|
||||
cost_rate=None,
|
||||
stock_value=None,
|
||||
bench_value=None,
|
||||
):
|
||||
trade_start_time: Union[str, pd.Timestamp] = None,
|
||||
trade_end_time: Union[str, pd.Timestamp] = None,
|
||||
account_value: float = None,
|
||||
cash: float = None,
|
||||
return_rate: float = None,
|
||||
total_turnover: float = None,
|
||||
turnover_rate: float = None,
|
||||
total_cost: float = None,
|
||||
cost_rate: float = None,
|
||||
stock_value: float = None,
|
||||
bench_value: float = None,
|
||||
) -> None:
|
||||
# check data
|
||||
if None in [
|
||||
trade_start_time,
|
||||
@@ -185,7 +195,7 @@ class PortfolioMetrics:
|
||||
self.latest_pm_time = trade_start_time
|
||||
# finish pm update in each step
|
||||
|
||||
def generate_portfolio_metrics_dataframe(self):
|
||||
def generate_portfolio_metrics_dataframe(self) -> pd.DataFrame:
|
||||
pm = pd.DataFrame()
|
||||
pm["account"] = pd.Series(self.accounts)
|
||||
pm["return"] = pd.Series(self.returns)
|
||||
@@ -199,19 +209,18 @@ class PortfolioMetrics:
|
||||
pm.index.name = "datetime"
|
||||
return pm
|
||||
|
||||
def save_portfolio_metrics(self, path):
|
||||
def save_portfolio_metrics(self, path: str) -> None:
|
||||
r = self.generate_portfolio_metrics_dataframe()
|
||||
r.to_csv(path)
|
||||
|
||||
def load_portfolio_metrics(self, path):
|
||||
def load_portfolio_metrics(self, path: str) -> None:
|
||||
"""load pm from a file
|
||||
should have format like
|
||||
columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
|
||||
:param
|
||||
path: str/ pathlib.Path()
|
||||
"""
|
||||
path = pathlib.Path(path)
|
||||
with path.open("rb") as f:
|
||||
with pathlib.Path(path).open("rb") as f:
|
||||
r = pd.read_csv(f, index_col=0)
|
||||
r.index = pd.DatetimeIndex(r.index)
|
||||
|
||||
@@ -261,30 +270,30 @@ class Indicator:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, order_indicator_cls=NumpyOrderIndicator):
|
||||
def __init__(self, order_indicator_cls: Type[BaseOrderIndicator] = NumpyOrderIndicator) -> None:
|
||||
self.order_indicator_cls = order_indicator_cls
|
||||
|
||||
# order indicator is metrics for a single order for a specific step
|
||||
self.order_indicator_his = OrderedDict()
|
||||
self.order_indicator_his: dict = OrderedDict()
|
||||
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
|
||||
|
||||
# trade indicator is metrics for all orders for a specific step
|
||||
self.trade_indicator_his = OrderedDict()
|
||||
self.trade_indicator: Dict[str, float] = OrderedDict()
|
||||
self.trade_indicator_his: dict = OrderedDict()
|
||||
self.trade_indicator: Dict[str, Optional[BaseSingleMetric]] = OrderedDict()
|
||||
|
||||
self._trade_calendar = None
|
||||
|
||||
# def reset(self, trade_calendar: TradeCalendarManager):
|
||||
def reset(self):
|
||||
self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
|
||||
def reset(self) -> None:
|
||||
self.order_indicator = self.order_indicator_cls()
|
||||
self.trade_indicator = OrderedDict()
|
||||
# self._trade_calendar = trade_calendar
|
||||
|
||||
def record(self, trade_start_time):
|
||||
def record(self, trade_start_time: Union[str, pd.Timestamp]) -> None:
|
||||
self.order_indicator_his[trade_start_time] = self.get_order_indicator()
|
||||
self.trade_indicator_his[trade_start_time] = self.get_trade_indicator()
|
||||
|
||||
def _update_order_trade_info(self, trade_info: list):
|
||||
def _update_order_trade_info(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
|
||||
amount = dict()
|
||||
deal_amount = dict()
|
||||
trade_price = dict()
|
||||
@@ -313,7 +322,7 @@ class Indicator:
|
||||
self.order_indicator.assign("trade_dir", trade_dir)
|
||||
self.order_indicator.assign("pa", pa)
|
||||
|
||||
def _update_order_fulfill_rate(self):
|
||||
def _update_order_fulfill_rate(self) -> None:
|
||||
def func(deal_amount, amount):
|
||||
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
|
||||
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
|
||||
@@ -322,11 +331,11 @@ class Indicator:
|
||||
|
||||
self.order_indicator.transfer(func, "ffr")
|
||||
|
||||
def update_order_indicators(self, trade_info: list):
|
||||
def update_order_indicators(self, trade_info: List[Tuple[Order, float, float, float]]) -> None:
|
||||
self._update_order_trade_info(trade_info=trade_info)
|
||||
self._update_order_fulfill_rate()
|
||||
|
||||
def _agg_order_trade_info(self, inner_order_indicators: List[Dict[str, pd.Series]]):
|
||||
def _agg_order_trade_info(self, inner_order_indicators: List[BaseOrderIndicator]) -> None:
|
||||
# calculate total trade amount with each inner order indicator.
|
||||
def trade_amount_func(deal_amount, trade_price):
|
||||
return deal_amount * trade_price
|
||||
@@ -355,9 +364,9 @@ class Indicator:
|
||||
|
||||
self.order_indicator.transfer(func_apply, "trade_dir")
|
||||
|
||||
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision):
|
||||
def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision) -> None:
|
||||
# NOTE: these indicator is designed for order execution, so the
|
||||
decision: List[Order] = outer_trade_decision.get_decision()
|
||||
decision: List[Order] = cast(List[Order], outer_trade_decision.get_decision())
|
||||
if len(decision) == 0:
|
||||
self.order_indicator.assign("amount", {})
|
||||
else:
|
||||
@@ -372,7 +381,7 @@ class Indicator:
|
||||
decision: BaseTradeDecision,
|
||||
trade_exchange: Exchange,
|
||||
pa_config: dict = {},
|
||||
):
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""
|
||||
Get the base volume and price information
|
||||
All the base price values are rooted from this function
|
||||
@@ -412,31 +421,35 @@ class Indicator:
|
||||
# NOTE: there are some zeros in the trading price. These cases are known meaningless
|
||||
# for aligning the previous logic, remove it.
|
||||
# remove zero and negative values.
|
||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(np.bool)]
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
|
||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||
# ~(np.NaN < 1e-8) -> ~(False) -> True
|
||||
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
if agg == "vwap":
|
||||
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
|
||||
if isinstance(volume_s, (int, float, np.number)):
|
||||
volume_s = idd.SingleData(volume_s, [trade_start_time])
|
||||
assert isinstance(volume_s, idd.SingleData)
|
||||
volume_s = volume_s.reindex(price_s.index)
|
||||
elif agg == "twap":
|
||||
volume_s = idd.SingleData(1, price_s.index)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
assert isinstance(volume_s, idd.SingleData)
|
||||
base_volume = volume_s.sum()
|
||||
base_price = (price_s * volume_s).sum() / base_volume
|
||||
return base_price, base_volume
|
||||
|
||||
def _agg_base_price(
|
||||
self,
|
||||
inner_order_indicators: List[Dict[str, Union[SingleMetric, idd.SingleData]]],
|
||||
inner_order_indicators: List[BaseOrderIndicator],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
trade_exchange: Exchange,
|
||||
pa_config: dict = {},
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
# NOTE:!!!!
|
||||
# Strong assumption!!!!!!
|
||||
@@ -444,7 +457,7 @@ class Indicator:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inner_order_indicators : List[Dict[str, pd.Series]]
|
||||
inner_order_indicators : List[BaseOrderIndicator]
|
||||
the indicators of account of inner executor
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
a list of decisions according to inner_order_indicators
|
||||
@@ -489,14 +502,17 @@ class Indicator:
|
||||
bv_new = idd.SingleData(bv_new)
|
||||
bp_all.append(bp_new)
|
||||
bv_all.append(bv_new)
|
||||
bp_all = idd.concat(bp_all, axis=1)
|
||||
bv_all = idd.concat(bv_all, axis=1)
|
||||
bp_all_multi_data = idd.concat(bp_all, axis=1)
|
||||
bv_all_multi_data = idd.concat(bv_all, axis=1)
|
||||
|
||||
base_volume = bv_all.sum(axis=1)
|
||||
base_volume = bv_all_multi_data.sum(axis=1)
|
||||
self.order_indicator.assign("base_volume", base_volume.to_dict())
|
||||
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict())
|
||||
self.order_indicator.assign(
|
||||
"base_price",
|
||||
((bp_all_multi_data * bv_all_multi_data).sum(axis=1) / base_volume).to_dict(),
|
||||
)
|
||||
|
||||
def _agg_order_price_advantage(self):
|
||||
def _agg_order_price_advantage(self) -> None:
|
||||
def if_empty_func(trade_price):
|
||||
return trade_price.empty
|
||||
|
||||
@@ -513,12 +529,12 @@ class Indicator:
|
||||
|
||||
def agg_order_indicators(
|
||||
self,
|
||||
inner_order_indicators: List[Dict[str, pd.Series]],
|
||||
inner_order_indicators: List[BaseOrderIndicator],
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
|
||||
outer_trade_decision: BaseTradeDecision,
|
||||
trade_exchange: Exchange,
|
||||
indicator_config={},
|
||||
):
|
||||
indicator_config: dict = {},
|
||||
) -> None:
|
||||
self._agg_order_trade_info(inner_order_indicators)
|
||||
self._update_trade_amount(outer_trade_decision)
|
||||
self._update_order_fulfill_rate()
|
||||
@@ -526,71 +542,66 @@ class Indicator:
|
||||
self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO
|
||||
self._agg_order_price_advantage()
|
||||
|
||||
def _cal_trade_fulfill_rate(self, method="mean"):
|
||||
def _cal_trade_fulfill_rate(self, method: str = "mean") -> Optional[BaseSingleMetric]:
|
||||
if method == "mean":
|
||||
|
||||
def func(ffr):
|
||||
return ffr.mean()
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda ffr: ffr.mean(),
|
||||
)
|
||||
elif method == "amount_weighted":
|
||||
|
||||
def func(ffr, deal_amount):
|
||||
return (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda ffr, deal_amount: (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
|
||||
)
|
||||
elif method == "value_weighted":
|
||||
|
||||
def func(ffr, trade_value):
|
||||
return (ffr * trade_value.abs()).sum() / (trade_value.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda ffr, trade_value: (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"method {method} is not supported!")
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_price_advantage(self, method="mean"):
|
||||
def _cal_trade_price_advantage(self, method: str = "mean") -> Optional[BaseSingleMetric]:
|
||||
if method == "mean":
|
||||
|
||||
def func(pa):
|
||||
return pa.mean()
|
||||
|
||||
return self.order_indicator.transfer(lambda pa: pa.mean())
|
||||
elif method == "amount_weighted":
|
||||
|
||||
def func(pa, deal_amount):
|
||||
return (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda pa, deal_amount: (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()),
|
||||
)
|
||||
elif method == "value_weighted":
|
||||
|
||||
def func(pa, trade_value):
|
||||
return (pa * trade_value.abs()).sum() / (trade_value.abs().sum())
|
||||
|
||||
return self.order_indicator.transfer(
|
||||
lambda pa, trade_value: (pa * trade_value.abs()).sum() / (trade_value.abs().sum()),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"method {method} is not supported!")
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_positive_rate(self):
|
||||
def _cal_trade_positive_rate(self) -> Optional[BaseSingleMetric]:
|
||||
def func(pa):
|
||||
return (pa > 0).sum() / pa.count()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_deal_amount(self):
|
||||
def _cal_deal_amount(self) -> Optional[BaseSingleMetric]:
|
||||
def func(deal_amount):
|
||||
return deal_amount.abs().sum()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_value(self):
|
||||
def _cal_trade_value(self) -> Optional[BaseSingleMetric]:
|
||||
def func(trade_value):
|
||||
return trade_value.abs().sum()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def _cal_trade_order_count(self):
|
||||
def _cal_trade_order_count(self) -> Optional[BaseSingleMetric]:
|
||||
def func(amount):
|
||||
return amount.count()
|
||||
|
||||
return self.order_indicator.transfer(func)
|
||||
|
||||
def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}):
|
||||
def cal_trade_indicators(
|
||||
self,
|
||||
trade_start_time: Union[str, pd.Timestamp],
|
||||
freq: str,
|
||||
indicator_config: dict = {},
|
||||
) -> None:
|
||||
show_indicator = indicator_config.get("show_indicator", False)
|
||||
ffr_config = indicator_config.get("ffr_config", {})
|
||||
pa_config = indicator_config.get("pa_config", {})
|
||||
@@ -608,22 +619,22 @@ class Indicator:
|
||||
self.trade_indicator["count"] = order_count
|
||||
if show_indicator:
|
||||
print(
|
||||
"[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
freq,
|
||||
trade_start_time,
|
||||
trade_start_time
|
||||
if isinstance(trade_start_time, str)
|
||||
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
fulfill_rate,
|
||||
price_advantage,
|
||||
positive_rate,
|
||||
),
|
||||
)
|
||||
|
||||
def get_order_indicator(self, raw: bool = True):
|
||||
if raw:
|
||||
return self.order_indicator
|
||||
return self.order_indicator.to_series()
|
||||
def get_order_indicator(self, raw: bool = True) -> Union[BaseOrderIndicator, Dict[Text, pd.Series]]:
|
||||
return self.order_indicator if raw else self.order_indicator.to_series()
|
||||
|
||||
def get_trade_indicator(self):
|
||||
def get_trade_indicator(self) -> Dict[str, Optional[BaseSingleMetric]]:
|
||||
return self.trade_indicator
|
||||
|
||||
def generate_trade_indicators_dataframe(self):
|
||||
def generate_trade_indicators_dataframe(self) -> pd.DataFrame:
|
||||
return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index")
|
||||
|
||||
@@ -22,7 +22,7 @@ class Signal(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
|
||||
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame, None]:
|
||||
"""
|
||||
get the signal at the end of the decision step(from `start_time` to `end_time`)
|
||||
|
||||
@@ -39,13 +39,14 @@ class SignalWCache(Signal):
|
||||
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
|
||||
"""
|
||||
|
||||
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
|
||||
def __init__(self, signal: Union[pd.Series, pd.DataFrame]) -> None:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : Union[pd.Series, pd.DataFrame]
|
||||
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
|
||||
The expected format of the signal is like the data below (the order of index is not important and can be
|
||||
automatically adjusted)
|
||||
|
||||
instrument datetime
|
||||
SH600000 2008-01-02 0.079704
|
||||
@@ -56,8 +57,8 @@ class SignalWCache(Signal):
|
||||
"""
|
||||
self.signal_cache = convert_index_format(signal, level="datetime")
|
||||
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
|
||||
# the frequency of the signal may not algin with the decision frequency of strategy
|
||||
def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame]:
|
||||
# the frequency of the signal may not align with the decision frequency of strategy
|
||||
# so resampling from the data is necessary
|
||||
# the latest signal leverage more recent data and therefore is used in trading.
|
||||
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
|
||||
@@ -65,7 +66,7 @@ class SignalWCache(Signal):
|
||||
|
||||
|
||||
class ModelSignal(SignalWCache):
|
||||
def __init__(self, model: BaseModel, dataset: Dataset):
|
||||
def __init__(self, model: BaseModel, dataset: Dataset) -> None:
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
pred_scores = self.model.predict(dataset)
|
||||
@@ -73,7 +74,7 @@ class ModelSignal(SignalWCache):
|
||||
pred_scores = pred_scores.iloc[:, 0]
|
||||
super().__init__(pred_scores)
|
||||
|
||||
def _update_model(self):
|
||||
def _update_model(self) -> None:
|
||||
"""
|
||||
When using online data, update model in each bar as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
|
||||
@@ -149,6 +149,8 @@ class TradeCalendarManager:
|
||||
Tuple[int, int]:
|
||||
"""
|
||||
# potential performance issue
|
||||
assert self.level_infra is not None
|
||||
|
||||
day_start = pd.Timestamp(self.start_time.date())
|
||||
day_end = epsilon_change(day_start + pd.Timedelta(days=1))
|
||||
freq = self.level_infra.get("common_infra").get("trade_exchange").freq
|
||||
@@ -182,8 +184,8 @@ class TradeCalendarManager:
|
||||
Tuple[int, int]:
|
||||
the index of the range. **the left and right are closed**
|
||||
"""
|
||||
left = bisect.bisect_right(self._calendar, start_time) - 1
|
||||
right = bisect.bisect_right(self._calendar, end_time) - 1
|
||||
left = bisect.bisect_right(list(self._calendar), start_time) - 1
|
||||
right = bisect.bisect_right(list(self._calendar), end_time) - 1
|
||||
left -= self.start_index
|
||||
right -= self.start_index
|
||||
|
||||
@@ -201,14 +203,14 @@ class TradeCalendarManager:
|
||||
|
||||
|
||||
class BaseInfrastructure:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.reset_infra(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def get_support_infra(self) -> Set[str]:
|
||||
raise NotImplementedError("`get_support_infra` is not implemented!")
|
||||
|
||||
def reset_infra(self, **kwargs) -> None:
|
||||
def reset_infra(self, **kwargs: Any) -> None:
|
||||
support_infra = self.get_support_infra()
|
||||
for k, v in kwargs.items():
|
||||
if k in support_infra:
|
||||
|
||||
@@ -203,8 +203,14 @@ class MTSDatasetH(DatasetH):
|
||||
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
fn = _get_date_parse_fn(self._index[0][1])
|
||||
start_date = fn(slc.start)
|
||||
end_date = fn(slc.stop)
|
||||
if isinstance(slc, slice):
|
||||
start, stop = slc.start, slc.stop
|
||||
elif isinstance(slc, (list, tuple)):
|
||||
start, stop = slc
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
start_date = pd.Timestamp(fn(start))
|
||||
end_date = pd.Timestamp(fn(stop))
|
||||
obj = copy.copy(self) # shallow copy
|
||||
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
|
||||
obj._data = self._data # reference (no copy)
|
||||
|
||||
@@ -259,79 +259,119 @@ class Alpha158(DataHandlerLP):
|
||||
def use(x):
|
||||
return x not in exclude and (include is None or x in include)
|
||||
|
||||
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
|
||||
if use("ROC"):
|
||||
# https://www.investopedia.com/terms/r/rateofchange.asp
|
||||
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
|
||||
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
# The rate of close price change in the past d days, divided by latest close price to remove unit
|
||||
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
# The R-sqaure value of linear regression for the past d days, represent the trend linear
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
# The max price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
# The low price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
# Used with MIN and MAX
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
# Get the percentile of current close price in past d day's close price.
|
||||
# Represent the current price level comparing to past N days, add additional information to moving average.
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
# Represent the price position between upper and lower resistent price for past d days.
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
# The number of days between current date and previous highest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
# The number of days between current date and previous lowest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
# The time period between previous lowest-price date occur after highest price date.
|
||||
# Large value suggest downward momemtum.
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
# The correlation between absolute close price and log scaled trading volume
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
# The correlation between price change ratio and volume change ratio
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
# The percentage of days in past d days that price go up.
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
# The percentage of days in past d days that price go down.
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
# The diff between past up day and past down day
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
# The total gain / the absolute total price changed
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
# The total lose / the absolute total price changed
|
||||
# Can be derived from SUMP by SUMN = 1 - SUMP
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
# The diff ratio between total gain and total lose
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
@@ -339,12 +379,15 @@ class Alpha158(DataHandlerLP):
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
# The standard deviation for volume in past d days.
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
# The volume weighted price change volatility
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
@@ -352,6 +395,7 @@ class Alpha158(DataHandlerLP):
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
@@ -359,6 +403,8 @@ class Alpha158(DataHandlerLP):
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
@@ -366,6 +412,8 @@ class Alpha158(DataHandlerLP):
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
# The diff ratio between total volume increase and total volume decrease
|
||||
# RSI indicator for volume
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
|
||||
@@ -137,8 +137,7 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
# template_paused = "{0}"
|
||||
template_paused = "Select(Gt($paused_num, 1.001), {0})"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
@@ -162,3 +161,249 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
names += ["$factor0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqOrderHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
drop_raw=drop_raw,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_ifinf = "If(IsInf({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($paused_num, 1.001), {0})"
|
||||
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
# norm with the close price of 237th minute of yesterday.
|
||||
if shift == 0:
|
||||
template_norm = "{0}/DayLast(Ref({1}, 243))"
|
||||
else:
|
||||
template_norm = "Ref({0}, " + str(shift) + ")/DayLast(Ref({1}, 243))"
|
||||
|
||||
template_fillnan = "FFillNan({0})"
|
||||
# calculate -> ffill -> remove paused
|
||||
feature_ops = template_paused.format(
|
||||
template_fillnan.format(
|
||||
template_norm.format(template_if.format("$close", price_field), template_fillnan.format("$close"))
|
||||
)
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
def get_normalized_vwap_price_feature(price_field, shift=0):
|
||||
# norm with the close price of 237th minute of yesterday.
|
||||
if shift == 0:
|
||||
template_norm = "{0}/DayLast(Ref({1}, 243))"
|
||||
else:
|
||||
template_norm = "Ref({0}, " + str(shift) + ")/DayLast(Ref({1}, 243))"
|
||||
|
||||
template_fillnan = "FFillNan({0})"
|
||||
# calculate -> ffill -> remove paused
|
||||
feature_ops = template_paused.format(
|
||||
template_fillnan.format(
|
||||
template_norm.format(
|
||||
template_if.format("$close", template_ifinf.format("$close", price_field)),
|
||||
template_fillnan.format("$close"),
|
||||
)
|
||||
)
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 0)]
|
||||
fields += [get_normalized_price_feature("$high", 0)]
|
||||
fields += [get_normalized_price_feature("$low", 0)]
|
||||
fields += [get_normalized_price_feature("$close", 0)]
|
||||
fields += [get_normalized_vwap_price_feature("$vwap", 0)]
|
||||
names += ["$open", "$high", "$low", "$close", "$vwap"]
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 240)]
|
||||
fields += [get_normalized_price_feature("$high", 240)]
|
||||
fields += [get_normalized_price_feature("$low", 240)]
|
||||
fields += [get_normalized_price_feature("$close", 240)]
|
||||
fields += [get_normalized_vwap_price_feature("$vwap", 240)]
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
|
||||
fields += [get_normalized_price_feature("$bid", 0)]
|
||||
fields += [get_normalized_price_feature("$ask", 0)]
|
||||
names += ["$bid", "$ask"]
|
||||
|
||||
fields += [get_normalized_price_feature("$bid", 240)]
|
||||
fields += [get_normalized_price_feature("$ask", 240)]
|
||||
names += ["$bid_1", "$ask_1"]
|
||||
|
||||
# calculate and fill nan with 0
|
||||
|
||||
def get_volume_feature(volume_field, shift=0):
|
||||
template_gzero = "If(Ge({0}, 0), {0}, 0)"
|
||||
if shift == 0:
|
||||
feature_ops = template_gzero.format(
|
||||
template_paused.format(
|
||||
"If(IsInf({0}), 0, {0})".format(
|
||||
"If(IsNull({0}), 0, {0})".format(
|
||||
"{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format(volume_field)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
feature_ops = template_gzero.format(
|
||||
template_paused.format(
|
||||
"If(IsInf({0}), 0, {0})".format(
|
||||
"If(IsNull({0}), 0, {0})".format(
|
||||
f"Ref({{0}}, {shift})/Ref(DayLast(Mean({{0}}, 7200)), 240)".format(volume_field)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
fields += [get_volume_feature("$volume", 0)]
|
||||
names += ["$volume"]
|
||||
|
||||
fields += [get_volume_feature("$volume", 240)]
|
||||
names += ["$volume_1"]
|
||||
|
||||
fields += [get_volume_feature("$bidV", 0)]
|
||||
fields += [get_volume_feature("$bidV1", 0)]
|
||||
fields += [get_volume_feature("$bidV3", 0)]
|
||||
fields += [get_volume_feature("$bidV5", 0)]
|
||||
fields += [get_volume_feature("$askV", 0)]
|
||||
fields += [get_volume_feature("$askV1", 0)]
|
||||
fields += [get_volume_feature("$askV3", 0)]
|
||||
fields += [get_volume_feature("$askV5", 0)]
|
||||
names += ["$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"]
|
||||
|
||||
fields += [get_volume_feature("$bidV", 240)]
|
||||
fields += [get_volume_feature("$bidV1", 240)]
|
||||
fields += [get_volume_feature("$bidV3", 240)]
|
||||
fields += [get_volume_feature("$bidV5", 240)]
|
||||
fields += [get_volume_feature("$askV", 240)]
|
||||
fields += [get_volume_feature("$askV1", 240)]
|
||||
fields += [get_volume_feature("$askV3", 240)]
|
||||
fields += [get_volume_feature("$askV5", 240)]
|
||||
names += ["$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqBacktestOrderHandler(DataHandler):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
# template_paused = "{0}"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
|
||||
fields += [
|
||||
template_paused.format(
|
||||
template_if.format(
|
||||
template_fillnan.format("$close"),
|
||||
"$vwap",
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
|
||||
names += ["$volume0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$bid"))]
|
||||
names += ["$bid0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$bidV"))]
|
||||
names += ["$bidV0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$ask"))]
|
||||
names += ["$ask0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$askV"))]
|
||||
names += ["$askV0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("($bid + $ask) / 2"))]
|
||||
names += ["$median0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$factor"))]
|
||||
names += ["$factor0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$downlimitmarket"))]
|
||||
names += ["$downlimitmarket0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$uplimitmarket"))]
|
||||
names += ["$uplimitmarket0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$highmarket"))]
|
||||
names += ["$highmarket0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$lowmarket"))]
|
||||
names += ["$lowmarket0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
@@ -339,7 +339,7 @@ def long_short_backtest(
|
||||
for stock in long_stocks:
|
||||
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
|
||||
continue
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||
if np.isnan(profit):
|
||||
long_profit.append(0)
|
||||
else:
|
||||
@@ -348,17 +348,17 @@ def long_short_backtest(
|
||||
for stock in short_stocks:
|
||||
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
|
||||
continue
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||
if np.isnan(profit):
|
||||
short_profit.append(0)
|
||||
else:
|
||||
short_profit.append(-profit)
|
||||
short_profit.append(profit * -1)
|
||||
|
||||
for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):
|
||||
# exclude the suspend stock
|
||||
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
|
||||
continue
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
|
||||
profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str)
|
||||
if np.isnan(profit):
|
||||
all_profit.append(0)
|
||||
else:
|
||||
|
||||
@@ -44,7 +44,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
if sample_ratios is None: # the default values for sample_ratios
|
||||
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
|
||||
if sub_weights is None: # the default values for sub_weights
|
||||
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
|
||||
sub_weights = [1] * self.num_models
|
||||
if not len(sample_ratios) == bins_fs:
|
||||
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
|
||||
self.sample_ratios = sample_ratios
|
||||
@@ -87,7 +87,9 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
loss_curve = self.retrieve_loss_curve(model_k, df_train, features)
|
||||
pred_k = self.predict_sub(model_k, df_train, features)
|
||||
pred_sub.iloc[:, k] = pred_k
|
||||
pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1)
|
||||
pred_ensemble = (pred_sub.iloc[:, : k + 1] * self.sub_weights[0 : k + 1]).sum(axis=1) / np.sum(
|
||||
self.sub_weights[0 : k + 1]
|
||||
)
|
||||
loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values))
|
||||
|
||||
if self.enable_sr:
|
||||
@@ -159,8 +161,8 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
|
||||
h_avg = h.groupby("bins")["h_value"].mean()
|
||||
weights = pd.Series(np.zeros(N, dtype=float))
|
||||
for i_b, b in enumerate(h_avg.index):
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay**k_th * h_avg[i_b] + 0.1)
|
||||
for b in h_avg.index:
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay**k_th * h_avg[b] + 0.1)
|
||||
return weights
|
||||
|
||||
def feature_selection(self, df_train, loss_values):
|
||||
@@ -246,6 +248,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index)
|
||||
* self.sub_weights[i_sub]
|
||||
)
|
||||
pred = pred / np.sum(self.sub_weights)
|
||||
return pred
|
||||
|
||||
def predict_sub(self, submodel, df_data, features):
|
||||
|
||||
@@ -104,9 +104,9 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
only_tradable : bool
|
||||
will the strategy only consider the tradable stock when buying and selling.
|
||||
if only_tradable:
|
||||
strategy will make buy sell decision without checking the tradable state of the stock.
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
else:
|
||||
strategy will make buy sell decision without checking the tradable state of the stock.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.topk = topk
|
||||
|
||||
@@ -108,14 +108,16 @@ class CalendarProvider(abc.ABC):
|
||||
_, _, si, ei = self.locate_index(start_time, end_time, freq, future)
|
||||
return _calendar[si : ei + 1]
|
||||
|
||||
def locate_index(self, start_time, end_time, freq, future=False):
|
||||
def locate_index(
|
||||
self, start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], freq: str, future: bool = False
|
||||
):
|
||||
"""Locate the start time index and end time index in a calendar under certain frequency.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : str
|
||||
start_time : pd.Timestamp
|
||||
start of the time range.
|
||||
end_time : str
|
||||
end_time : pd.Timestamp
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency, available: year/quarter/month/week/day.
|
||||
|
||||
@@ -32,6 +32,7 @@ except ValueError:
|
||||
|
||||
np.seterr(invalid="ignore")
|
||||
|
||||
|
||||
#################### Element-Wise Operator ####################
|
||||
|
||||
|
||||
@@ -62,6 +63,39 @@ class ElemOperator(ExpressionOps):
|
||||
return self.feature.get_extended_window_size()
|
||||
|
||||
|
||||
class ChangeInstrument(ElemOperator):
|
||||
"""Change Instrument Operator
|
||||
In some case, one may want to change to another instrument when calculating, for example, to
|
||||
calculate beta of a stock with respect to a market index.
|
||||
This would require changing the calculation of features from the stock (original instrument) to
|
||||
the index (reference instrument)
|
||||
Parameters
|
||||
----------
|
||||
instrument: new instrument for which the downstream operations should be performed upon.
|
||||
i.e., SH000300 (CSI300 index), or ^GPSC (SP500 index).
|
||||
|
||||
feature: the feature to be calculated for the new instrument.
|
||||
Returns
|
||||
----------
|
||||
Expression
|
||||
feature operation output
|
||||
"""
|
||||
|
||||
def __init__(self, instrument, feature):
|
||||
self.instrument = instrument
|
||||
self.feature = feature
|
||||
|
||||
def __str__(self):
|
||||
return "{}('{}',{})".format(type(self).__name__, self.instrument, self.feature)
|
||||
|
||||
def load(self, instrument, start_index, end_index, *args):
|
||||
# the first `instrument` is ignored
|
||||
return super().load(self.instrument, start_index, end_index, *args)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
return self.feature.load(instrument, start_index, end_index, *args)
|
||||
|
||||
|
||||
class NpElemOperator(ElemOperator):
|
||||
"""Numpy Element-wise Operator
|
||||
|
||||
@@ -1535,6 +1569,7 @@ class TResample(ElemOperator):
|
||||
|
||||
TOpsList = [TResample]
|
||||
OpsList = [
|
||||
ChangeInstrument,
|
||||
Rolling,
|
||||
Ref,
|
||||
Max,
|
||||
|
||||
@@ -102,14 +102,22 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
self._freq_file_cache = freq
|
||||
return self._freq_file_cache
|
||||
|
||||
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
|
||||
def _read_calendar(self) -> List[CalVT]:
|
||||
# NOTE:
|
||||
# if we want to accelerate partial reading calendar
|
||||
# we can add parameters like `skip_rows: int = 0, n_rows: int = None` to the interface.
|
||||
# Currently, it is not supported for the txt-based calendar
|
||||
|
||||
if not self.uri.exists():
|
||||
self._write_calendar(values=[])
|
||||
with self.uri.open("rb") as fp:
|
||||
return [
|
||||
str(x)
|
||||
for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8")
|
||||
]
|
||||
|
||||
with self.uri.open("r") as fp:
|
||||
res = []
|
||||
for line in fp.readlines():
|
||||
line = line.strip()
|
||||
if len(line) > 0:
|
||||
res.append(line)
|
||||
return res
|
||||
|
||||
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
|
||||
with self.uri.open(mode=mode) as fp:
|
||||
|
||||
@@ -12,7 +12,7 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
|
||||
"""
|
||||
|
||||
import socket
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
@@ -219,7 +219,13 @@ class TrainerR(Trainer):
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, train_func: Callable = task_train, call_in_subproc: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: Optional[str] = None,
|
||||
train_func: Callable = task_train,
|
||||
call_in_subproc: bool = False,
|
||||
default_rec_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
@@ -230,6 +236,7 @@ class TrainerR(Trainer):
|
||||
"""
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.default_rec_name = default_rec_name
|
||||
self.train_func = train_func
|
||||
self._call_in_subproc = call_in_subproc
|
||||
|
||||
@@ -259,7 +266,7 @@ class TrainerR(Trainer):
|
||||
if self._call_in_subproc:
|
||||
get_module_logger("TrainerR").info("running models in sub process (for forcing release memroy).")
|
||||
train_func = call_in_subproc(train_func, C)
|
||||
rec = train_func(task, experiment_name, **kwargs)
|
||||
rec = train_func(task, experiment_name, recorder_name=self.default_rec_name, **kwargs)
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
@@ -286,7 +293,9 @@ class DelayTrainerR(TrainerR):
|
||||
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train):
|
||||
def __init__(
|
||||
self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train, **kwargs
|
||||
):
|
||||
"""
|
||||
Init TrainerRM.
|
||||
|
||||
@@ -295,7 +304,7 @@ class DelayTrainerR(TrainerR):
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
"""
|
||||
super().__init__(experiment_name, train_func)
|
||||
super().__init__(experiment_name, train_func, **kwargs)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
|
||||
@@ -344,7 +353,12 @@ class TrainerRM(Trainer):
|
||||
TM_ID = "_id in TaskManager"
|
||||
|
||||
def __init__(
|
||||
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
|
||||
self,
|
||||
experiment_name: str = None,
|
||||
task_pool: str = None,
|
||||
train_func=task_train,
|
||||
skip_run_task: bool = False,
|
||||
default_rec_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Init TrainerR.
|
||||
@@ -363,6 +377,7 @@ class TrainerRM(Trainer):
|
||||
self.task_pool = task_pool
|
||||
self.train_func = train_func
|
||||
self.skip_run_task = skip_run_task
|
||||
self.default_rec_name = default_rec_name
|
||||
|
||||
def train(
|
||||
self,
|
||||
@@ -371,6 +386,7 @@ class TrainerRM(Trainer):
|
||||
experiment_name: str = None,
|
||||
before_status: str = TaskManager.STATUS_WAITING,
|
||||
after_status: str = TaskManager.STATUS_DONE,
|
||||
default_rec_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List[Recorder]:
|
||||
"""
|
||||
@@ -398,6 +414,8 @@ class TrainerRM(Trainer):
|
||||
train_func = self.train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
if default_rec_name is None:
|
||||
default_rec_name = self.default_rec_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
@@ -412,6 +430,7 @@ class TrainerRM(Trainer):
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
recorder_name=default_rec_name,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -480,6 +499,7 @@ class DelayTrainerRM(TrainerRM):
|
||||
train_func=begin_task_train,
|
||||
end_train_func=end_task_train,
|
||||
skip_run_task: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Init DelayTrainerRM.
|
||||
@@ -494,7 +514,7 @@ class DelayTrainerRM(TrainerRM):
|
||||
Only run_task in the worker. Otherwise skip run_task.
|
||||
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
|
||||
"""
|
||||
super().__init__(experiment_name, task_pool, train_func)
|
||||
super().__init__(experiment_name, task_pool, train_func, **kwargs)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
self.skip_run_task = skip_run_task
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, TYPE_CHECKING, TypeVar
|
||||
from typing import Optional, TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
@@ -21,7 +21,7 @@ AuxInfoType = TypeVar("AuxInfoType")
|
||||
class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
|
||||
"""Override this class to collect customized auxiliary information from environment."""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@final
|
||||
def __call__(self, simulator_state: StateType) -> AuxInfoType:
|
||||
|
||||
58
qlib/rl/data/exchange_wrapper.py
Normal file
58
qlib/rl/data/exchange_wrapper.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from .pickle_styled import IntradayBacktestData
|
||||
|
||||
|
||||
class QlibIntradayBacktestData(IntradayBacktestData):
|
||||
"""Backtest data for Qlib simulator"""
|
||||
|
||||
def __init__(self, order: Order, exchange: Exchange, start_time: pd.Timestamp, end_time: pd.Timestamp) -> None:
|
||||
super(QlibIntradayBacktestData, self).__init__()
|
||||
self._order = order
|
||||
self._exchange = exchange
|
||||
self._start_time = start_time
|
||||
self._end_time = end_time
|
||||
|
||||
self._deal_price = cast(
|
||||
pd.Series,
|
||||
self._exchange.get_deal_price(
|
||||
self._order.stock_id,
|
||||
self._start_time,
|
||||
self._end_time,
|
||||
direction=self._order.direction,
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
self._volume = cast(
|
||||
pd.Series,
|
||||
self._exchange.get_volume(
|
||||
self._order.stock_id,
|
||||
self._start_time,
|
||||
self._end_time,
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Order: {self._order}, Exchange: {self._exchange}, "
|
||||
f"Start time: {self._start_time}, End time: {self._end_time}"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._deal_price)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
return self._deal_price
|
||||
|
||||
def get_volume(self) -> pd.Series:
|
||||
return self._volume
|
||||
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
|
||||
@@ -19,19 +19,19 @@ This file shows resemblence to qlib.backtest.high_performance_ds. We might merge
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import List, Sequence, cast
|
||||
from pathlib import Path
|
||||
from typing import List, Sequence, cast
|
||||
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from qlib.backtest.decision import OrderDir, Order
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
|
||||
"""Several ad-hoc deal price.
|
||||
``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.
|
||||
@@ -40,7 +40,7 @@ DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
|
||||
"""
|
||||
|
||||
|
||||
def _infer_processed_data_column_names(shape: int) -> list[str]:
|
||||
def _infer_processed_data_column_names(shape: int) -> List[str]:
|
||||
if shape == 16:
|
||||
return [
|
||||
"$open",
|
||||
@@ -87,7 +87,36 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
|
||||
|
||||
|
||||
class IntradayBacktestData:
|
||||
"""Raw market data that is often used in backtesting (thus called BacktestData)."""
|
||||
"""
|
||||
Raw market data that is often used in backtesting (thus called BacktestData).
|
||||
|
||||
Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest
|
||||
data type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_volume(self) -> pd.Series:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SimpleIntradayBacktestData(IntradayBacktestData):
|
||||
"""Backtest data for simple simulator"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -95,8 +124,10 @@ class IntradayBacktestData:
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
deal_price: DealPriceType = "close",
|
||||
order_dir: int | None = None,
|
||||
):
|
||||
order_dir: int = None,
|
||||
) -> None:
|
||||
super(SimpleIntradayBacktestData, self).__init__()
|
||||
|
||||
backtest = _read_pickle(data_dir / stock_id)
|
||||
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
|
||||
@@ -105,13 +136,13 @@ class IntradayBacktestData:
|
||||
|
||||
self.data: pd.DataFrame = backtest
|
||||
self.deal_price_type: DealPriceType = deal_price
|
||||
self.order_dir: int | None = order_dir
|
||||
self.order_dir = order_dir
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.data})"
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
@@ -162,7 +193,14 @@ class IntradayProcessedData:
|
||||
"""Processed data for "yesterday".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
|
||||
def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> None:
|
||||
proc = _read_pickle(data_dir / stock_id)
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
@@ -190,16 +228,20 @@ class IntradayProcessedData:
|
||||
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
|
||||
assert len(self.today) == len(self.yesterday) == time_length
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@lru_cache(maxsize=100) # 100 * 50K = 5MB
|
||||
def load_intraday_backtest_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None
|
||||
) -> IntradayBacktestData:
|
||||
return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
def load_simple_intraday_backtest_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
deal_price: DealPriceType = "close",
|
||||
order_dir: int = None,
|
||||
) -> SimpleIntradayBacktestData:
|
||||
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
@@ -207,13 +249,19 @@ def load_intraday_backtest_data(
|
||||
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_intraday_processed_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> IntradayProcessedData:
|
||||
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
|
||||
|
||||
def load_orders(
|
||||
order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None
|
||||
order_path: Path,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> Sequence[Order]:
|
||||
"""Load orders, and set start time and end time for the orders."""
|
||||
|
||||
@@ -248,10 +296,10 @@ def load_orders(
|
||||
Order(
|
||||
row["instrument"],
|
||||
row["amount"],
|
||||
int(row["order_type"]),
|
||||
OrderDir(int(row["order_type"])),
|
||||
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
|
||||
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return orders
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Train, test, inference utilities.
|
||||
|
||||
The APIs in this directory are NOT considered final and are subject to change!
|
||||
"""
|
||||
@@ -1,99 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.constant import INF
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
|
||||
|
||||
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
def backtest(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
) -> None:
|
||||
"""Backtest with the parallelism provided by RL framework.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
simulator_fn
|
||||
Callable receiving initial seed, returning a simulator.
|
||||
state_interpreter
|
||||
Interprets the state of simulators.
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
policy
|
||||
Policy to test against.
|
||||
logger
|
||||
Logger to record the backtest results. Logger must be present because
|
||||
without logger, all information will be lost.
|
||||
reward
|
||||
Optional reward function. For backtest, this is for testing the rewards
|
||||
and logging them only.
|
||||
finite_env_type
|
||||
Type of finite env implementation.
|
||||
concurrency
|
||||
Parallel workers.
|
||||
"""
|
||||
|
||||
# To save bandwidth
|
||||
min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
|
||||
|
||||
def env_factory():
|
||||
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
|
||||
# and could be thread unsafe.
|
||||
# I'm not sure whether it's a design flaw.
|
||||
# I'll rethink about this when designing the trainer.
|
||||
|
||||
if finite_env_type == "dummy":
|
||||
# We could only experience the "threading-unsafe" problem in dummy.
|
||||
state = copy.deepcopy(state_interpreter)
|
||||
action = copy.deepcopy(action_interpreter)
|
||||
rew = copy.deepcopy(reward)
|
||||
else:
|
||||
state, action, rew = state_interpreter, action_interpreter, reward
|
||||
|
||||
return EnvWrapper(
|
||||
simulator_fn,
|
||||
state,
|
||||
action,
|
||||
seed_iterator,
|
||||
rew,
|
||||
logger=LogCollector(min_loglevel=min_loglevel),
|
||||
)
|
||||
|
||||
with DataQueue(initial_states) as seed_iterator:
|
||||
vector_env = vectorize_env(
|
||||
env_factory,
|
||||
finite_env_type,
|
||||
concurrency,
|
||||
logger,
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
test_collector = Collector(policy, vector_env)
|
||||
_logger.info("All ready. Start backtest.")
|
||||
test_collector.collect(n_step=INF * len(vector_env))
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# TBD
|
||||
# TODO: find a better way to organize contents under this module.
|
||||
20
qlib/rl/from_neutrader/config.py
Normal file
20
qlib/rl/from_neutrader/config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
# TODO: In the future we should merge the dataclass-based config with Qlib's dict-based config.
|
||||
@dataclass
|
||||
class ExchangeConfig:
|
||||
limit_threshold: Union[float, Tuple[str, str]]
|
||||
deal_price: Union[str, Tuple[str, str]]
|
||||
volume_threshold: dict
|
||||
open_cost: float = 0.0005
|
||||
close_cost: float = 0.0015
|
||||
min_cost: float = 5.0
|
||||
trade_unit: Optional[float] = 100.0
|
||||
cash_limit: Optional[Union[Path, float]] = None
|
||||
generate_report: bool = False
|
||||
109
qlib/rl/from_neutrader/feature.py
Normal file
109
qlib/rl/from_neutrader/feature.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import collections
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, pool_size: int = 200):
|
||||
self.pool_size = pool_size
|
||||
self.contents: dict = {}
|
||||
self.keys: collections.deque = collections.deque()
|
||||
|
||||
def put(self, key, item):
|
||||
if self.has(key):
|
||||
self.keys.remove(key)
|
||||
self.keys.append(key)
|
||||
self.contents[key] = item
|
||||
while len(self.contents) > self.pool_size:
|
||||
self.contents.pop(self.keys.popleft())
|
||||
|
||||
def get(self, key):
|
||||
return self.contents[key]
|
||||
|
||||
def has(self, key):
|
||||
return key in self.contents
|
||||
|
||||
|
||||
class DataWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
feature_dataset: DatasetH,
|
||||
backtest_dataset: DatasetH,
|
||||
columns_today: List[str],
|
||||
columns_yesterday: List[str],
|
||||
_internal: bool = False,
|
||||
):
|
||||
assert _internal, "Init function of data wrapper is for internal use only."
|
||||
|
||||
self.feature_dataset = feature_dataset
|
||||
self.backtest_dataset = backtest_dataset
|
||||
self.columns_today = columns_today
|
||||
self.columns_yesterday = columns_yesterday
|
||||
|
||||
# TODO: We might have the chance to merge them.
|
||||
self.feature_cache = LRUCache()
|
||||
self.backtest_cache = LRUCache()
|
||||
|
||||
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
|
||||
if backtest:
|
||||
dataset = self.backtest_dataset
|
||||
cache = self.backtest_cache
|
||||
else:
|
||||
dataset = self.feature_dataset
|
||||
cache = self.feature_cache
|
||||
|
||||
if cache.has((start_time, end_time, stock_id)):
|
||||
return cache.get((start_time, end_time, stock_id))
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
cache.put((start_time, end_time, stock_id), data)
|
||||
return data
|
||||
|
||||
|
||||
def init_qlib(config: dict, part: Optional[str] = None) -> None:
|
||||
provider_uri_map = {
|
||||
"day": config["provider_uri_day"].as_posix(),
|
||||
"1min": config["provider_uri_1min"].as_posix(),
|
||||
}
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum],
|
||||
expression_cache=None,
|
||||
calendar_provider={
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
},
|
||||
},
|
||||
},
|
||||
feature_provider={
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
},
|
||||
},
|
||||
},
|
||||
provider_uri=provider_uri_map,
|
||||
kernels=1,
|
||||
redis_port=-1,
|
||||
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
|
||||
)
|
||||
@@ -3,13 +3,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic, Any
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
from .simulator import StateType, ActType
|
||||
from .simulator import ActType, StateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import EnvWrapper
|
||||
@@ -40,7 +40,7 @@ class Interpreter:
|
||||
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gym.Space:
|
||||
@@ -74,7 +74,7 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
|
||||
env: "EnvWrapper" | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
@@ -141,10 +141,10 @@ def _gym_space_contains(space: gym.Space, x: Any) -> None:
|
||||
|
||||
|
||||
class GymSpaceValidationError(Exception):
|
||||
def __init__(self, message: str, space: gym.Space, x: Any):
|
||||
def __init__(self, message: str, space: gym.Space, x: Any) -> None:
|
||||
self.message = message
|
||||
self.space = space
|
||||
self.x = x
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"
|
||||
|
||||
@@ -9,4 +9,5 @@ Multi-asset is on the way.
|
||||
from .interpreter import *
|
||||
from .network import *
|
||||
from .policy import *
|
||||
from .reward import *
|
||||
from .simulator_simple import *
|
||||
|
||||
@@ -5,15 +5,15 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any, List, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .simulator_simple import SAOEState
|
||||
@@ -99,18 +99,18 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
||||
"data_processed": self._mask_future_info(processed.today, state.cur_time),
|
||||
"data_processed_prev": processed.yesterday,
|
||||
"acquiring": state.order.direction == state.order.BUY,
|
||||
"cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1),
|
||||
"cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1),
|
||||
"cur_step": min(self.env.status["cur_step"], self.max_step - 1),
|
||||
"num_step": self.max_step,
|
||||
"target": state.order.amount,
|
||||
"position": state.position,
|
||||
"position_history": position_history[: self.max_step],
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> spaces.Dict:
|
||||
space = {
|
||||
"data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
|
||||
"data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
|
||||
@@ -147,11 +147,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
||||
The key list is not full. You can add more if more information is needed by your policy.
|
||||
"""
|
||||
|
||||
def __init__(self, max_step: int):
|
||||
def __init__(self, max_step: int) -> None:
|
||||
self.max_step = max_step
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> spaces.Dict:
|
||||
space = {
|
||||
"acquiring": spaces.Discrete(2),
|
||||
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
|
||||
@@ -165,13 +165,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
||||
assert self.env is not None
|
||||
assert self.env.status["cur_step"] <= self.max_step
|
||||
obs = CurrentStateObs(
|
||||
{
|
||||
"acquiring": state.order.direction == state.order.BUY,
|
||||
"cur_step": self.env.status["cur_step"],
|
||||
"num_step": self.max_step,
|
||||
"target": state.order.amount,
|
||||
"position": state.position,
|
||||
}
|
||||
acquiring=state.order.direction == state.order.BUY,
|
||||
cur_step=self.env.status["cur_step"],
|
||||
num_step=self.max_step,
|
||||
target=state.order.amount,
|
||||
position=state.position,
|
||||
)
|
||||
return obs
|
||||
|
||||
@@ -188,7 +186,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
||||
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
|
||||
"""
|
||||
|
||||
def __init__(self, values: int | list[float]):
|
||||
def __init__(self, values: int | List[float]) -> None:
|
||||
if isinstance(values, int):
|
||||
values = [i / values for i in range(0, values + 1)]
|
||||
self.action_values = values
|
||||
@@ -203,7 +201,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
||||
|
||||
|
||||
class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
|
||||
"""Convert a continous ratio to deal amount.
|
||||
"""Convert a continuous ratio to deal amount.
|
||||
|
||||
The ratio is relative to TWAP on the remainder of the day.
|
||||
For example, there are 5 steps left, and the left position is 300.
|
||||
|
||||
@@ -3,13 +3,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tianshou.data import Batch
|
||||
|
||||
from qlib.typehint import Literal
|
||||
|
||||
from .interpreter import FullHistoryObs
|
||||
|
||||
__all__ = ["Recurrent"]
|
||||
@@ -18,7 +19,7 @@ __all__ = ["Recurrent"]
|
||||
class Recurrent(nn.Module):
|
||||
"""The network architecture proposed in `OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_.
|
||||
|
||||
At every timestep the input of policy network is divided into two parts,
|
||||
At every time step the input of policy network is divided into two parts,
|
||||
the public variables and the private variables. which are handled by ``raw_rnn``
|
||||
and ``pri_rnn`` in this network, respectively.
|
||||
|
||||
@@ -33,7 +34,7 @@ class Recurrent(nn.Module):
|
||||
output_dim: int = 32,
|
||||
rnn_type: Literal["rnn", "lstm", "gru"] = "gru",
|
||||
rnn_num_layers: int = 1,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
@@ -62,10 +63,10 @@ class Recurrent(nn.Module):
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def _init_extra_branches(self):
|
||||
def _init_extra_branches(self) -> None:
|
||||
pass
|
||||
|
||||
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
||||
bs, _, data_dim = obs["data_processed"].size()
|
||||
data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1)
|
||||
cur_step = obs["cur_step"].long()
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gym.spaces import Discrete
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.policy import PPOPolicy, BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.policy import BasePolicy, PPOPolicy
|
||||
|
||||
__all__ = ["AllOne", "PPO"]
|
||||
|
||||
@@ -18,29 +19,39 @@ __all__ = ["AllOne", "PPO"]
|
||||
# baselines #
|
||||
|
||||
|
||||
class NonlearnablePolicy(BasePolicy):
|
||||
class NonLearnablePolicy(BasePolicy):
|
||||
"""Tianshou's BasePolicy with empty ``learn`` and ``process_fn``.
|
||||
|
||||
This could be moved outside in future.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space: gym.Space, action_space: gym.Space):
|
||||
def __init__(self, obs_space: gym.Space, action_space: gym.Space) -> None:
|
||||
super().__init__()
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
def process_fn(
|
||||
self,
|
||||
batch: Batch,
|
||||
buffer: ReplayBuffer,
|
||||
indices: np.ndarray,
|
||||
) -> Batch:
|
||||
pass
|
||||
|
||||
|
||||
class AllOne(NonlearnablePolicy):
|
||||
class AllOne(NonLearnablePolicy):
|
||||
"""Forward returns a batch full of 1.
|
||||
|
||||
Useful when implementing some baselines (e.g., TWAP).
|
||||
"""
|
||||
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: dict | Batch | np.ndarray = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
return Batch(act=np.full(len(batch), 1.0), state=state)
|
||||
|
||||
|
||||
@@ -48,24 +59,34 @@ class AllOne(NonlearnablePolicy):
|
||||
|
||||
|
||||
class PPOActor(nn.Module):
|
||||
def __init__(self, extractor: nn.Module, action_dim: int):
|
||||
def __init__(self, extractor: nn.Module, action_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1))
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
obs: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
info: dict = {},
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
feature = self.extractor(to_torch(obs, device=auto_device(self)))
|
||||
out = self.layer_out(feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class PPOCritic(nn.Module):
|
||||
def __init__(self, extractor: nn.Module):
|
||||
def __init__(self, extractor: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(cast(int, extractor.output_dim), 1)
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
obs: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
info: dict = {},
|
||||
) -> torch.Tensor:
|
||||
feature = self.extractor(to_torch(obs, device=auto_device(self)))
|
||||
return self.value_out(feature).squeeze(dim=-1)
|
||||
|
||||
@@ -93,18 +114,20 @@ class PPO(PPOPolicy):
|
||||
max_grad_norm: float = 100.0,
|
||||
reward_normalization: bool = True,
|
||||
eps_clip: float = 0.3,
|
||||
value_clip: float = True,
|
||||
value_clip: bool = True,
|
||||
vf_coef: float = 1.0,
|
||||
gae_lambda: float = 1.0,
|
||||
max_batchsize: int = 256,
|
||||
max_batch_size: int = 256,
|
||||
deterministic_eval: bool = True,
|
||||
weight_file: Optional[Path] = None,
|
||||
):
|
||||
) -> None:
|
||||
assert isinstance(action_space, Discrete)
|
||||
actor = PPOActor(network, action_space.n)
|
||||
critic = PPOCritic(network)
|
||||
optimizer = torch.optim.Adam(
|
||||
chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay
|
||||
chain_dedup(actor.parameters(), critic.parameters()),
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
super().__init__(
|
||||
actor,
|
||||
@@ -118,7 +141,7 @@ class PPO(PPOPolicy):
|
||||
value_clip=value_clip,
|
||||
vf_coef=vf_coef,
|
||||
gae_lambda=gae_lambda,
|
||||
max_batchsize=max_batchsize,
|
||||
max_batchsize=max_batch_size,
|
||||
deterministic_eval=deterministic_eval,
|
||||
observation_space=obs_space,
|
||||
action_space=action_space,
|
||||
@@ -136,7 +159,7 @@ def auto_device(module: nn.Module) -> torch.device:
|
||||
return torch.device("cpu") # fallback to cpu
|
||||
|
||||
|
||||
def load_weight(policy, path):
|
||||
def load_weight(policy: nn.Module, path: Path) -> None:
|
||||
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
|
||||
loaded_weight = torch.load(path, map_location="cpu")
|
||||
try:
|
||||
@@ -149,7 +172,7 @@ def load_weight(policy, path):
|
||||
policy.load_state_dict(loaded_weight)
|
||||
|
||||
|
||||
def chain_dedup(*iterables):
|
||||
def chain_dedup(*iterables: Iterable) -> Generator[Any, None, None]:
|
||||
seen = set()
|
||||
for iterable in iterables:
|
||||
for i in iterable:
|
||||
|
||||
47
qlib/rl/order_execution/reward.py
Normal file
47
qlib/rl/order_execution/reward.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.rl.reward import Reward
|
||||
|
||||
from .simulator_simple import SAOEMetrics, SAOEState
|
||||
|
||||
__all__ = ["PAPenaltyReward"]
|
||||
|
||||
|
||||
class PAPenaltyReward(Reward[SAOEState]):
|
||||
"""Encourage higher PAs, but penalize stacking all the amounts within a very short time.
|
||||
Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
penalty
|
||||
The penalty for large volume in a short time.
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: float = 100.0):
|
||||
self.penalty = penalty
|
||||
|
||||
def reward(self, simulator_state: SAOEState) -> float:
|
||||
whole_order = simulator_state.order.amount
|
||||
assert whole_order > 0
|
||||
last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict())
|
||||
pa = last_step["pa"] * last_step["amount"] / whole_order
|
||||
|
||||
# Inspect the "break-down" of the latest step: trading amount at every tick
|
||||
last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :]
|
||||
penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum()
|
||||
|
||||
reward = pa + penalty
|
||||
|
||||
# Throw error in case of NaN
|
||||
assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}"
|
||||
|
||||
self.log("reward/pa", pa)
|
||||
self.log("reward/penalty", penalty)
|
||||
return reward
|
||||
@@ -1,4 +1,424 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Placeholder for qlib-based simulator."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, cast, Generator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderHelper, TradeDecisionWO, TradeRange, TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data.exchange_wrapper import QlibIntradayBacktestData
|
||||
from qlib.rl.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.from_neutrader.feature import init_qlib
|
||||
from qlib.rl.order_execution.simulator_simple import SAOEMetrics, SAOEState
|
||||
from qlib.rl.order_execution.utils import (
|
||||
dataframe_append,
|
||||
get_common_infra,
|
||||
get_portfolio_and_indicator,
|
||||
get_ticks_slice,
|
||||
price_advantage,
|
||||
)
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
|
||||
|
||||
class DecomposedStrategy(BaseStrategy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.execute_order: Optional[Order] = None
|
||||
self.execute_result: List[Tuple[Order, float, float, float]] = []
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
|
||||
# Once the following line is executed, this DecomposedStrategy (self) will be yielded to the outside
|
||||
# of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`,
|
||||
# the sent item will be captured by `exec_vol`. The outside policy could communicate with the inner
|
||||
# level strategy through this way.
|
||||
exec_vol = yield self
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order = oh.create(self._order.stock_id, exec_vol, self._order.direction)
|
||||
|
||||
self.execute_order = order
|
||||
|
||||
return TradeDecisionWO([order], self)
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
return outer_trade_decision
|
||||
|
||||
def post_exe_step(self, execute_result: list) -> None:
|
||||
self.execute_result = execute_result
|
||||
|
||||
def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs: Any) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
order_list = outer_trade_decision.order_list
|
||||
assert len(order_list) == 1
|
||||
self._order = order_list[0]
|
||||
|
||||
|
||||
class SingleOrderStrategy(BaseStrategy):
|
||||
# this logic is copied from FileOrderStrategy
|
||||
def __init__(
|
||||
self,
|
||||
common_infra: CommonInfrastructure,
|
||||
order: Order,
|
||||
trade_range: TradeRange,
|
||||
instrument: str,
|
||||
) -> None:
|
||||
super().__init__(common_infra=common_infra)
|
||||
self._order = order
|
||||
self._trade_range = trade_range
|
||||
self._instrument = instrument
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
return outer_trade_decision
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO:
|
||||
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
|
||||
order_list = [
|
||||
oh.create(
|
||||
code=self._instrument,
|
||||
amount=self._order.amount,
|
||||
direction=self._order.direction,
|
||||
),
|
||||
]
|
||||
return TradeDecisionWO(order_list, self, self._trade_range)
|
||||
|
||||
|
||||
# TODO: move these to the configuration files
|
||||
FINEST_GRANULARITY = "1min"
|
||||
COARSEST_GRANULARITY = "1day"
|
||||
|
||||
|
||||
class StateMaintainer:
|
||||
"""
|
||||
Maintain states of the environment.
|
||||
|
||||
Example usage::
|
||||
|
||||
maintainer = StateMaintainer(...) # in reset
|
||||
maintainer.update(...) # in step
|
||||
# get states in get_state from maintainer
|
||||
"""
|
||||
|
||||
def __init__(self, order: Order, time_per_step: str, tick_index: pd.DatetimeIndex, twap_price: float) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.position = order.amount
|
||||
self._order = order
|
||||
self._time_per_step = time_per_step
|
||||
self._tick_index = tick_index
|
||||
self._twap_price = twap_price
|
||||
|
||||
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
|
||||
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics: Optional[SAOEMetrics] = None
|
||||
|
||||
def update(
|
||||
self,
|
||||
inner_executor: BaseExecutor,
|
||||
inner_strategy: DecomposedStrategy,
|
||||
done: bool,
|
||||
all_indicators: dict,
|
||||
) -> None:
|
||||
execute_order = inner_strategy.execute_order
|
||||
execute_result = inner_strategy.execute_result
|
||||
exec_vol = np.array([e[0].deal_amount for e in execute_result])
|
||||
num_step = len(execute_result)
|
||||
|
||||
assert execute_order is not None
|
||||
|
||||
if num_step == 0:
|
||||
market_volume = np.array([])
|
||||
market_price = np.array([])
|
||||
datetime_list = pd.DatetimeIndex([])
|
||||
else:
|
||||
market_volume = np.array(
|
||||
inner_executor.trade_exchange.get_volume(
|
||||
execute_order.stock_id,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
|
||||
trade_value = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["value"].values
|
||||
deal_amount = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["deal_amount"].values
|
||||
market_price = trade_value / deal_amount
|
||||
|
||||
datetime_list = all_indicators[FINEST_GRANULARITY].index[-num_step:]
|
||||
|
||||
assert market_price.shape == market_volume.shape == exec_vol.shape
|
||||
|
||||
self.history_exec = dataframe_append(
|
||||
self.history_exec,
|
||||
self._collect_multi_order_metric(
|
||||
order=self._order,
|
||||
datetime=datetime_list,
|
||||
market_vol=market_volume,
|
||||
market_price=market_price,
|
||||
exec_vol=exec_vol,
|
||||
pa=all_indicators[self._time_per_step].iloc[-1]["pa"],
|
||||
),
|
||||
)
|
||||
|
||||
self.history_steps = dataframe_append(
|
||||
self.history_steps,
|
||||
[
|
||||
self._collect_single_order_metric(
|
||||
execute_order,
|
||||
execute_order.start_time,
|
||||
market_volume,
|
||||
market_price,
|
||||
exec_vol.sum(),
|
||||
exec_vol,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if done:
|
||||
self.metrics = self._collect_single_order_metric(
|
||||
self._order,
|
||||
self._tick_index[0], # start time
|
||||
self.history_exec["market_volume"],
|
||||
self.history_exec["market_price"],
|
||||
self.history_steps["amount"].sum(),
|
||||
self.history_exec["deal_amount"],
|
||||
)
|
||||
|
||||
# TODO: check whether we need this. Can we get this information from Account?
|
||||
# Do this at the end
|
||||
self.position -= exec_vol.sum()
|
||||
|
||||
def _collect_multi_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
exec_vol: np.ndarray,
|
||||
pa: float,
|
||||
) -> SAOEMetrics:
|
||||
return SAOEMetrics(
|
||||
# It should have the same keys with SAOEMetrics,
|
||||
# but the values do not necessarily have the annotated type.
|
||||
# Some values could be vectorized (e.g., exec_vol).
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol,
|
||||
market_price=market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=market_price,
|
||||
trade_value=market_price * exec_vol,
|
||||
position=self.position - np.cumsum(exec_vol),
|
||||
ffr=exec_vol / order.amount,
|
||||
pa=pa,
|
||||
)
|
||||
|
||||
def _collect_single_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
amount: float, # intended to trade such amount
|
||||
exec_vol: np.ndarray,
|
||||
) -> SAOEMetrics:
|
||||
assert len(market_vol) == len(market_price) == len(exec_vol)
|
||||
|
||||
if np.abs(np.sum(exec_vol)) < EPS:
|
||||
exec_avg_price = 0.0
|
||||
else:
|
||||
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
|
||||
if hasattr(exec_avg_price, "item"): # could be numpy scalar
|
||||
exec_avg_price = exec_avg_price.item() # type: ignore
|
||||
|
||||
exec_sum = exec_vol.sum()
|
||||
return SAOEMetrics(
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol.sum(),
|
||||
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
|
||||
amount=amount,
|
||||
inner_amount=exec_sum,
|
||||
deal_amount=exec_sum, # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position - exec_sum,
|
||||
ffr=float(exec_sum / order.amount),
|
||||
pa=price_advantage(exec_avg_price, self._twap_price, order.direction),
|
||||
)
|
||||
|
||||
|
||||
class SingleAssetOrderExecutionQlib(Simulator[Order, SAOEState, float]):
|
||||
"""Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
order (Order):
|
||||
The seed to start an SAOE simulator is an order.
|
||||
time_per_step (str):
|
||||
A string to describe the time granularity of each step. Current support "1min", "30min", and "1day"
|
||||
qlib_config (dict):
|
||||
Configuration used to initialize Qlib.
|
||||
inner_executor_fn (Callable[[str, CommonInfrastructure], BaseExecutor]):
|
||||
Function used to get the inner level executor.
|
||||
exchange_config (ExchangeConfig):
|
||||
Configuration used to create the Exchange instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
time_per_step: str, # "1min", "30min", "1day"
|
||||
qlib_config: dict,
|
||||
inner_executor_fn: Callable[[str, CommonInfrastructure], BaseExecutor],
|
||||
exchange_config: ExchangeConfig,
|
||||
) -> None:
|
||||
assert time_per_step in ("1min", "30min", "1day")
|
||||
|
||||
super().__init__(initial=order)
|
||||
|
||||
assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same."
|
||||
|
||||
self._order = order
|
||||
self._order_date = pd.Timestamp(order.start_time.date())
|
||||
self._trade_range = TradeRangeByTime(order.start_time.time(), order.end_time.time())
|
||||
self._qlib_config = qlib_config
|
||||
self._inner_executor_fn = inner_executor_fn
|
||||
self._exchange_config = exchange_config
|
||||
|
||||
self._time_per_step = time_per_step
|
||||
self._ticks_per_step = int(pd.Timedelta(time_per_step).total_seconds() // 60)
|
||||
|
||||
self._executor: Optional[NestedExecutor] = None
|
||||
self._collect_data_loop: Optional[Generator] = None
|
||||
|
||||
self._done = False
|
||||
|
||||
self._inner_strategy = DecomposedStrategy()
|
||||
|
||||
self.reset(self._order)
|
||||
|
||||
def reset(self, order: Order) -> None:
|
||||
instrument = order.stock_id
|
||||
|
||||
# TODO: Check this logic. Make sure we need to do this every time we reset the simulator.
|
||||
init_qlib(self._qlib_config, instrument)
|
||||
|
||||
common_infra = get_common_infra(
|
||||
self._exchange_config,
|
||||
trade_date=pd.Timestamp(self._order_date),
|
||||
codes=[instrument],
|
||||
)
|
||||
|
||||
# TODO: We can leverage interfaces like (https://tinyurl.com/y8f8fhv4) to create trading environment.
|
||||
# TODO: By aligning the interface to create environments with Qlib, it will be easier to share the config and
|
||||
# TODO: code between backtesting and training.
|
||||
self._inner_executor = self._inner_executor_fn(self._time_per_step, common_infra)
|
||||
self._executor = NestedExecutor(
|
||||
time_per_step=COARSEST_GRANULARITY,
|
||||
inner_executor=self._inner_executor,
|
||||
inner_strategy=self._inner_strategy,
|
||||
track_data=True,
|
||||
common_infra=common_infra,
|
||||
)
|
||||
|
||||
exchange = self._inner_executor.trade_exchange
|
||||
self._ticks_index = pd.DatetimeIndex([e[1] for e in list(exchange.quote_df.index)])
|
||||
self._ticks_for_order = get_ticks_slice(
|
||||
self._ticks_index,
|
||||
self._order.start_time,
|
||||
self._order.end_time,
|
||||
include_end=True,
|
||||
)
|
||||
|
||||
self._backtest_data = QlibIntradayBacktestData(
|
||||
order=self._order,
|
||||
exchange=exchange,
|
||||
start_time=self._ticks_for_order[0],
|
||||
end_time=self._ticks_for_order[-1],
|
||||
)
|
||||
|
||||
self.twap_price = self._backtest_data.get_deal_price().mean()
|
||||
|
||||
top_strategy = SingleOrderStrategy(common_infra, order, self._trade_range, instrument)
|
||||
self._executor.reset(start_time=pd.Timestamp(self._order_date), end_time=pd.Timestamp(self._order_date))
|
||||
top_strategy.reset(level_infra=self._executor.get_level_infra())
|
||||
|
||||
self._collect_data_loop = self._executor.collect_data(top_strategy.generate_trade_decision(), level=0)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
self._iter_strategy(action=None)
|
||||
self._done = False
|
||||
|
||||
self._maintainer = StateMaintainer(
|
||||
order=self._order,
|
||||
time_per_step=self._time_per_step,
|
||||
tick_index=self._ticks_index,
|
||||
twap_price=self.twap_price,
|
||||
)
|
||||
|
||||
def _iter_strategy(self, action: float = None) -> DecomposedStrategy:
|
||||
"""Iterate the _collect_data_loop until we get the next yield DecomposedStrategy."""
|
||||
assert self._collect_data_loop is not None
|
||||
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
while not isinstance(strategy, DecomposedStrategy):
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
assert isinstance(strategy, DecomposedStrategy)
|
||||
return strategy
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
"""Execute one step or SAOE.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action (float):
|
||||
The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.
|
||||
"""
|
||||
|
||||
assert not self._done, "Simulator has already done!"
|
||||
|
||||
try:
|
||||
self._iter_strategy(action=action)
|
||||
except StopIteration:
|
||||
self._done = True
|
||||
|
||||
assert self._executor is not None
|
||||
_, all_indicators = get_portfolio_and_indicator(self._executor)
|
||||
|
||||
self._maintainer.update(
|
||||
inner_executor=self._inner_executor,
|
||||
inner_strategy=self._inner_strategy,
|
||||
done=self._done,
|
||||
all_indicators=all_indicators,
|
||||
)
|
||||
|
||||
def get_state(self) -> SAOEState:
|
||||
return SAOEState(
|
||||
order=self._order,
|
||||
cur_time=self._inner_executor.trade_calendar.get_step_time()[0],
|
||||
position=self._maintainer.position,
|
||||
history_exec=self._maintainer.history_exec,
|
||||
history_steps=self._maintainer.history_steps,
|
||||
metrics=self._maintainer.metrics,
|
||||
backtest_data=self._backtest_data,
|
||||
ticks_per_step=self._ticks_per_step,
|
||||
ticks_index=self._ticks_index,
|
||||
ticks_for_order=self._ticks_for_order,
|
||||
)
|
||||
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
@@ -4,18 +4,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Any, TypeVar, cast
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_simple_intraday_backtest_data
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType
|
||||
from qlib.rl.utils import LogLevel
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
# TODO: Integrating Qlib's native data with simulator_simple
|
||||
|
||||
__all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"]
|
||||
|
||||
ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point
|
||||
@@ -33,40 +35,40 @@ class SAOEMetrics(TypedDict):
|
||||
|
||||
stock_id: str
|
||||
"""Stock ID of this record."""
|
||||
datetime: pd.Timestamp
|
||||
datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this
|
||||
"""Datetime of this record (this is index in the dataframe)."""
|
||||
direction: int
|
||||
"""Direction of the order. 0 for sell, 1 for buy."""
|
||||
|
||||
# Market information.
|
||||
market_volume: float
|
||||
market_volume: np.ndarray | float
|
||||
"""(total) market volume traded in the period."""
|
||||
market_price: float
|
||||
market_price: np.ndarray | float
|
||||
"""Deal price. If it's a period of time, this is the average market deal price."""
|
||||
|
||||
# Strategy records.
|
||||
|
||||
amount: float
|
||||
amount: np.ndarray | float
|
||||
"""Total amount (volume) strategy intends to trade."""
|
||||
inner_amount: float
|
||||
inner_amount: np.ndarray | float
|
||||
"""Total amount that the lower-level strategy intends to trade
|
||||
(might be larger than amount, e.g., to ensure ffr)."""
|
||||
|
||||
deal_amount: float
|
||||
deal_amount: np.ndarray | float
|
||||
"""Amount that successfully takes effect (must be less than inner_amount)."""
|
||||
trade_price: float
|
||||
trade_price: np.ndarray | float
|
||||
"""The average deal price for this strategy."""
|
||||
trade_value: float
|
||||
"""Total worth of trading. In the simple simulaton, trade_value = deal_amount * price."""
|
||||
position: float
|
||||
trade_value: np.ndarray | float
|
||||
"""Total worth of trading. In the simple simulation, trade_value = deal_amount * price."""
|
||||
position: np.ndarray | float
|
||||
"""Position left after this "period"."""
|
||||
|
||||
# Accumulated metrics
|
||||
|
||||
ffr: float
|
||||
ffr: np.ndarray | float
|
||||
"""Completed how much percent of the daily order."""
|
||||
|
||||
pa: float
|
||||
pa: np.ndarray | float
|
||||
"""Price advantage compared to baseline (i.e., trade with baseline market price).
|
||||
The baseline is trade price when using TWAP strategy to execute this order.
|
||||
Please note that there could be data leak here).
|
||||
@@ -87,7 +89,7 @@ class SAOEState(NamedTuple):
|
||||
history_steps: pd.DataFrame
|
||||
"""See :attr:`SingleAssetOrderExecution.history_steps`."""
|
||||
|
||||
metrics: SAOEMetrics | None
|
||||
metrics: Optional[SAOEMetrics]
|
||||
"""Daily metric, only available when the trading is in "done" state."""
|
||||
|
||||
backtest_data: IntradayBacktestData
|
||||
@@ -114,13 +116,13 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
If such fine granularity is not needed, use ``ticks_per_step`` to
|
||||
lengthen the ticks for each step.
|
||||
|
||||
In each step, the traded amount are "equally" splitted to each tick,
|
||||
then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``),
|
||||
In each step, the traded amount are "equally" separated to each tick,
|
||||
then bounded by volume maximum execution volume (i.e., ``vol_threshold``),
|
||||
and if it's the last step, try to ensure all the amount to be executed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
initial
|
||||
order
|
||||
The seed to start an SAOE simulator is an order.
|
||||
ticks_per_step
|
||||
How many ticks per step.
|
||||
@@ -131,13 +133,16 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
"""
|
||||
|
||||
history_exec: pd.DataFrame
|
||||
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
|
||||
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns.
|
||||
Index is ``datetime``.
|
||||
"""
|
||||
|
||||
history_steps: pd.DataFrame
|
||||
"""Positions at each step. The position before first step is also recorded.
|
||||
See :class:`SAOEMetrics` for available columns."""
|
||||
See :class:`SAOEMetrics` for available columns.
|
||||
Index is ``datetime``, which is the **starting** time of each step."""
|
||||
|
||||
metrics: SAOEMetrics | None
|
||||
metrics: Optional[SAOEMetrics]
|
||||
"""Metrics. Only available when done."""
|
||||
|
||||
twap_price: float
|
||||
@@ -156,15 +161,21 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
data_dir: Path,
|
||||
ticks_per_step: int = 30,
|
||||
deal_price_type: DealPriceType = "close",
|
||||
vol_threshold: float | None = None,
|
||||
vol_threshold: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(initial=order)
|
||||
|
||||
self.order = order
|
||||
self.ticks_per_step: int = ticks_per_step
|
||||
self.deal_price_type = deal_price_type
|
||||
self.vol_threshold = vol_threshold
|
||||
self.data_dir = data_dir
|
||||
self.backtest_data = load_intraday_backtest_data(
|
||||
self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction
|
||||
self.backtest_data = load_simple_intraday_backtest_data(
|
||||
self.data_dir,
|
||||
order.stock_id,
|
||||
pd.Timestamp(order.start_time.date()),
|
||||
self.deal_price_type,
|
||||
order.direction,
|
||||
)
|
||||
|
||||
self.ticks_index = self.backtest_data.get_time_index()
|
||||
@@ -185,9 +196,9 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics = None
|
||||
|
||||
self.market_price: np.ndarray | None = None
|
||||
self.market_vol: np.ndarray | None = None
|
||||
self.market_vol_limit: np.ndarray | None = None
|
||||
self.market_price: Optional[np.ndarray] = None
|
||||
self.market_vol: Optional[np.ndarray] = None
|
||||
self.market_vol_limit: Optional[np.ndarray] = None
|
||||
|
||||
def step(self, amount: float) -> None:
|
||||
"""Execute one step or SAOE.
|
||||
@@ -202,7 +213,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
|
||||
self.market_price = self.market_vol = None # avoid misuse
|
||||
exec_vol = self._split_exec_vol(amount)
|
||||
assert self.market_price is not None and self.market_vol is not None
|
||||
assert self.market_price is not None
|
||||
assert self.market_vol is not None
|
||||
|
||||
ticks_position = self.position - np.cumsum(exec_vol)
|
||||
|
||||
@@ -360,7 +372,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
inner_amount=exec_vol.sum(),
|
||||
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=np.sum(market_price * exec_vol),
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position,
|
||||
ffr=float(exec_vol.sum() / self.order.amount),
|
||||
pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction),
|
||||
@@ -383,7 +395,9 @@ _float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
|
||||
|
||||
|
||||
def price_advantage(
|
||||
exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int
|
||||
exec_price: _float_or_ndarray,
|
||||
baseline_price: float,
|
||||
direction: OrderDir | int,
|
||||
) -> _float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
|
||||
111
qlib/rl/order_execution/utils.py
Normal file
111
qlib/rl/order_execution/utils.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import CommonInfrastructure, get_exchange
|
||||
from qlib.backtest.account import Account
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.rl.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.order_execution.simulator_simple import ONE_SEC, _float_or_ndarray
|
||||
from qlib.utils.time import Freq
|
||||
|
||||
|
||||
def get_common_infra(
|
||||
config: ExchangeConfig,
|
||||
trade_date: pd.Timestamp,
|
||||
codes: List[str],
|
||||
cash_limit: float = None,
|
||||
) -> CommonInfrastructure:
|
||||
# need to specify a range here for acceleration
|
||||
if cash_limit is None:
|
||||
trade_account = Account(init_cash=int(1e12), benchmark_config={}, pos_type="InfPosition")
|
||||
else:
|
||||
trade_account = Account(
|
||||
init_cash=cash_limit,
|
||||
benchmark_config={},
|
||||
pos_type="Position",
|
||||
position_dict={code: {"amount": 1e12, "price": 1.0} for code in codes},
|
||||
)
|
||||
|
||||
exchange = get_exchange(
|
||||
codes=codes,
|
||||
freq="1min",
|
||||
limit_threshold=config.limit_threshold,
|
||||
deal_price=config.deal_price,
|
||||
open_cost=config.open_cost,
|
||||
close_cost=config.close_cost,
|
||||
min_cost=config.min_cost if config.trade_unit is not None else 0,
|
||||
start_time=trade_date,
|
||||
end_time=trade_date + pd.DateOffset(1),
|
||||
trade_unit=config.trade_unit,
|
||||
volume_threshold=config.volume_threshold,
|
||||
)
|
||||
|
||||
return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
|
||||
|
||||
|
||||
def get_ticks_slice(
|
||||
ticks_index: pd.DatetimeIndex,
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
include_end: bool = False,
|
||||
) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - ONE_SEC
|
||||
return ticks_index[ticks_index.slice_indexer(start, end)]
|
||||
|
||||
|
||||
def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
|
||||
# dataframe.append is deprecated
|
||||
other_df = pd.DataFrame(other).set_index("datetime")
|
||||
other_df.index.name = "datetime"
|
||||
|
||||
res = pd.concat([df, other_df], axis=0)
|
||||
return res
|
||||
|
||||
|
||||
def price_advantage(
|
||||
exec_price: _float_or_ndarray,
|
||||
baseline_price: float,
|
||||
direction: OrderDir | int,
|
||||
) -> _float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
return 0.0
|
||||
else:
|
||||
return np.zeros_like(exec_price)
|
||||
if direction == OrderDir.BUY:
|
||||
res = (1 - exec_price / baseline_price) * 10000
|
||||
elif direction == OrderDir.SELL:
|
||||
res = (exec_price / baseline_price - 1) * 10000
|
||||
else:
|
||||
raise ValueError(f"Unexpected order direction: {direction}")
|
||||
res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0)
|
||||
if res_wo_nan.size == 1:
|
||||
return res_wo_nan.item()
|
||||
else:
|
||||
return cast(_float_or_ndarray, res_wo_nan)
|
||||
|
||||
|
||||
def get_portfolio_and_indicator(executor: BaseExecutor) -> Tuple[dict, dict]:
|
||||
all_executors = executor.get_all_executors()
|
||||
all_portfolio_metrics = {
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
|
||||
for _executor in all_executors
|
||||
if _executor.trade_account.is_port_metr_enabled()
|
||||
}
|
||||
|
||||
all_indicators = {}
|
||||
for _executor in all_executors:
|
||||
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
|
||||
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
|
||||
|
||||
return all_portfolio_metrics, all_indicators
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, Any, TypeVar, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, TypeVar
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
@@ -20,7 +20,7 @@ class Reward(Generic[SimulatorState]):
|
||||
Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe.
|
||||
"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@final
|
||||
def __call__(self, simulator_state: SimulatorState) -> float:
|
||||
@@ -30,14 +30,15 @@ class Reward(Generic[SimulatorState]):
|
||||
"""Implement this method for your own reward."""
|
||||
raise NotImplementedError("Implement reward calculation recipe in `reward()`.")
|
||||
|
||||
def log(self, name, value):
|
||||
def log(self, name: str, value: Any) -> None:
|
||||
assert self.env is not None
|
||||
self.env.logger.add_scalar(name, value)
|
||||
|
||||
|
||||
class RewardCombination(Reward):
|
||||
"""Combination of multiple reward."""
|
||||
|
||||
def __init__(self, rewards: dict[str, tuple[Reward, float]]):
|
||||
def __init__(self, rewards: Dict[str, Tuple[Reward, float]]) -> None:
|
||||
self.rewards = rewards
|
||||
|
||||
def reward(self, simulator_state: Any) -> float:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeVar, Generic, Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
from .seed import InitialStateType
|
||||
|
||||
@@ -49,7 +49,7 @@ class Simulator(Generic[InitialStateType, StateType, ActType]):
|
||||
Simulators are discouraged to use this, because it's prone to induce errors.
|
||||
"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
def __init__(self, initial: InitialStateType, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
9
qlib/rl/trainer/__init__.py
Normal file
9
qlib/rl/trainer/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Train, test, inference utilities."""
|
||||
|
||||
from .api import backtest, train
|
||||
from .callbacks import EarlyStopping, Checkpoint
|
||||
from .trainer import Trainer
|
||||
from .vessel import TrainingVessel, TrainingVesselBase
|
||||
118
qlib/rl/trainer/api.py
Normal file
118
qlib/rl/trainer/api.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Sequence, cast
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.utils import FiniteEnvType, LogWriter
|
||||
|
||||
from .trainer import Trainer
|
||||
from .vessel import TrainingVessel
|
||||
|
||||
|
||||
def train(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
vessel_kwargs: dict[str, Any],
|
||||
trainer_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Train a policy with the parallelism provided by RL framework.
|
||||
|
||||
Experimental API. Parameters might change shortly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
simulator_fn
|
||||
Callable receiving initial seed, returning a simulator.
|
||||
state_interpreter
|
||||
Interprets the state of simulators.
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
policy
|
||||
Policy to train against.
|
||||
reward
|
||||
Reward function.
|
||||
vessel_kwargs
|
||||
Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``.
|
||||
trainer_kwargs
|
||||
Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``.
|
||||
"""
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=simulator_fn,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
train_initial_states=initial_states,
|
||||
reward=reward, # ignore none
|
||||
**vessel_kwargs,
|
||||
)
|
||||
trainer = Trainer(**trainer_kwargs)
|
||||
trainer.fit(vessel)
|
||||
|
||||
|
||||
def backtest(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
) -> None:
|
||||
"""Backtest with the parallelism provided by RL framework.
|
||||
|
||||
Experimental API. Parameters might change shortly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
simulator_fn
|
||||
Callable receiving initial seed, returning a simulator.
|
||||
state_interpreter
|
||||
Interprets the state of simulators.
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
policy
|
||||
Policy to test against.
|
||||
logger
|
||||
Logger to record the backtest results. Logger must be present because
|
||||
without logger, all information will be lost.
|
||||
reward
|
||||
Optional reward function. For backtest, this is for testing the rewards
|
||||
and logging them only.
|
||||
finite_env_type
|
||||
Type of finite env implementation.
|
||||
concurrency
|
||||
Parallel workers.
|
||||
"""
|
||||
|
||||
vessel = TrainingVessel(
|
||||
simulator_fn=simulator_fn,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
test_initial_states=initial_states,
|
||||
reward=cast(Reward, reward), # ignore none
|
||||
)
|
||||
trainer = Trainer(
|
||||
finite_env_type=finite_env_type,
|
||||
concurrency=concurrency,
|
||||
loggers=logger,
|
||||
)
|
||||
trainer.test(vessel)
|
||||
267
qlib/rl/trainer/callbacks.py
Normal file
267
qlib/rl/trainer/callbacks.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Callbacks to insert customized recipes during the training.
|
||||
Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of RL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import shutil
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.typehint import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .trainer import Trainer
|
||||
from .vessel import TrainingVesselBase
|
||||
|
||||
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class Callback:
|
||||
"""Base class of all callbacks."""
|
||||
|
||||
def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called before the whole fit process begins."""
|
||||
|
||||
def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called after the whole fit process ends."""
|
||||
|
||||
def on_train_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when each collect for training begins."""
|
||||
|
||||
def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when the training ends.
|
||||
To access all outputs produced during training, cache the data in either trainer and vessel,
|
||||
and post-process them in this hook.
|
||||
"""
|
||||
|
||||
def on_validate_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when every run for validation begins."""
|
||||
|
||||
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when the validation ends."""
|
||||
|
||||
def on_test_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when every run of testing begins."""
|
||||
|
||||
def on_test_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when the testing ends."""
|
||||
|
||||
def on_iter_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called when every iteration (i.e., collect) starts."""
|
||||
|
||||
def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
"""Called upon every end of iteration.
|
||||
This is called **after** the bump of ``current_iter``,
|
||||
when the previous iteration is considered complete.
|
||||
"""
|
||||
|
||||
def state_dict(self) -> Any:
|
||||
"""Get a state dict of the callback for pause and resume."""
|
||||
|
||||
def load_state_dict(self, state_dict: Any) -> None:
|
||||
"""Resume the callback from a saved state dict."""
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
"""Stop training when a monitored metric has stopped improving.
|
||||
|
||||
The earlystopping callback will be triggered each time validation ends.
|
||||
It will examine the metrics produced in validation,
|
||||
and get the metric with name ``monitor` (``monitor`` is ``reward`` by default),
|
||||
to check whether it's no longer increasing / decreasing.
|
||||
It takes ``min_delta`` and ``patience`` if applicable.
|
||||
If it's found to be not increasing / decreasing any more.
|
||||
``trainer.should_stop`` will be set to true,
|
||||
and the training terminates.
|
||||
|
||||
Implementation reference: https://github.com/keras-team/keras/blob/v2.9.0/keras/callbacks.py#L1744-L1893
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
monitor: str = "reward",
|
||||
min_delta: float = 0.0,
|
||||
patience: int = 0,
|
||||
mode: Literal["min", "max"] = "max",
|
||||
baseline: float | None = None,
|
||||
restore_best_weights: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.monitor = monitor
|
||||
self.patience = patience
|
||||
self.baseline = baseline
|
||||
self.min_delta = abs(min_delta)
|
||||
self.restore_best_weights = restore_best_weights
|
||||
self.best_weights: Any | None = None
|
||||
|
||||
if mode not in ["min", "max"]:
|
||||
raise ValueError("Unsupported earlystopping mode: " + mode)
|
||||
|
||||
if mode == "min":
|
||||
self.monitor_op = np.less
|
||||
elif mode == "max":
|
||||
self.monitor_op = np.greater
|
||||
|
||||
if self.monitor_op == np.greater:
|
||||
self.min_delta *= 1
|
||||
else:
|
||||
self.min_delta *= -1
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
return {"wait": self.wait, "best": self.best, "best_weights": self.best_weights, "best_iter": self.best_iter}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.wait = state_dict["wait"]
|
||||
self.best = state_dict["best"]
|
||||
self.best_weights = state_dict["best_weights"]
|
||||
self.best_iter = state_dict["best_iter"]
|
||||
|
||||
def on_fit_start(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
# Allow instances to be re-used
|
||||
self.wait = 0
|
||||
self.best = np.inf if self.monitor_op == np.less else -np.inf
|
||||
self.best_weights = None
|
||||
self.best_iter = 0
|
||||
|
||||
def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
current = self.get_monitor_value(trainer)
|
||||
if current is None:
|
||||
return
|
||||
if self.restore_best_weights and self.best_weights is None:
|
||||
# Restore the weights after first iteration if no progress is ever made.
|
||||
self.best_weights = copy.deepcopy(vessel.state_dict())
|
||||
|
||||
self.wait += 1
|
||||
if self._is_improvement(current, self.best):
|
||||
self.best = current
|
||||
self.best_iter = trainer.current_iter
|
||||
if self.restore_best_weights:
|
||||
self.best_weights = copy.deepcopy(vessel.state_dict())
|
||||
# Only restart wait if we beat both the baseline and our previous best.
|
||||
if self.baseline is None or self._is_improvement(current, self.baseline):
|
||||
self.wait = 0
|
||||
|
||||
# Only check after the first epoch.
|
||||
if self.wait >= self.patience and trainer.current_iter > 0:
|
||||
trainer.should_stop = True
|
||||
_logger.info(f"On iteration %d: early stopping", trainer.current_iter + 1)
|
||||
if self.restore_best_weights and self.best_weights is not None:
|
||||
_logger.info("Restoring model weights from the end of the best iteration: %d", self.best_iter + 1)
|
||||
vessel.load_state_dict(self.best_weights)
|
||||
|
||||
def get_monitor_value(self, trainer: Trainer) -> Any:
|
||||
monitor_value = trainer.metrics.get(self.monitor)
|
||||
if monitor_value is None:
|
||||
_logger.warning(
|
||||
"Early stopping conditioned on metric `%s` which is not available. Available metrics are: %s",
|
||||
self.monitor,
|
||||
",".join(list(trainer.metrics.keys())),
|
||||
)
|
||||
return monitor_value
|
||||
|
||||
def _is_improvement(self, monitor_value, reference_value):
|
||||
return self.monitor_op(monitor_value - self.min_delta, reference_value)
|
||||
|
||||
|
||||
class Checkpoint(Callback):
|
||||
"""Save checkpoints periodically for persistence and recovery.
|
||||
|
||||
Reference: https://github.com/PyTorchLightning/pytorch-lightning/blob/bfa8b7be/pytorch_lightning/callbacks/model_checkpoint.py
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dirpath
|
||||
Directory to save the checkpoint file.
|
||||
filename
|
||||
Checkpoint filename. Can contain named formatting options to be auto-filled.
|
||||
For example: ``{iter:03d}-{reward:.2f}.pth``.
|
||||
Supported argument names are:
|
||||
|
||||
- iter (int)
|
||||
- metrics in ``trainer.metrics``
|
||||
- time string, in the format of ``%Y%m%d%H%M%S``
|
||||
save_latest
|
||||
Save the latest checkpoint in ``latest.pth``.
|
||||
If ``link``, ``latest.pth`` will be created as a softlink.
|
||||
If ``copy``, ``latest.pth`` will be stored as an individual copy.
|
||||
Set to none to disable this.
|
||||
every_n_iters
|
||||
Checkpoints are saved at the end of every n iterations of training,
|
||||
after validation if applicable.
|
||||
time_interval
|
||||
Maximum time (seconds) before checkpoints save again.
|
||||
save_on_fit_end
|
||||
Save one last checkpoint at the end to fit.
|
||||
Do nothing if a checkpoint is already saved there.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirpath: Path,
|
||||
filename: str = "{iter:03d}.pth",
|
||||
save_latest: Literal["link", "copy"] | None = "link",
|
||||
every_n_iters: int | None = None,
|
||||
time_interval: int | None = None,
|
||||
save_on_fit_end: bool = True,
|
||||
):
|
||||
self.dirpath = Path(dirpath)
|
||||
self.filename = filename
|
||||
self.save_latest = save_latest
|
||||
self.every_n_iters = every_n_iters
|
||||
self.time_interval = time_interval
|
||||
self.save_on_fit_end = save_on_fit_end
|
||||
|
||||
self._last_checkpoint_name: str | None = None
|
||||
self._last_checkpoint_iter: int | None = None
|
||||
self._last_checkpoint_time: float | None = None
|
||||
|
||||
def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
if self.save_on_fit_end and (trainer.current_iter != self._last_checkpoint_iter):
|
||||
self._save_checkpoint(trainer)
|
||||
|
||||
def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None:
|
||||
should_save_ckpt = False
|
||||
if self.every_n_iters is not None and (trainer.current_iter + 1) % self.every_n_iters == 0:
|
||||
should_save_ckpt = True
|
||||
if self.time_interval is not None and (
|
||||
self._last_checkpoint_time is None or (time.time() - self._last_checkpoint_time) >= self.time_interval
|
||||
):
|
||||
should_save_ckpt = True
|
||||
if should_save_ckpt:
|
||||
self._save_checkpoint(trainer)
|
||||
|
||||
def _save_checkpoint(self, trainer: Trainer) -> None:
|
||||
self.dirpath.mkdir(exist_ok=True, parents=True)
|
||||
self._last_checkpoint_name = self._new_checkpoint_name(trainer)
|
||||
self._last_checkpoint_iter = trainer.current_iter
|
||||
self._last_checkpoint_time = time.time()
|
||||
torch.save(trainer.state_dict(), self.dirpath / self._last_checkpoint_name)
|
||||
|
||||
latest_pth = self.dirpath / "latest.pth"
|
||||
|
||||
# Remove first before saving
|
||||
if self.save_latest and latest_pth.exists():
|
||||
latest_pth.unlink()
|
||||
|
||||
if self.save_latest == "link":
|
||||
latest_pth.symlink_to(self.dirpath / self._last_checkpoint_name)
|
||||
elif self.save_latest == "copy":
|
||||
shutil.copyfile(self.dirpath / self._last_checkpoint_name, latest_pth)
|
||||
|
||||
def _new_checkpoint_name(self, trainer: Trainer) -> str:
|
||||
return self.filename.format(
|
||||
iter=trainer.current_iter, time=datetime.now().strftime("%Y%m%d%H%M%S"), **trainer.metrics
|
||||
)
|
||||
343
qlib/rl/trainer/trainer.py
Normal file
343
qlib/rl/trainer/trainer.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Sequence, TypeVar, cast
|
||||
|
||||
import torch
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.simulator import InitialStateType
|
||||
from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogBuffer, LogCollector, LogLevel, LogWriter, vectorize_env
|
||||
from qlib.rl.utils.finite_env import FiniteVectorEnv
|
||||
from qlib.typehint import Literal
|
||||
|
||||
from .callbacks import Callback
|
||||
from .vessel import TrainingVesselBase
|
||||
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Utility to train a policy on a particular task.
|
||||
|
||||
Different from traditional DL trainer, the iteration of this trainer is "collect",
|
||||
rather than "epoch", or "mini-batch".
|
||||
In each collect, :class:`Collector` collects a number of policy-env interactions, and accumulates
|
||||
them into a replay buffer. This buffer is used as the "data" to train the policy.
|
||||
At the end of each collect, the policy is *updated* several times.
|
||||
|
||||
The API has some resemblence with `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/>`__,
|
||||
but it's essentially different because this trainer is built for RL applications, and thus
|
||||
most configurations are under RL context.
|
||||
We are still looking for ways to incorporate existing trainer libraries, because it looks like
|
||||
big efforts to build a trainer as powerful as those libraries, and also, that's not our primary goal.
|
||||
|
||||
It's essentially different
|
||||
`tianshou's built-in trainers <https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html>`__,
|
||||
as it's far much more complicated than that.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_iters
|
||||
Maximum iterations before stopping.
|
||||
val_every_n_iters
|
||||
Perform validation every n iterations (i.e., training collects).
|
||||
logger
|
||||
Logger to record the backtest results. Logger must be present because
|
||||
without logger, all information will be lost.
|
||||
finite_env_type
|
||||
Type of finite env implementation.
|
||||
concurrency
|
||||
Parallel workers.
|
||||
fast_dev_run
|
||||
Create a subset for debugging.
|
||||
How this is implemented depends on the implementation of training vessel.
|
||||
For :class:`~qlib.rl.vessel.TrainingVessel`, if greater than zero,
|
||||
a random subset sized ``fast_dev_run`` will be used
|
||||
instead of ``train_initial_states`` and ``val_initial_states``.
|
||||
"""
|
||||
|
||||
should_stop: bool
|
||||
"""Set to stop the training."""
|
||||
|
||||
metrics: dict
|
||||
"""Numeric metrics of produced in train/val/test.
|
||||
In the middle of training / validation, metrics will be of the latest episode.
|
||||
When each iteration of training / validation finishes, metrics will be the aggregation
|
||||
of all episodes encountered in this iteration.
|
||||
|
||||
Cleared on every new iteration of training.
|
||||
|
||||
In fit, validation metrics will be prefixed with ``val/``.
|
||||
"""
|
||||
|
||||
current_iter: int
|
||||
"""Current iteration (collect) of training."""
|
||||
|
||||
loggers: list[LogWriter]
|
||||
"""A list of log writers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_iters: int | None = None,
|
||||
val_every_n_iters: int | None = None,
|
||||
loggers: LogWriter | list[LogWriter] | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
fast_dev_run: int | None = None,
|
||||
):
|
||||
self.max_iters = max_iters
|
||||
self.val_every_n_iters = val_every_n_iters
|
||||
|
||||
if isinstance(loggers, list):
|
||||
self.loggers = loggers
|
||||
elif isinstance(loggers, LogWriter):
|
||||
self.loggers = [loggers]
|
||||
else:
|
||||
self.loggers = []
|
||||
|
||||
self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel()))
|
||||
|
||||
self.callbacks: list[Callback] = callbacks if callbacks is not None else []
|
||||
self.finite_env_type = finite_env_type
|
||||
self.concurrency = concurrency
|
||||
self.fast_dev_run = fast_dev_run
|
||||
|
||||
self.current_stage: Literal["train", "val", "test"] = "train"
|
||||
|
||||
self.vessel: TrainingVesselBase = cast(TrainingVesselBase, None)
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize the whole training process.
|
||||
|
||||
The states here should be synchronized with state_dict.
|
||||
"""
|
||||
self.should_stop = False
|
||||
self.current_iter = 0
|
||||
self.current_episode = 0
|
||||
self.current_stage = "train"
|
||||
|
||||
def initialize_iter(self):
|
||||
"""Initialize one iteration / collect."""
|
||||
self.metrics = {}
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Putting every states of current training into a dict, at best effort.
|
||||
|
||||
It doesn't try to handle all the possible kinds of states in the middle of one training collect.
|
||||
For most cases at the end of each iteration, things should be usually correct.
|
||||
|
||||
Note that it's also intended behavior that replay buffer data in the collector will be lost.
|
||||
"""
|
||||
return {
|
||||
"vessel": self.vessel.state_dict(),
|
||||
"callbacks": {name: callback.state_dict() for name, callback in self.named_callbacks().items()},
|
||||
"loggers": {name: logger.state_dict() for name, logger in self.named_loggers().items()},
|
||||
"should_stop": self.should_stop,
|
||||
"current_iter": self.current_iter,
|
||||
"current_episode": self.current_episode,
|
||||
"current_stage": self.current_stage,
|
||||
"metrics": self.metrics,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Load all states into current trainer."""
|
||||
self.vessel.load_state_dict(state_dict["vessel"])
|
||||
for name, callback in self.named_callbacks().items():
|
||||
callback.load_state_dict(state_dict["callbacks"][name])
|
||||
for name, logger in self.named_loggers().items():
|
||||
logger.load_state_dict(state_dict["loggers"][name])
|
||||
self.should_stop = state_dict["should_stop"]
|
||||
self.current_iter = state_dict["current_iter"]
|
||||
self.current_episode = state_dict["current_episode"]
|
||||
self.current_stage = state_dict["current_stage"]
|
||||
self.metrics = state_dict["metrics"]
|
||||
|
||||
def named_callbacks(self) -> dict[str, Callback]:
|
||||
"""Retrieve a collection of callbacks where each one has a name.
|
||||
Useful when saving checkpoints.
|
||||
"""
|
||||
return _named_collection(self.callbacks)
|
||||
|
||||
def named_loggers(self) -> dict[str, LogWriter]:
|
||||
"""Retrieve a collection of loggers where each one has a name.
|
||||
Useful when saving checkpoints.
|
||||
"""
|
||||
return _named_collection(self.loggers)
|
||||
|
||||
def fit(self, vessel: TrainingVesselBase, ckpt_path: Path | None = None) -> None:
|
||||
"""Train the RL policy upon the defined simulator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vessel
|
||||
A bundle of all elements used in training.
|
||||
ckpt_path
|
||||
Load a pre-trained / paused training checkpoint.
|
||||
"""
|
||||
self.vessel = vessel
|
||||
vessel.assign_trainer(self)
|
||||
|
||||
if ckpt_path is not None:
|
||||
_logger.info("Resuming states from %s", str(ckpt_path))
|
||||
self.load_state_dict(torch.load(ckpt_path))
|
||||
else:
|
||||
self.initialize()
|
||||
|
||||
self._call_callback_hooks("on_fit_start")
|
||||
|
||||
while not self.should_stop:
|
||||
self.initialize_iter()
|
||||
|
||||
self._call_callback_hooks("on_iter_start")
|
||||
|
||||
self.current_stage = "train"
|
||||
self._call_callback_hooks("on_train_start")
|
||||
|
||||
# TODO
|
||||
# Add a feature that supports reloading the training environment every few iterations.
|
||||
with _wrap_context(vessel.train_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.train(vector_env)
|
||||
|
||||
self._call_callback_hooks("on_train_end")
|
||||
|
||||
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
|
||||
# Implementation of validation loop
|
||||
self.current_stage = "val"
|
||||
self._call_callback_hooks("on_validate_start")
|
||||
with _wrap_context(vessel.val_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.validate(vector_env)
|
||||
|
||||
self._call_callback_hooks("on_validate_end")
|
||||
|
||||
# This iteration is considered complete.
|
||||
# Bumping the current iteration counter.
|
||||
self.current_iter += 1
|
||||
|
||||
if self.max_iters is not None and self.current_iter >= self.max_iters:
|
||||
self.should_stop = True
|
||||
|
||||
self._call_callback_hooks("on_iter_end")
|
||||
|
||||
self._call_callback_hooks("on_fit_end")
|
||||
|
||||
def test(self, vessel: TrainingVesselBase) -> None:
|
||||
"""Test the RL policy against the simulator.
|
||||
|
||||
The simulator will be fed with data generated in ``test_seed_iterator``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vessel
|
||||
A bundle of all related elements.
|
||||
"""
|
||||
self.vessel = vessel
|
||||
vessel.assign_trainer(self)
|
||||
|
||||
self.initialize_iter()
|
||||
|
||||
self.current_stage = "test"
|
||||
self._call_callback_hooks("on_test_start")
|
||||
with _wrap_context(vessel.test_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.test(vector_env)
|
||||
self._call_callback_hooks("on_test_end")
|
||||
|
||||
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
|
||||
"""Create a vectorized environment from iterator and the training vessel."""
|
||||
|
||||
def env_factory():
|
||||
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
|
||||
# and could be thread unsafe.
|
||||
# I'm not sure whether it's a design flaw.
|
||||
# I'll rethink about this when designing the trainer.
|
||||
|
||||
if self.finite_env_type == "dummy":
|
||||
# We could only experience the "threading-unsafe" problem in dummy.
|
||||
state = copy.deepcopy(self.vessel.state_interpreter)
|
||||
action = copy.deepcopy(self.vessel.action_interpreter)
|
||||
rew = copy.deepcopy(self.vessel.reward)
|
||||
else:
|
||||
state = self.vessel.state_interpreter
|
||||
action = self.vessel.action_interpreter
|
||||
rew = self.vessel.reward
|
||||
|
||||
return EnvWrapper(
|
||||
self.vessel.simulator_fn,
|
||||
state,
|
||||
action,
|
||||
iterator,
|
||||
rew,
|
||||
logger=LogCollector(min_loglevel=self._min_loglevel()),
|
||||
)
|
||||
|
||||
return vectorize_env(
|
||||
env_factory,
|
||||
self.finite_env_type,
|
||||
self.concurrency,
|
||||
self.loggers,
|
||||
)
|
||||
|
||||
def _metrics_callback(self, on_episode: bool, on_collect: bool, log_buffer: LogBuffer) -> None:
|
||||
if on_episode:
|
||||
# Update the global counter.
|
||||
self.current_episode = log_buffer.global_episode
|
||||
metrics = log_buffer.episode_metrics()
|
||||
elif on_collect:
|
||||
# Update the latest metrics.
|
||||
metrics = log_buffer.collect_metrics()
|
||||
if self.current_stage == "val":
|
||||
metrics = {"val/" + name: value for name, value in metrics.items()}
|
||||
self.metrics.update(metrics)
|
||||
|
||||
def _call_callback_hooks(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
for callback in self.callbacks:
|
||||
fn = getattr(callback, hook_name)
|
||||
fn(self, self.vessel, *args, **kwargs)
|
||||
|
||||
def _min_loglevel(self):
|
||||
if not self.loggers:
|
||||
return LogLevel.PERIODIC
|
||||
else:
|
||||
# To save bandwidth
|
||||
return min(lg.loglevel for lg in self.loggers)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _wrap_context(obj):
|
||||
"""Make any object a (possibly dummy) context manager."""
|
||||
|
||||
if isinstance(obj, AbstractContextManager):
|
||||
# obj has __enter__ and __exit__
|
||||
with obj as ctx:
|
||||
yield ctx
|
||||
else:
|
||||
yield obj
|
||||
|
||||
|
||||
def _named_collection(seq: Sequence[T]) -> dict[str, T]:
|
||||
"""Convert a list into a dict, where each item is named with its type."""
|
||||
res = {}
|
||||
for item in seq:
|
||||
typename = type(item).__name__.lower()
|
||||
if typename not in res:
|
||||
res[typename] = item
|
||||
else:
|
||||
# names are auto-labelled as earlystop1, earlystop2, ...
|
||||
for retry in range(1, 1000):
|
||||
if f"{typename}{retry}" not in res:
|
||||
res[f"{typename}{retry}"] = item
|
||||
return res
|
||||
216
qlib/rl/trainer/vessel.py
Normal file
216
qlib/rl/trainer/vessel.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.constant import INF
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.interpreter import ActionInterpreter, ActType, ObsType, PolicyActType, StateInterpreter, StateType
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.utils import DataQueue
|
||||
from qlib.rl.utils.finite_env import FiniteVectorEnv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .trainer import Trainer
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
class SeedIteratorNotAvailable(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]):
|
||||
"""A ship that contains simulator, interpreter, and policy, will be sent to trainer.
|
||||
This class controls algorithm-related parts of training, while trainer is responsible for runtime part.
|
||||
|
||||
The ship also defines the most important logic of the core training part,
|
||||
and (optionally) some callbacks to insert customized logics at specific events.
|
||||
"""
|
||||
|
||||
simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]]
|
||||
state_interpreter: StateInterpreter[StateType, ObsType]
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType]
|
||||
policy: BasePolicy
|
||||
reward: Reward
|
||||
trainer: Trainer
|
||||
|
||||
def assign_trainer(self, trainer: Trainer) -> None:
|
||||
self.trainer = weakref.proxy(trainer) # type: ignore
|
||||
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for training.
|
||||
If the iterable is a context manager, the whole training will be invoked in the with-block,
|
||||
and the iterator will be automatically closed after the training is done."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for training is not available.")
|
||||
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for validation."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for validation is not available.")
|
||||
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for testing."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
|
||||
|
||||
def train(self, vector_env: BaseVectorEnv) -> dict[str, Any]:
|
||||
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
"""Implement this to validate the policy once."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
"""Implement this to evaluate the policy on test environment once."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def log(self, name: str, value: Any) -> None:
|
||||
# FIXME: this is a workaround to make the log at least show somewhere.
|
||||
# Need a refactor in logger to formalize this.
|
||||
if isinstance(value, (np.ndarray, list)):
|
||||
value = np.mean(value)
|
||||
_logger.info(f"[Iter {self.trainer.current_iter + 1}] {name} = {value}")
|
||||
|
||||
def log_dict(self, data: dict[str, Any]) -> None:
|
||||
for name, value in data.items():
|
||||
self.log(name, value)
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Return a checkpoint of current vessel state."""
|
||||
return {"policy": self.policy.state_dict()}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Restore a checkpoint from a previously saved state dict."""
|
||||
self.policy.load_state_dict(state_dict["policy"])
|
||||
|
||||
|
||||
class TrainingVessel(TrainingVesselBase):
|
||||
"""The default implementation of training vessel.
|
||||
|
||||
``__init__`` accepts a sequence of initial states so that iterator can be created.
|
||||
``train``, ``validate``, ``test`` each do one collect (and also update in train).
|
||||
By default, the train initial states will be repeated infinitely during training,
|
||||
and collector will control the number of episodes for each iteration.
|
||||
In validation and testing, the val / test initial states will be used exactly once.
|
||||
|
||||
Extra hyper-parameters (only used in train) include:
|
||||
|
||||
- ``buffer_size``: Size of replay buffer.
|
||||
- ``episode_per_iter``: Episodes per collect at training. Can be overridden by fast dev run.
|
||||
- ``update_kwargs``: Keyword arguments appearing in ``policy.update``.
|
||||
For example, ``dict(repeat=10, batch_size=64)``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]],
|
||||
state_interpreter: StateInterpreter[StateType, ObsType],
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
train_initial_states: Sequence[InitialStateType] | None = None,
|
||||
val_initial_states: Sequence[InitialStateType] | None = None,
|
||||
test_initial_states: Sequence[InitialStateType] | None = None,
|
||||
buffer_size: int = 20000,
|
||||
episode_per_iter: int = 1000,
|
||||
update_kwargs: dict[str, Any] = cast(Dict[str, Any], None),
|
||||
):
|
||||
self.simulator_fn = simulator_fn # type: ignore
|
||||
self.state_interpreter = state_interpreter
|
||||
self.action_interpreter = action_interpreter
|
||||
self.policy = policy
|
||||
self.reward = reward
|
||||
self.train_initial_states = train_initial_states
|
||||
self.val_initial_states = val_initial_states
|
||||
self.test_initial_states = test_initial_states
|
||||
self.buffer_size = buffer_size
|
||||
self.episode_per_iter = episode_per_iter
|
||||
self.update_kwargs = update_kwargs or {}
|
||||
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
if self.train_initial_states is not None:
|
||||
_logger.info("Training initial states collection size: %d", len(self.train_initial_states))
|
||||
# Implement fast_dev_run here.
|
||||
train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(train_initial_states, repeat=-1, shuffle=True)
|
||||
return super().train_seed_iterator()
|
||||
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
if self.val_initial_states is not None:
|
||||
_logger.info("Validation initial states collection size: %d", len(self.val_initial_states))
|
||||
val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(val_initial_states, repeat=1)
|
||||
return super().val_seed_iterator()
|
||||
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
if self.test_initial_states is not None:
|
||||
_logger.info("Testing initial states collection size: %d", len(self.test_initial_states))
|
||||
test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(test_initial_states, repeat=1)
|
||||
return super().test_seed_iterator()
|
||||
|
||||
def train(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
"""Create a collector and collects ``episode_per_iter`` episodes.
|
||||
Update the policy on the collected replay buffer.
|
||||
"""
|
||||
self.policy.train()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)))
|
||||
|
||||
# Number of episodes collected in each training iteration can be overridden by fast dev run.
|
||||
if self.trainer.fast_dev_run is not None:
|
||||
episodes = self.trainer.fast_dev_run
|
||||
else:
|
||||
episodes = self.episode_per_iter
|
||||
|
||||
col_result = collector.collect(n_episode=episodes)
|
||||
update_result = self.policy.update(sample_size=0, buffer=collector.buffer, **self.update_kwargs)
|
||||
res = {**col_result, **update_result}
|
||||
self.log_dict(res)
|
||||
return res
|
||||
|
||||
def validate(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
self.policy.eval()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
test_collector = Collector(self.policy, vector_env)
|
||||
res = test_collector.collect(n_step=INF * len(vector_env))
|
||||
self.log_dict(res)
|
||||
return res
|
||||
|
||||
def test(self, vector_env: FiniteVectorEnv) -> dict[str, Any]:
|
||||
self.policy.eval()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
test_collector = Collector(self.policy, vector_env)
|
||||
res = test_collector.collect(n_step=INF * len(vector_env))
|
||||
self.log_dict(res)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _random_subset(name: str, collection: Sequence[T], size: int | None) -> Sequence[T]:
|
||||
if size is None:
|
||||
# Size = None -> original collection
|
||||
return collection
|
||||
order = np.random.permutation(len(collection))
|
||||
res = [collection[o] for o in order[:size]]
|
||||
_logger.info(
|
||||
"Fast running in development mode. Cut %s initial states from %d to %d.",
|
||||
name,
|
||||
len(collection),
|
||||
len(res),
|
||||
)
|
||||
return res
|
||||
@@ -1,7 +1,21 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .data_queue import *
|
||||
from .env_wrapper import *
|
||||
from .finite_env import *
|
||||
from .log import *
|
||||
from .data_queue import DataQueue
|
||||
from .env_wrapper import EnvWrapper, EnvWrapperStatus
|
||||
from .finite_env import FiniteEnvType, vectorize_env
|
||||
from .log import ConsoleWriter, CsvWriter, LogBuffer, LogCollector, LogLevel, LogWriter
|
||||
|
||||
__all__ = [
|
||||
"LogLevel",
|
||||
"DataQueue",
|
||||
"EnvWrapper",
|
||||
"FiniteEnvType",
|
||||
"LogCollector",
|
||||
"LogWriter",
|
||||
"vectorize_env",
|
||||
"ConsoleWriter",
|
||||
"CsvWriter",
|
||||
"EnvWrapperStatus",
|
||||
"LogBuffer",
|
||||
]
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from queue import Empty
|
||||
from typing import TypeVar, Generic, Sequence, cast
|
||||
from typing import Any, Generator, Generic, Sequence, TypeVar, cast
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
@@ -60,7 +62,7 @@ class DataQueue(Generic[T]):
|
||||
shuffle: bool = True,
|
||||
producer_num_workers: int = 0,
|
||||
queue_maxsize: int = 0,
|
||||
):
|
||||
) -> None:
|
||||
if queue_maxsize == 0:
|
||||
if os.cpu_count() is not None:
|
||||
queue_maxsize = cast(int, os.cpu_count())
|
||||
@@ -78,14 +80,14 @@ class DataQueue(Generic[T]):
|
||||
self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
|
||||
self._done = multiprocessing.Value("i", 0)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> DataQueue:
|
||||
self.activate()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
def cleanup(self) -> None:
|
||||
with self._done.get_lock():
|
||||
self._done.value += 1
|
||||
for repeat in range(500):
|
||||
@@ -105,7 +107,7 @@ class DataQueue(Generic[T]):
|
||||
break
|
||||
_logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}")
|
||||
|
||||
def get(self, block=True):
|
||||
def get(self, block: bool = True) -> Any:
|
||||
if not hasattr(self, "_first_get"):
|
||||
self._first_get = True
|
||||
if self._first_get:
|
||||
@@ -120,17 +122,17 @@ class DataQueue(Generic[T]):
|
||||
if self._done.value:
|
||||
raise StopIteration # pylint: disable=raise-missing-from
|
||||
|
||||
def put(self, obj, block=True, timeout=None):
|
||||
return self._queue.put(obj, block=block, timeout=timeout)
|
||||
def put(self, obj: Any, block: bool = True, timeout: int = None) -> None:
|
||||
self._queue.put(obj, block=block, timeout=timeout)
|
||||
|
||||
def mark_as_done(self):
|
||||
def mark_as_done(self) -> None:
|
||||
with self._done.get_lock():
|
||||
self._done.value = 1
|
||||
|
||||
def done(self):
|
||||
def done(self) -> int:
|
||||
return self._done.value
|
||||
|
||||
def activate(self):
|
||||
def activate(self) -> DataQueue:
|
||||
if self._activated:
|
||||
raise ValueError("DataQueue can not activate twice.")
|
||||
thread = threading.Thread(target=self._producer, daemon=True)
|
||||
@@ -138,18 +140,20 @@ class DataQueue(Generic[T]):
|
||||
self._activated = True
|
||||
return self
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
_logger.debug(f"__del__ of {__name__}.DataQueue")
|
||||
self.cleanup()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Generator[Any, None, None]:
|
||||
if not self._activated:
|
||||
raise ValueError(
|
||||
"Need to call activate() to launch a daemon worker to produce data into data queue before using it."
|
||||
"Need to call activate() to launch a daemon worker "
|
||||
"to produce data into data queue before using it. "
|
||||
"You probably have forgotten to use the DataQueue in a with block.",
|
||||
)
|
||||
return self._consumer()
|
||||
|
||||
def _consumer(self):
|
||||
def _consumer(self) -> Generator[Any, None, None]:
|
||||
while True:
|
||||
try:
|
||||
yield self.get()
|
||||
@@ -157,23 +161,25 @@ class DataQueue(Generic[T]):
|
||||
_logger.debug("Data consumer timed-out from get.")
|
||||
return
|
||||
|
||||
def _producer(self):
|
||||
def _producer(self) -> None:
|
||||
# pytorch dataloader is used here only because we need its sampler and multi-processing
|
||||
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
|
||||
|
||||
dataloader = DataLoader(
|
||||
cast(Dataset[T], self.dataset),
|
||||
batch_size=None,
|
||||
num_workers=self.producer_num_workers,
|
||||
shuffle=self.shuffle,
|
||||
collate_fn=lambda t: t, # identity collate fn
|
||||
)
|
||||
repeat = 10**18 if self.repeat == -1 else self.repeat
|
||||
for _rep in range(repeat):
|
||||
for data in dataloader:
|
||||
if self._done.value:
|
||||
# Already done.
|
||||
return
|
||||
self._queue.put(data)
|
||||
_logger.debug(f"Dataloader loop done. Repeat {_rep}.")
|
||||
self.mark_as_done()
|
||||
try:
|
||||
dataloader = DataLoader(
|
||||
cast(Dataset[T], self.dataset),
|
||||
batch_size=None,
|
||||
num_workers=self.producer_num_workers,
|
||||
shuffle=self.shuffle,
|
||||
collate_fn=lambda t: t, # identity collate fn
|
||||
)
|
||||
repeat = 10**18 if self.repeat == -1 else self.repeat
|
||||
for _rep in range(repeat):
|
||||
for data in dataloader:
|
||||
if self._done.value:
|
||||
# Already done.
|
||||
return
|
||||
self._queue.put(data)
|
||||
_logger.debug(f"Dataloader loop done. Repeat {_rep}.")
|
||||
finally:
|
||||
self.mark_as_done()
|
||||
|
||||
@@ -4,14 +4,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import Callable, Any, Iterable, Iterator, Generic, cast
|
||||
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast
|
||||
|
||||
import gym
|
||||
from gym import Space
|
||||
|
||||
from qlib.rl.aux_info import AuxiliaryInfoCollector
|
||||
from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType
|
||||
from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, StateInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .finite_env import generate_nan_observation
|
||||
@@ -28,7 +29,7 @@ class InfoDict(TypedDict):
|
||||
|
||||
aux_info: dict
|
||||
"""Any information depends on auxiliary info collector."""
|
||||
log: dict[str, Any]
|
||||
log: Dict[str, Any]
|
||||
"""Collected by LogCollector."""
|
||||
|
||||
|
||||
@@ -42,14 +43,15 @@ class EnvWrapperStatus(TypedDict):
|
||||
|
||||
cur_step: int
|
||||
done: bool
|
||||
initial_state: Any | None
|
||||
initial_state: Optional[Any]
|
||||
obs_history: list
|
||||
action_history: list
|
||||
reward_history: list
|
||||
|
||||
|
||||
class EnvWrapper(
|
||||
gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]
|
||||
gym.Env[ObsType, PolicyActType],
|
||||
Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType],
|
||||
):
|
||||
"""Qlib-based RL environment, subclassing ``gym.Env``.
|
||||
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
|
||||
@@ -97,11 +99,11 @@ class EnvWrapper(
|
||||
simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]],
|
||||
state_interpreter: StateInterpreter[StateType, ObsType],
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
|
||||
seed_iterator: Iterable[InitialStateType] | None,
|
||||
reward_fn: Reward | None = None,
|
||||
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
|
||||
logger: LogCollector | None = None,
|
||||
):
|
||||
seed_iterator: Optional[Iterable[InitialStateType]],
|
||||
reward_fn: Reward = None,
|
||||
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] = None,
|
||||
logger: LogCollector = None,
|
||||
) -> None:
|
||||
# Assign weak reference to wrapper.
|
||||
#
|
||||
# Use weak reference here, because:
|
||||
@@ -135,11 +137,11 @@ class EnvWrapper(
|
||||
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
def action_space(self) -> Space:
|
||||
return self.action_interpreter.action_space
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> Space:
|
||||
return self.state_interpreter.observation_space
|
||||
|
||||
def reset(self, **kwargs: Any) -> ObsType:
|
||||
@@ -191,7 +193,7 @@ class EnvWrapper(
|
||||
self.seed_iterator = None
|
||||
return generate_nan_observation(self.observation_space)
|
||||
|
||||
def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]:
|
||||
def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, float, bool, InfoDict]:
|
||||
"""Environment step.
|
||||
|
||||
See the code along with comments to get a sequence of things happening here.
|
||||
@@ -245,5 +247,5 @@ class EnvWrapper(
|
||||
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
|
||||
return obs, rew, done, info_dict
|
||||
|
||||
def render(self):
|
||||
def render(self, mode: str = "human") -> None:
|
||||
raise NotImplementedError("Render is not implemented in EnvWrapper.")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user