mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
95 Commits
mini_proje
...
6cma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f5f3a6af0 | ||
|
|
2f8fc8d28a | ||
|
|
3e9ccd3ad2 | ||
|
|
94268619c4 | ||
|
|
8d60a6a02b | ||
|
|
7234308651 | ||
|
|
acf5df27ce | ||
|
|
37a59f28d3 | ||
|
|
b084c352f5 | ||
|
|
9e22e5168b | ||
|
|
dceff7b471 | ||
|
|
7f1e8c5206 | ||
|
|
46264dfec9 | ||
|
|
754799ab05 | ||
|
|
32c3070b73 | ||
|
|
40de67265a | ||
|
|
e6f9a94fc5 | ||
|
|
73937863f1 | ||
|
|
d010219ba6 | ||
|
|
4fc8a5f25f | ||
|
|
0e8bfcb5d3 | ||
|
|
e457ca8511 | ||
|
|
4dbb8ecb86 | ||
|
|
653c082e7a | ||
|
|
f98e04ca9d | ||
|
|
76f2fb1a1a | ||
|
|
5eb5ac1f1f | ||
|
|
6295939346 | ||
|
|
5f3e322784 | ||
|
|
691b7f1f60 | ||
|
|
d8fc9aea6b | ||
|
|
d8764660dc | ||
|
|
7f08e6c7b3 | ||
|
|
0f3abfed74 | ||
|
|
44ce91ee9d | ||
|
|
ebb8ec34f3 | ||
|
|
4fe3ffccfd | ||
|
|
2f5ce3dc01 | ||
|
|
756bd0f65b | ||
|
|
667fb0e4d9 | ||
|
|
f326f83fae | ||
|
|
cbd69fb0ed | ||
|
|
5e3924d7a6 | ||
|
|
57f9813f85 | ||
|
|
26d24b5b23 | ||
|
|
b8bb3adbc8 | ||
|
|
ea10da32ba | ||
|
|
577923a9f0 | ||
|
|
9d8a8c6f13 | ||
|
|
d44175e425 | ||
|
|
5b73b80293 | ||
|
|
6a47416a2d | ||
|
|
4f5ae4d224 | ||
|
|
7e5bab599a | ||
|
|
c4ee9ff882 | ||
|
|
cc01812c62 | ||
|
|
0c4db8b0f8 | ||
|
|
e47b0f1c50 | ||
|
|
994f89319d | ||
|
|
b51e881be3 | ||
|
|
8802653bb9 | ||
|
|
82afd6a67a | ||
|
|
4001a5d157 | ||
|
|
ff2154c618 | ||
|
|
a82cc0b129 | ||
|
|
3b471a0fe3 | ||
|
|
b46bf8283d | ||
|
|
e182124e75 | ||
|
|
35794846ff | ||
|
|
49a5bccfec | ||
|
|
59fbf23a71 | ||
|
|
94e420f755 | ||
|
|
2fae407b19 | ||
|
|
67d618d4b2 | ||
|
|
fb5888be9e | ||
|
|
08de1a1874 | ||
|
|
1861c8edaf | ||
|
|
3c62d131a5 | ||
|
|
bc06f0301e | ||
|
|
84e9df2603 | ||
|
|
216a8ec2de | ||
|
|
54928e956d | ||
|
|
4fa37a96bf | ||
|
|
b22ecd145b | ||
|
|
bee05f56ef | ||
|
|
e762548295 | ||
|
|
7c64ea5c08 | ||
|
|
80bf16f9a8 | ||
|
|
dffaeaf07b | ||
|
|
157ef529bb | ||
|
|
ae85562a03 | ||
|
|
1d65d28b28 | ||
|
|
e78fe48a26 | ||
|
|
8480407150 | ||
|
|
a25b736c64 |
6
.github/labeler.yml
vendored
Normal file
6
.github/labeler.yml
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
documentation:
|
||||
- 'docs/**/*'
|
||||
- '**/*.md'
|
||||
|
||||
waiting for triage:
|
||||
- any: ['**/*', '!docs/**/*', '!**/*.md']
|
||||
14
.github/workflows/labeler.yml
vendored
Normal file
14
.github/workflows/labeler.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
name: "Add label automatically"
|
||||
on:
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v4
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
2
.github/workflows/test_qlib_from_pip.yml
vendored
2
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
|
||||
31
.github/workflows/test_qlib_from_source.yml
vendored
31
.github/workflows/test_qlib_from_source.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -28,8 +28,10 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Update pip to the latest version
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0.1
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pip==23.0.1
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
@@ -37,15 +39,13 @@ jobs:
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Installing pytorch for ubuntu
|
||||
if: ${{ matrix.os == 'ubuntu-18.04' || matrix.os == 'ubuntu-20.04' }}
|
||||
if: ${{ matrix.os == 'ubuntu-20.04' || matrix.os == 'ubuntu-22.04' }}
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
- name: Installing pytorch for windows
|
||||
if: ${{ matrix.os == 'windows-latest' }}
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Set up Python tools
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
- name: Make html with sphinx
|
||||
run: |
|
||||
cd docs
|
||||
sphinx-build -b html . build
|
||||
sphinx-build -W --keep-going -b html . _build
|
||||
cd ..
|
||||
|
||||
# Check Qlib with pylint
|
||||
@@ -87,9 +87,10 @@ jobs:
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
|
||||
- name: Check Qlib with pylint
|
||||
run: |
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
@@ -119,6 +120,11 @@ jobs:
|
||||
run: |
|
||||
mypy qlib --install-types --non-interactive || true
|
||||
mypy qlib --verbose
|
||||
|
||||
- name: Check Qlib ipynb with nbqa
|
||||
run: |
|
||||
nbqa black . -l 120 --check --diff
|
||||
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$'
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
@@ -137,12 +143,15 @@ jobs:
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
|
||||
# Run after data downloads
|
||||
- name: Check Qlib ipynb with nbconvert
|
||||
run: |
|
||||
# add more ipynb files in future
|
||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
# Version 0.52.0 of numba must be installed manually in CI, otherwise it will cause incompatibility with the latest version of numpy.
|
||||
python -m pip install numba==0.52.0
|
||||
# You must update numpy manually, because when installing python tools, it will try to uninstall numpy and cause CI to fail.
|
||||
python -m pip install --upgrade numpy
|
||||
python -m pip install numba
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -28,9 +28,10 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0.1
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# python -m pip is necessary to upgrade pip.
|
||||
python -m pip install pip==23.0.1
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -10,7 +10,6 @@ _build
|
||||
build/
|
||||
dist/
|
||||
|
||||
|
||||
*.pkl
|
||||
*.hd5
|
||||
*.csv
|
||||
@@ -24,6 +23,11 @@ qlib/VERSION.txt
|
||||
qlib/data/_libs/expanding.cpp
|
||||
qlib/data/_libs/rolling.cpp
|
||||
examples/estimator/estimator_example/
|
||||
examples/rl/data/
|
||||
examples/rl/checkpoints/
|
||||
examples/rl/outputs/
|
||||
examples/rl_order_execution/data/
|
||||
examples/rl_order_execution/outputs/
|
||||
|
||||
*.egg-info/
|
||||
|
||||
|
||||
12
CHANGES.rst
12
CHANGES.rst
@@ -85,7 +85,7 @@ Version 0.4.0
|
||||
-------------
|
||||
- Add `data` package that holds all data-related codes
|
||||
- Reform the data provider structure
|
||||
- Create a server for data centralized management `qlib-server<https://amc-msra.visualstudio.com/trading-algo/_git/qlib-server>`_
|
||||
- Create a server for data centralized management `qlib-server <https://amc-msra.visualstudio.com/trading-algo/_git/qlib-server>`_
|
||||
- Add a `ClientProvider` to work with server
|
||||
- Add a pluggable cache mechanism
|
||||
- Add a recursive backtracking algorithm to inspect the furthest reference date for an expression
|
||||
@@ -166,12 +166,12 @@ Version 0.8.0
|
||||
- Nested decision execution framework is supported
|
||||
- There are lots of changes for daily trading, it is hard to list all of them. But a few important changes could be noticed
|
||||
- The trading limitation is more accurate;
|
||||
- In `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/backtest/exchange.py#L160>`_, longing and shorting actions share the same action.
|
||||
- In `current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/backtest/exchange.py#L304>`_, the trading limitation is different between logging and shorting action.
|
||||
- In `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/backtest/exchange.py#L160>`__, longing and shorting actions share the same action.
|
||||
- In `current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/backtest/exchange.py#L304>`__, the trading limitation is different between logging and shorting action.
|
||||
- The constant is different when calculating annualized metrics.
|
||||
- `Current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/contrib/evaluate.py#L42>`_ uses more accurate constant than `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/evaluate.py#L22>`_
|
||||
- `A new version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/tests/data.py#L17>`_ of data is released. Due to the unstability of Yahoo data source, the data may be different after downloading data again.
|
||||
- Users could check out the backtesting results between `Current version <https://github.com/microsoft/qlib/tree/7c31012b507a3823117bddcc693fc64899460b2a/examples/benchmarks>`_ and `previous version <https://github.com/microsoft/qlib/tree/v0.7.2/examples/benchmarks>`_
|
||||
- `Current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/contrib/evaluate.py#L42>`_ uses more accurate constant than `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/evaluate.py#L22>`__
|
||||
- `A new version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/tests/data.py#L17>`__ of data is released. Due to the unstability of Yahoo data source, the data may be different after downloading data again.
|
||||
- Users could check out the backtesting results between `Current version <https://github.com/microsoft/qlib/tree/7c31012b507a3823117bddcc693fc64899460b2a/examples/benchmarks>`__ and `previous version <https://github.com/microsoft/qlib/tree/v0.7.2/examples/benchmarks>`__
|
||||
|
||||
|
||||
Other Versions
|
||||
|
||||
41
README.md
41
README.md
@@ -11,6 +11,8 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |
|
||||
| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|
|
||||
| HIST and IGMTF models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1040) on Apr 10, 2022 |
|
||||
| Qlib [notebook tutorial](https://github.com/microsoft/qlib/tree/main/examples/tutorial) | 📖 [Released](https://github.com/microsoft/qlib/pull/1037) on Apr 7, 2022 |
|
||||
| Ibovespa index data | :rice: [Released](https://github.com/microsoft/qlib/pull/990) on Apr 6, 2022 |
|
||||
@@ -40,13 +42,11 @@ Features released before 2021 are not listed here.
|
||||
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
|
||||
</p>
|
||||
|
||||
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
|
||||
|
||||
Qlib is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.
|
||||
An increasing number of SOTA Quant research works/papers in diverse paradigms are being released in Qlib to collaboratively solve key challenges in quantitative investment. For example, 1) using supervised learning to mine the market's complex non-linear patterns from rich and heterogeneous financial data, 2) modeling the dynamic nature of the financial market using adaptive concept drift technology, and 3) using reinforcement learning to model continuous investment decisions and assist investors in optimizing their trading strategies.
|
||||
|
||||
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
|
||||
|
||||
With Qlib, users can easily try ideas to create better Quant investment strategies.
|
||||
|
||||
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
<li type="circle"><a href="#auto-quant-research-workflow">Auto Quant Research Workflow</a></li>
|
||||
<li type="circle"><a href="#building-customized-quant-research-workflow-by-code">Building Customized Quant Research Workflow by Code</a></li></ul>
|
||||
<li><a href="#quant-dataset-zoo"><strong>Quant Dataset Zoo</strong></a></li>
|
||||
<li><a href="#learning-framework">Learning Framework</a></li>
|
||||
<li><a href="#more-about-qlib">More About Qlib</a></li>
|
||||
<li><a href="#offline-mode-and-online-mode">Offline Mode and Online Mode</a>
|
||||
<ul>
|
||||
@@ -105,21 +106,16 @@ Your feedbacks about the features are very important.
|
||||
# Framework of Qlib
|
||||
|
||||
<div style="align: center">
|
||||
<img src="docs/_static/img/framework.svg" />
|
||||
<img src="docs/_static/img/framework-abstract.jpg" />
|
||||
</div>
|
||||
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules, and each component could be used stand-alone.
|
||||
The high-level framework of Qlib can be found above(users can find the [detailed framework](https://qlib.readthedocs.io/en/latest/introduction/introduction.html#framework) of Qlib's design when getting into nitty gritty).
|
||||
The components are designed as loose-coupled modules, and each component could be used stand-alone.
|
||||
|
||||
| Name | Description |
|
||||
| ------ | ----- |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides a high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides a flexible interface to control the training process of models, which enable algorithms to control the training process. |
|
||||
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Decision Generator` will generate the target trading decisions(i.e. portfolio, orders) to be executed by `Execution Env` (i.e. the trading market). There may be multiple levels of `Trading Agent` and `Execution Env` (e.g. an _order executor trading agent and intraday order execution environment_ could behave like an interday trading environment and nested in _daily portfolio management trading agent and interday trading environment_ ) |
|
||||
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
|
||||
|
||||
* The modules with hand-drawn style are under development and will be released in the future.
|
||||
* The modules with dashed borders are highly user-customizable and extendible.
|
||||
|
||||
(p.s. framework image is created with https://draw.io/)
|
||||
Qlib provides a strong infrastructure to support Quant research. [Data](https://qlib.readthedocs.io/en/latest/component/data.html) is always an important part.
|
||||
A strong learning framework is designed to support diverse learning paradigms (e.g. [reinforcement learning](https://qlib.readthedocs.io/en/latest/component/rl.html), [supervised learning](https://qlib.readthedocs.io/en/latest/component/workflow.html#model-section)) and patterns at different levels(e.g. [market dynamic modeling](https://qlib.readthedocs.io/en/latest/component/meta.html)).
|
||||
By modeling the market, [trading strategies](https://qlib.readthedocs.io/en/latest/component/strategy.html) will generate trade decisions that will be executed. Multiple trading strategies and executors in different levels or granularities can be [nested to be optimized and run together](https://qlib.readthedocs.io/en/latest/component/highfreq.html).
|
||||
At last, a comprehensive [analysis](https://qlib.readthedocs.io/en/latest/component/report.html) will be provided and the model can be [served online](https://qlib.readthedocs.io/en/latest/component/online.html) in a low cost.
|
||||
|
||||
|
||||
# Quick Start
|
||||
@@ -170,7 +166,7 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
pip install .
|
||||
```
|
||||
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommanded approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommended approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
|
||||
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test_qlib_from_source.yml) may help you find the problem.
|
||||
|
||||
@@ -404,6 +400,17 @@ Dataset plays a very important role in Quant. Here is a list of the datasets bui
|
||||
[Here](https://qlib.readthedocs.io/en/latest/advanced/alpha.html) is a tutorial to build dataset with `Qlib`.
|
||||
Your PR to build new Quant dataset is highly welcomed.
|
||||
|
||||
|
||||
# Learning Framework
|
||||
Qlib is high customizable and a lot of its components are learnable.
|
||||
The learnable components are instances of `Forecast Model` and `Trading Agent`. They are learned based on the `Learning Framework` layer and then applied to multiple scenarios in `Workflow` layer.
|
||||
The learning framework leverages the `Workflow` layer as well(e.g. sharing `Information Extractor`, creating environments based on `Execution Env`).
|
||||
|
||||
Based on learning paradigms, they can be categorized into reinforcement learning and supervised learning.
|
||||
- For supervised learning, the detailed docs can be found [here](https://qlib.readthedocs.io/en/latest/component/model.html).
|
||||
- For reinforcement learning, the detailed docs can be found [here](https://qlib.readthedocs.io/en/latest/component/rl.html). Qlib's RL learning framework leverages `Execution Env` in `Workflow` layer to create environments. It's worth noting that `NestedExecutor` is supported as well. This empowers users to optimize different level of strategies/models/agents together (e.g. optimizing an order execution strategy for a specific portfolio management strategy).
|
||||
|
||||
|
||||
# More About Qlib
|
||||
If you want to have a quick glance at the most frequently used components of qlib, you can try notebooks [here](examples/tutorial/).
|
||||
|
||||
|
||||
@@ -17,4 +17,5 @@ help:
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
pip install -r requirements.txt
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
BIN
docs/_static/img/QlibRL_framework.png
vendored
Normal file
BIN
docs/_static/img/QlibRL_framework.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 91 KiB |
BIN
docs/_static/img/RL_framework.png
vendored
Normal file
BIN
docs/_static/img/RL_framework.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 30 KiB |
BIN
docs/_static/img/framework-abstract.jpg
vendored
Normal file
BIN
docs/_static/img/framework-abstract.jpg
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 65 KiB |
2
docs/_static/img/framework.svg
vendored
2
docs/_static/img/framework.svg
vendored
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 98 KiB After Width: | Height: | Size: 144 KiB |
@@ -38,7 +38,7 @@ Example
|
||||
|
||||
DIF = \frac{EMA(CLOSE, 12) - EMA(CLOSE, 26)}{CLOSE}
|
||||
|
||||
`DEA`means a 9-period EMA of the DIF.
|
||||
`DEA` means a 9-period EMA of the DIF.
|
||||
|
||||
.. math::
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ With this module, users can run their ``task`` automatically at different period
|
||||
|
||||
This whole process can be used in `Online Serving <../component/online.html>`_.
|
||||
|
||||
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`__.
|
||||
|
||||
Task Generating
|
||||
===============
|
||||
@@ -31,9 +31,10 @@ Here is the base class of ``TaskGen``:
|
||||
|
||||
.. autoclass:: qlib.workflow.task.gen.TaskGen
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
|
||||
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
|
||||
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`__.
|
||||
|
||||
Task Storing
|
||||
============
|
||||
@@ -53,8 +54,9 @@ Users need to provide the MongoDB URL and database name for using ``TaskManager`
|
||||
|
||||
.. autoclass:: qlib.workflow.task.manage.TaskManager
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
|
||||
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`__.
|
||||
|
||||
Task Training
|
||||
=============
|
||||
@@ -64,11 +66,13 @@ An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train``
|
||||
It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
|
||||
|
||||
.. autofunction:: qlib.workflow.task.manage.run_task
|
||||
:noindex:
|
||||
|
||||
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
|
||||
|
||||
.. autoclass:: qlib.model.trainer.Trainer
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
``Trainer`` will train a list of tasks and return a list of model recorders.
|
||||
``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
|
||||
@@ -24,8 +24,8 @@ The introduction of ``Data Layer`` includes the following parts.
|
||||
Here is a typical example of Qlib data workflow
|
||||
|
||||
- Users download data and converting data into Qlib format(with filename suffix `.bin`). In this step, typically only some basic data are stored on disk(such as OHLCV).
|
||||
- Creating some basic features based on Qlib's expression Engine(e.g. "Ref($close, 60) / $close", the return of last 60 trading days). Supported operators in the expression engine can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/data/ops.py>`_. This step is typically implemented in Qlib's `Data Loader <https://qlib.readthedocs.io/en/latest/component/data.html#data-loader>`_ which is a component of `Data Handler <https://qlib.readthedocs.io/en/latest/component/data.html#data-handler>`_ .
|
||||
- If users require more complicated data processing (e.g. data normalization), `Data Handler <https://qlib.readthedocs.io/en/latest/component/data.html#data-handler>`_ support user-customized processors to process data(some predefined processors can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`_). The processors are different from operators in expression engine. It is designed for some complicated data processing methods which is hard to supported in operators in expression engine.
|
||||
- Creating some basic features based on Qlib's expression Engine(e.g. "Ref($close, 60) / $close", the return of last 60 trading days). Supported operators in the expression engine can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/data/ops.py>`__. This step is typically implemented in Qlib's `Data Loader <https://qlib.readthedocs.io/en/latest/component/data.html#data-loader>`_ which is a component of `Data Handler <https://qlib.readthedocs.io/en/latest/component/data.html#data-handler>`_ .
|
||||
- If users require more complicated data processing (e.g. data normalization), `Data Handler <https://qlib.readthedocs.io/en/latest/component/data.html#data-handler>`_ support user-customized processors to process data(some predefined processors can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`__). The processors are different from operators in expression engine. It is designed for some complicated data processing methods which is hard to supported in operators in expression engine.
|
||||
- At last, `Dataset <https://qlib.readthedocs.io/en/latest/component/data.html#dataset>`_ is responsible to prepare model-specific dataset from the processed data of Data Handler
|
||||
|
||||
Data Preparation
|
||||
@@ -37,7 +37,7 @@ Qlib Format Data
|
||||
We've specially designed a data structure to manage financial data, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information.
|
||||
Such data will be stored with filename suffix `.bin` (We'll call them `.bin` file, `.bin` format, or qlib format). `.bin` file is designed for scientific computing on finance data.
|
||||
|
||||
``Qlib`` provides two different off-the-shelf datasets, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
|
||||
``Qlib`` provides two different off-the-shelf datasets, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`__:
|
||||
|
||||
======================== ================= ================
|
||||
Dataset US Market China Market
|
||||
@@ -47,7 +47,7 @@ Alpha360 √ √
|
||||
Alpha158 √ √
|
||||
======================== ================= ================
|
||||
|
||||
Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency dataset example through this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
|
||||
Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency dataset example through this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`__.
|
||||
|
||||
Qlib Format Dataset
|
||||
-------------------
|
||||
@@ -332,6 +332,7 @@ Here are some interfaces of the ``QlibDataLoader`` class:
|
||||
|
||||
.. autoclass:: qlib.data.dataset.loader.DataLoader
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
API
|
||||
---
|
||||
@@ -361,6 +362,7 @@ Here are some important interfaces that ``DataHandlerLP`` provides:
|
||||
|
||||
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
|
||||
:members: __init__, fetch, get_cols
|
||||
:noindex:
|
||||
|
||||
|
||||
If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``.
|
||||
@@ -451,6 +453,7 @@ The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most im
|
||||
|
||||
.. autoclass:: qlib.data.dataset.__init__.DatasetH
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
API
|
||||
---
|
||||
@@ -470,9 +473,11 @@ Global Memory Cache
|
||||
|
||||
.. autoclass:: qlib.data.cache.MemCacheUnit
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
.. autoclass:: qlib.data.cache.MemCache
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
|
||||
ExpressionCache
|
||||
@@ -487,6 +492,7 @@ The following shows the details about the interfaces:
|
||||
|
||||
.. autoclass:: qlib.data.cache.ExpressionCache
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
``Qlib`` has currently provided implemented disk cache `DiskExpressionCache` which inherits from `ExpressionCache` . The expressions data will be stored in the disk.
|
||||
|
||||
@@ -502,6 +508,7 @@ The following shows the details about the interfaces:
|
||||
|
||||
.. autoclass:: qlib.data.cache.DatasetCache
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
``Qlib`` has currently provided implemented disk cache `DiskDatasetCache` which inherits from `DatasetCache` . The datasets' data will be stored in the disk.
|
||||
|
||||
@@ -512,7 +519,7 @@ Data and Cache File Structure
|
||||
|
||||
We've specially designed a file structure to manage data and cache, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information. The file structure of data and cache is listed as follows.
|
||||
|
||||
.. code-block:: json
|
||||
.. code-block::
|
||||
|
||||
- data/
|
||||
[raw data] updated by data providers
|
||||
|
||||
@@ -8,31 +8,33 @@ Design of Nested Decision Execution Framework for High-Frequency Trading
|
||||
Introduction
|
||||
============
|
||||
|
||||
Daily trading (e.g. portfolio management) and intraday trading (e.g. orders execution) are two hot topics in Quant investment and usually studied separately.
|
||||
Daily trading (e.g. portfolio management) and intraday trading (e.g. orders execution) are two hot topics in Quant investment and are usually studied separately.
|
||||
|
||||
To get the join trading performance of daily and intraday trading, they must interact with each other and run backtest jointly.
|
||||
In order to support the joint backtest strategies in multiple levels, a corresponding framework is required. None of the publicly available high-frequency trading frameworks considers multi-level joint trading, which make the backtesting aforementioned inaccurate.
|
||||
In order to support the joint backtest strategies at multiple levels, a corresponding framework is required. None of the publicly available high-frequency trading frameworks considers multi-level joint trading, which makes the backtesting aforementioned inaccurate.
|
||||
|
||||
Besides backtesting, the optimization of strategies from different levels is not standalone and can be affected by each other.
|
||||
For example, the best portfolio management strategy may change with the performance of order executions(e.g. a portfolio with higher turnover may becomes a better choice when we improve the order execution strategies).
|
||||
To achieve the overall good performance , it is necessary to consider the interaction of strategies in different level.
|
||||
For example, the best portfolio management strategy may change with the performance of order executions(e.g. a portfolio with higher turnover may become a better choice when we improve the order execution strategies).
|
||||
To achieve overall good performance, it is necessary to consider the interaction of strategies at a different levels.
|
||||
|
||||
Therefore, building a new framework for trading in multiple levels becomes necessary to solve the various problems mentioned above, for which we designed a nested decision execution framework that consider the interaction of strategies.
|
||||
Therefore, building a new framework for trading on multiple levels becomes necessary to solve the various problems mentioned above, for which we designed a nested decision execution framework that considers the interaction of strategies.
|
||||
|
||||
.. image:: ../_static/img/framework.svg
|
||||
|
||||
The design of the framework is shown in the yellow part in the middle of the figure above. Each level consists of ``Trading Agent`` and ``Execution Env``. ``Trading Agent`` has its own data processing module (``Information Extractor``), forecasting module (``Forecast Model``) and decision generator (``Decision Generator``). The trading algorithm generates the decisions by the ``Decision Generator`` based on the forecast signals output by the ``Forecast Module``, and the decisions generated by the trading algorithm are passed to the ``Execution Env``, which returns the execution results.
|
||||
|
||||
The frequency of trading algorithm, decision content and execution environment can be customized by users (e.g. intraday trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The flexibility of nested decision execution framework makes it easy for users to explore the effects of combining different levels of trading strategies and break down the optimization barriers between different levels of trading algorithm.
|
||||
The frequency of the trading algorithm, decision content and execution environment can be customized by users (e.g. intraday trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The flexibility of the nested decision execution framework makes it easy for users to explore the effects of combining different levels of trading strategies and break down the optimization barriers between different levels of the trading algorithm.
|
||||
|
||||
The optimization for the nested decision execution framework can be implemented with the support of `QlibRL <https://qlib.readthedocs.io/en/latest/component/rl.html>`_. To know more about how to use the QlibRL, go to API Reference: `RL API <../reference/api.html#rl>`_.
|
||||
|
||||
Example
|
||||
=======
|
||||
|
||||
An example of nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.
|
||||
An example of a nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.
|
||||
|
||||
|
||||
Besides, the above examples, here are some other related work about high-frequency trading in Qlib.
|
||||
Besides, the above examples, here are some other related works about high-frequency trading in Qlib.
|
||||
|
||||
- `Prediction with high-frequency data <https://github.com/microsoft/qlib/tree/main/examples/highfreq#benchmarks-performance-predicting-the-price-trend-in-high-frequency-data>`_
|
||||
- `Examples <https://github.com/microsoft/qlib/blob/main/examples/orderbook_data/>`_ to extract features form high-frequency data without fixed frequency.
|
||||
- `Examples <https://github.com/microsoft/qlib/blob/main/examples/orderbook_data/>`_ to extract features from high-frequency data without fixed frequency.
|
||||
- `A paper <https://github.com/microsoft/qlib/tree/high-freq-execution#high-frequency-execution>`_ for high-frequency trading.
|
||||
|
||||
@@ -20,6 +20,7 @@ The base class provides the following interfaces:
|
||||
|
||||
.. autoclass:: qlib.model.base.Model
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
``Qlib`` also provides a base class `qlib.model.base.ModelFT <../reference/api.html#qlib.model.base.ModelFT>`_, which includes the method for finetuning the model.
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
.. _online:
|
||||
.. _online_serving:
|
||||
|
||||
==============
|
||||
Online Serving
|
||||
@@ -32,21 +32,25 @@ Online Manager
|
||||
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
Online Strategy
|
||||
===============
|
||||
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
Online Tool
|
||||
===========
|
||||
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
Updater
|
||||
=======
|
||||
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
@@ -61,6 +61,7 @@ The ``ExpManager`` module in ``Qlib`` is responsible for managing different expe
|
||||
|
||||
.. autoclass:: qlib.workflow.expm.ExpManager
|
||||
:members: get_exp, list_experiments
|
||||
:noindex:
|
||||
|
||||
For other interfaces such as `create_exp`, `delete_exp`, please refer to `Experiment Manager API <../reference/api.html#experiment-manager>`_.
|
||||
|
||||
@@ -71,6 +72,7 @@ The ``Experiment`` class is solely responsible for a single experiment, and it w
|
||||
|
||||
.. autoclass:: qlib.workflow.exp.Experiment
|
||||
:members: get_recorder, list_recorders
|
||||
:noindex:
|
||||
|
||||
For other interfaces such as `search_records`, `delete_recorder`, please refer to `Experiment API <../reference/api.html#experiment>`_.
|
||||
|
||||
@@ -85,6 +87,7 @@ Here are some important APIs that are not included in the ``QlibRecorder``:
|
||||
|
||||
.. autoclass:: qlib.workflow.recorder.Recorder
|
||||
:members: list_artifacts, list_metrics, list_params, list_tags
|
||||
:noindex:
|
||||
|
||||
For other interfaces such as `save_objects`, `load_object`, please refer to `Recorder API <../reference/api.html#recorder>`_.
|
||||
|
||||
@@ -107,7 +110,7 @@ Here is a simple example of what is done in ``SigAnaRecord``, which users can re
|
||||
|
||||
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
|
||||
|
||||
Here is a simple exampke of what is done in ``PortAnaRecord``, which users can refer to if they want to do backtest based on their own prediction and label.
|
||||
Here is a simple example of what is done in ``PortAnaRecord``, which users can refer to if they want to do backtest based on their own prediction and label.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ API
|
||||
|
||||
.. automodule:: qlib.contrib.report.analysis_position.report
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
Graphical Result
|
||||
~~~~~~~~~~~~~~~~
|
||||
@@ -93,6 +94,7 @@ API
|
||||
|
||||
.. automodule:: qlib.contrib.report.analysis_position.score_ic
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
|
||||
Graphical Result
|
||||
@@ -151,6 +153,7 @@ API
|
||||
|
||||
.. automodule:: qlib.contrib.report.analysis_position.risk_analysis
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
|
||||
Graphical Result
|
||||
@@ -174,6 +177,7 @@ Graphical Result
|
||||
The `Information Ratio` without cost.
|
||||
- `excess_return_with_cost`
|
||||
The `Information Ratio` with cost.
|
||||
|
||||
To know more about `Information Ratio`, please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
|
||||
- `max_drawdown`
|
||||
- `excess_return_without_cost`
|
||||
@@ -269,6 +273,7 @@ API
|
||||
|
||||
.. automodule:: qlib.contrib.report.analysis_model.analysis_model_performance
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
|
||||
Graphical Results
|
||||
|
||||
49
docs/component/rl/framework.rst
Normal file
49
docs/component/rl/framework.rst
Normal file
@@ -0,0 +1,49 @@
|
||||
The Framework of QlibRL
|
||||
=======================
|
||||
|
||||
QlibRL contains a full set of components that cover the entire lifecycle of an RL pipeline, including building the simulator of the market, shaping states & actions, training policies (strategies), and backtesting strategies in the simulated environment.
|
||||
|
||||
QlibRL is basically implemented with the support of Tianshou and Gym frameworks. The high-level structure of QlibRL is demonstrated below:
|
||||
|
||||
.. image:: ../../_static/img/QlibRL_framework.png
|
||||
:width: 600
|
||||
:align: center
|
||||
|
||||
Here, we briefly introduce each component in the figure.
|
||||
|
||||
EnvWrapper
|
||||
------------
|
||||
EnvWrapper is the complete capsulation of the simulated environment. It receives actions from outside (policy/strategy/agent), simulates the changes in the market, and then replies rewards and updated states, thus forming an interaction loop.
|
||||
|
||||
In QlibRL, EnvWrapper is a subclass of gym.Env, so it implements all necessary interfaces of gym.Env. Any classes or pipelines that accept gym.Env should also accept EnvWrapper. Developers do not need to implement their own EnvWrapper to build their own environment. Instead, they only need to implement 4 components of the EnvWrapper:
|
||||
|
||||
- `Simulator`
|
||||
The simulator is the core component responsible for the environment simulation. Developers could implement all the logic that is directly related to the environment simulation in the Simulator in any way they like. In QlibRL, there are already two implementations of Simulator for single asset trading: 1) ``SingleAssetOrderExecution``, which is built based on Qlib's backtest toolkits and hence considers a lot of practical trading details but is slow. 2) ``SimpleSingleAssetOrderExecution``, which is built based on a simplified trading simulator, which ignores a lot of details (e.g. trading limitations, rounding) but is quite fast.
|
||||
- `State interpreter`
|
||||
The state interpreter is responsible for "interpret" states in the original format (format provided by the simulator) into states in a format that the policy could understand. For example, transform unstructured raw features into numerical tensors.
|
||||
- `Action interpreter`
|
||||
The action interpreter is similar to the state interpreter. But instead of states, it interprets actions generated by the policy, from the format provided by the policy to the format that is acceptable to the simulator.
|
||||
- `Reward function`
|
||||
The reward function returns a numerical reward to the policy after each time the policy takes an action.
|
||||
|
||||
EnvWrapper will organically organize these components. Such decomposition allows for better flexibility in development. For example, if the developers want to train multiple types of policies in the same environment, they only need to design one simulator and design different state interpreters/action interpreters/reward functions for different types of policies.
|
||||
|
||||
QlibRL has well-defined base classes for all these 4 components. All the developers need to do is define their own components by inheriting the base classes and then implementing all interfaces required by the base classes. The API for the above base components can be found `here <../../reference/api.html#module-qlib.rl>`__.
|
||||
|
||||
Policy
|
||||
------------
|
||||
QlibRL directly uses Tianshou's policy. Developers could use policies provided by Tianshou off the shelf, or implement their own policies by inheriting Tianshou's policies.
|
||||
|
||||
Training Vessel & Trainer
|
||||
-------------------------
|
||||
As stated by their names, training vessels and trainers are helper classes used in training. A training vessel is a ship that contains a simulator/interpreters/reward function/policy, and it controls algorithm-related parts of training. Correspondingly, the trainer is responsible for controlling the runtime parts of training.
|
||||
|
||||
As you may have noticed, a training vessel itself holds all the required components to build an EnvWrapper rather than holding an instance of EnvWrapper directly. This allows the training vessel to create duplicates of EnvWrapper dynamically when necessary (for example, under parallel training).
|
||||
|
||||
With a training vessel, the trainer could finally launch the training pipeline by simple, Scikit-learn-like interfaces (i.e., ``trainer.fit()``).
|
||||
|
||||
The API for Trainer and TrainingVessel and can be found `here <../../reference/api.html#module-qlib.rl.trainer>`__.
|
||||
|
||||
The RL module is designed in a loosely-coupled way. Currently, RL examples are integrated with concrete business logic.
|
||||
But the core part of RL is much simpler than what you see.
|
||||
To demonstrate the simple core of RL, `a dedicated notebook <https://github.com/microsoft/qlib/tree/main/examples/rl/simple_example.ipynb>`__ for RL without business loss is created.
|
||||
50
docs/component/rl/overall.rst
Normal file
50
docs/component/rl/overall.rst
Normal file
@@ -0,0 +1,50 @@
|
||||
=====================================================
|
||||
Reinforcement Learning in Quantitative Trading
|
||||
=====================================================
|
||||
|
||||
Reinforcement Learning
|
||||
======================
|
||||
Different from supervised learning tasks such as classification tasks and regression tasks. Another important paradigm in machine learning is Reinforcement Learning,
|
||||
which attempts to optimize an accumulative numerical reward signal by directly interacting with the environment under a few assumptions such as Markov Decision Process(MDP).
|
||||
|
||||
As demonstrated in the following figure, an RL system consists of four elements, 1)the agent 2) the environment the agent interacts with 3) the policy that the agent follows to take actions on the environment and 4)the reward signal from the environment to the agent.
|
||||
In general, the agent can perceive and interpret its environment, take actions and learn through reward, to seek long-term and maximum overall reward to achieve an optimal solution.
|
||||
|
||||
.. image:: ../../_static/img/RL_framework.png
|
||||
:width: 300
|
||||
:align: center
|
||||
|
||||
RL attempts to learn to produce actions by trial and error.
|
||||
By sampling actions and then observing which one leads to our desired outcome, a policy is obtained to generate optimal actions.
|
||||
In contrast to supervised learning, RL learns this not from a label but from a time-delayed label called a reward.
|
||||
This scalar value lets us know whether the current outcome is good or bad.
|
||||
In a word, the target of RL is to take actions to maximize reward.
|
||||
|
||||
The Qlib Reinforcement Learning toolkit (QlibRL) is an RL platform for quantitative investment, which provides support to implement the RL algorithms in Qlib.
|
||||
|
||||
|
||||
Potential Application Scenarios in Quantitative Trading
|
||||
=======================================================
|
||||
RL methods have already achieved outstanding achievement in many applications, such as game playing, resource allocating, recommendation, marketing and advertising, etc.
|
||||
Investment is always a continuous process, taking the stock market as an example, investors need to control their positions and stock holdings by one or more buying and selling behaviors, to maximize the investment returns.
|
||||
Besides, each buy and sell decision is made by investors after fully considering the overall market information and stock information.
|
||||
From the view of an investor, the process could be described as a continuous decision-making process generated according to interaction with the market, such problems could be solved by the RL algorithms.
|
||||
Following are some scenarios where RL can potentially be used in quantitative investment.
|
||||
|
||||
Portfolio Construction
|
||||
----------------------
|
||||
Portfolio construction is a process of selecting securities optimally by taking a minimum risk to achieve maximum returns. With an RL-based solution, an agent allocates stocks at every time step by obtaining information for each stock and the market. The key is to develop of policy for building a portfolio and make the policy able to pick the optimal portfolio.
|
||||
|
||||
Order Execution
|
||||
---------------
|
||||
As a fundamental problem in algorithmic trading, order execution aims at fulfilling a specific trading order, either liquidation or acquirement, for a given instrument. Essentially, the goal of order execution is twofold: it not only requires to fulfill the whole order but also targets a more economical execution with maximizing profit gain (or minimizing capital loss). The order execution with only one order of liquidation or acquirement is called single-asset order execution.
|
||||
|
||||
Considering stock investment always aim to pursue long-term maximized profits, it usually manifests as a sequential process of continuously adjusting the asset portfolios, execution for multiple orders, including order of liquidation and acquirement, brings more constraints and makes the sequence of execution for different orders should be considered, e.g. before executing an order to buy some stocks, we have to sell at least one stock. The order execution with multiple assets is called multi-asset order execution.
|
||||
|
||||
According to the order execution’s trait of sequential decision-making, an RL-based solution could be applied to solve the order execution. With an RL-based solution, an agent optimizes execution strategy by interacting with the market environment.
|
||||
|
||||
With QlibRL, the RL algorithm in the above scenarios can be easily implemented.
|
||||
|
||||
Nested Portfolio Construction and Order Executor
|
||||
------------------------------------------------
|
||||
QlibRL makes it possible to jointly optimize different levels of strategies/models/agents. Take `Nested Decision Execution Framework <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution>`_ as an example, the optimization of order execution strategy and portfolio management strategies can interact with each other to maximize returns.
|
||||
175
docs/component/rl/quickstart.rst
Normal file
175
docs/component/rl/quickstart.rst
Normal file
@@ -0,0 +1,175 @@
|
||||
|
||||
Quick Start
|
||||
============
|
||||
.. currentmodule:: qlib
|
||||
|
||||
QlibRL provides an example of an implementation of a single asset order execution task and the following is an example of the config file to train with QlibRL.
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
simulator:
|
||||
# Each step contains 30mins
|
||||
time_per_step: 30
|
||||
# Upper bound of volume, should be null or a float between 0 and 1, if it is a float, represent upper bound is calculated by the percentage of the market volume
|
||||
vol_limit: null
|
||||
env:
|
||||
# Concurrent environment workers.
|
||||
concurrency: 1
|
||||
# dummy or subproc or shmem. Corresponding to `parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.
|
||||
parallel_mode: dummy
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
# Candidate actions, it can be a list with length L: [a_1, a_2,..., a_L] or an integer n, in which case the list of length n+1 is auto-generated, i.e., [0, 1/n, 2/n,..., n/n].
|
||||
values: 14
|
||||
# Total number of steps (an upper-bound estimation)
|
||||
max_step: 8
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
kwargs:
|
||||
# Number of dimensions in data.
|
||||
data_dim: 6
|
||||
# Equal to the total number of records. For example, in SAOE per minute, data_ticks is the length of the day in minutes.
|
||||
data_ticks: 240
|
||||
# The total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps.
|
||||
max_step: 8
|
||||
# Provider of the processed data.
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
reward:
|
||||
class: PAPenaltyReward
|
||||
kwargs:
|
||||
# The penalty for a large volume in a short time.
|
||||
penalty: 100.0
|
||||
module_path: qlib.rl.order_execution.reward
|
||||
data:
|
||||
source:
|
||||
order_dir: ./data/training_order_split
|
||||
data_dir: ./data/pickle_dataframe/backtest
|
||||
# number of time indexes
|
||||
total_time: 240
|
||||
# start time index
|
||||
default_start_time: 0
|
||||
# end time index
|
||||
default_end_time: 240
|
||||
proc_data_dim: 6
|
||||
num_workers: 0
|
||||
queue_size: 20
|
||||
network:
|
||||
class: Recurrent
|
||||
module_path: qlib.rl.order_execution.network
|
||||
policy:
|
||||
class: PPO
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
runtime:
|
||||
seed: 42
|
||||
use_cuda: false
|
||||
trainer:
|
||||
max_epoch: 2
|
||||
# Number of episodes collected in each training iteration
|
||||
repeat_per_collect: 5
|
||||
earlystop_patience: 2
|
||||
# Episodes per collect at training.
|
||||
episode_per_collect: 20
|
||||
batch_size: 16
|
||||
# Perform validation every n iterations
|
||||
val_every_n_epoch: 1
|
||||
checkpoint_path: ./checkpoints
|
||||
checkpoint_every_n_iters: 1
|
||||
|
||||
|
||||
And the config file for backtesting:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
order_file: ./data/backtest_orders.csv
|
||||
start_time: "9:45"
|
||||
end_time: "14:44"
|
||||
qlib:
|
||||
provider_uri_1min: ./data/bin
|
||||
feature_root_dir: ./data/pickle
|
||||
# feature generated by today's information
|
||||
feature_columns_today: [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$volume",
|
||||
]
|
||||
# feature generated by yesterday's information
|
||||
feature_columns_yesterday: [
|
||||
"$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1",
|
||||
]
|
||||
exchange:
|
||||
# the expression for buying and selling stock limitation
|
||||
limit_threshold: ['$close == 0', '$close == 0']
|
||||
# deal price for buying and selling
|
||||
deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"]
|
||||
volume_threshold:
|
||||
# volume limits are both buying and selling, "cum" means that this is a cumulative value over time
|
||||
all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"]
|
||||
# the volume limits of buying
|
||||
buy: ["current", "$close"]
|
||||
# the volume limits of selling, "current" means that this is a real-time value and will not accumulate over time
|
||||
sell: ["current", "$close"]
|
||||
strategies:
|
||||
30min:
|
||||
class: TWAPStrategy
|
||||
module_path: qlib.contrib.strategy.rule_strategy
|
||||
kwargs: {}
|
||||
1day:
|
||||
class: SAOEIntStrategy
|
||||
module_path: qlib.rl.order_execution.strategy
|
||||
kwargs:
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
kwargs:
|
||||
max_step: 8
|
||||
data_ticks: 240
|
||||
data_dim: 6
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
kwargs:
|
||||
values: 14
|
||||
max_step: 8
|
||||
network:
|
||||
class: Recurrent
|
||||
module_path: qlib.rl.order_execution.network
|
||||
kwargs: {}
|
||||
policy:
|
||||
class: PPO
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
kwargs:
|
||||
lr: 1.0e-4
|
||||
# Local path to the latest model. The model is generated during training, so please run training first if you want to run backtest with a trained policy. You could also remove this parameter file to run backtest with a randomly initialized policy.
|
||||
weight_file: ./checkpoints/latest.pth
|
||||
# Concurrent environment workers.
|
||||
concurrency: 5
|
||||
|
||||
With the above config files, you can start training the agent by the following command:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m qlib.rl.contrib.train_onpolicy.py --config_path train_config.yml
|
||||
|
||||
After the training, you can backtest with the following command:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m qlib.rl.contrib.backtest.py --config_path backtest_config.yml
|
||||
|
||||
In that case, :class:`~qlib.rl.order_execution.simulator_qlib.SingleAssetOrderExecution` and :class:`~qlib.rl.order_execution.simulator_simple.SingleAssetOrderExecutionSimple` as examples for simulator, :class:`qlib.rl.order_execution.interpreter.FullHistoryStateInterpreter` and :class:`qlib.rl.order_execution.interpreter.CategoricalActionInterpreter` as examples for interpreter, :class:`qlib.rl.order_execution.policy.PPO` as an example for policy, and :class:`qlib.rl.order_execution.reward.PAPenaltyReward` as an example for reward.
|
||||
For the single asset order execution task, if developers have already defined their simulator/interpreters/reward function/policy, they could launch the training and backtest pipeline by simply modifying the corresponding settings in the config files.
|
||||
The details about the example can be found `here <https://github.com/microsoft/qlib/blob/main/examples/rl/README.md>`_.
|
||||
|
||||
In the future, we will provide more examples for different scenarios such as RL-based portfolio construction.
|
||||
10
docs/component/rl/toctree.rst
Normal file
10
docs/component/rl/toctree.rst
Normal file
@@ -0,0 +1,10 @@
|
||||
.. _rl:
|
||||
|
||||
========================================================================
|
||||
Reinforcement Learning in Quantitative Trading
|
||||
========================================================================
|
||||
|
||||
.. toctree::
|
||||
Overall <overall>
|
||||
Quick Start <quickstart>
|
||||
Framework <framework>
|
||||
@@ -80,6 +80,7 @@ TopkDropoutStrategy
|
||||
In most cases, ``TopkDrop`` algorithm sells and buys `Drop` stocks every trading day, which yields a turnover rate of 2$\times$`Drop`/$K$.
|
||||
|
||||
The following images illustrate a typical scenario.
|
||||
|
||||
.. image:: ../_static/img/topk_drop.png
|
||||
:alt: Topk-Drop
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ language = "en_US"
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This patterns also effect to html_static_path and html_extra_path
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "hidden"]
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = "sphinx"
|
||||
|
||||
@@ -15,7 +15,8 @@ Continuous Integration (CI) tools help you stick to the quality standards by run
|
||||
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
|
||||
|
||||
1. Qlib will check the code format with black. The PR will raise error if your code does not align to the standard of Qlib(e.g. a common error is the mixed use of space and tab).
|
||||
You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
You can fix the bug by inputting the following code in the command line.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@@ -32,7 +33,8 @@ When you submit a PR request, you can check whether your code passes the CI test
|
||||
|
||||
|
||||
3. Qlib will check your code style flake8. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L73).
|
||||
You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@@ -40,7 +42,8 @@ When you submit a PR request, you can check whether your code passes the CI test
|
||||
|
||||
|
||||
4. Qlib has integrated pre-commit, which will make it easier for developers to format their code.
|
||||
Just run the following two commands, and the code will be automatically formatted using black and flake8 when the git commit command is executed.
|
||||
|
||||
Just run the following two commands, and the code will be automatically formatted using black and flake8 when the git commit command is executed.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ If running on Windows, open **NFS** features and write correct **mount_path**, i
|
||||
* Open ``Programs and Features``.
|
||||
* Click ``Turn Windows features on or off``.
|
||||
* Scroll down and check the option ``Services for NFS``, then click OK
|
||||
|
||||
Reference address: https://graspingtech.com/mount-nfs-share-windows-10/
|
||||
2.config correct mount_path
|
||||
* In windows, mount path must be not exist path and root path,
|
||||
@@ -161,7 +162,7 @@ Limitations
|
||||
API
|
||||
***
|
||||
|
||||
The client is based on `python-socketio<https://python-socketio.readthedocs.io>`_ which is a framework that supports WebSocket client for Python language. The client can only propose requests and receive results, which do not include any calculating procedure.
|
||||
The client is based on `python-socketio <https://python-socketio.readthedocs.io>`_ which is a framework that supports WebSocket client for Python language. The client can only propose requests and receive results, which do not include any calculating procedure.
|
||||
|
||||
Class
|
||||
-----
|
||||
|
||||
@@ -33,7 +33,7 @@ Document Structure
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
:caption: COMPONENTS:
|
||||
:caption: MAIN COMPONENTS:
|
||||
|
||||
Workflow: Workflow Management <component/workflow.rst>
|
||||
Data Layer: Data Framework & Usage <component/data.rst>
|
||||
@@ -44,10 +44,11 @@ Document Structure
|
||||
Qlib Recorder: Experiment Management <component/recorder.rst>
|
||||
Analysis: Evaluation & Results Analysis <component/report.rst>
|
||||
Online Serving: Online Management & Strategy & Tool <component/online.rst>
|
||||
Reinforcement Learning <component/rl/toctree>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
:caption: ADVANCED TOPICS:
|
||||
:caption: OTHER COMPONENTS/FEATURES/TOPICS:
|
||||
|
||||
Building Formulaic Alphas <advanced/alpha.rst>
|
||||
Online & Offline mode <advanced/server.rst>
|
||||
@@ -55,6 +56,12 @@ Document Structure
|
||||
Task Management <advanced/task_management.rst>
|
||||
Point-In-Time database <advanced/PIT.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
:caption: FOR DEVELOPERS:
|
||||
|
||||
Code Standard & Development Guidance <developer/code_standard_and_dev_guide.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
:caption: REFERENCE:
|
||||
|
||||
@@ -15,38 +15,56 @@ With ``Qlib``, users can easily try their ideas to create better Quant investmen
|
||||
Framework
|
||||
=========
|
||||
|
||||
|
||||
.. image:: ../_static/img/framework.svg
|
||||
:align: center
|
||||
|
||||
|
||||
At the module level, Qlib is a platform that consists of above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
|
||||
|
||||
This framework may be intimidating for new users to Qlib. It tries to accurately include a lot of details of Qlib's design.
|
||||
For users new to Qlib, you can skip it first and read it later.
|
||||
|
||||
|
||||
======================== ==============================================================================
|
||||
Name Description
|
||||
======================== ==============================================================================
|
||||
`Infrastructure` layer `Infrastructure` layer provides underlying support for Quant research.
|
||||
`DataServer` provides high-performance infrastructure for users to manage
|
||||
and retrieve raw data. `Trainer` provides flexible interface to control
|
||||
the training process of models which enable algorithms controlling the
|
||||
training process.
|
||||
|
||||
`Workflow` layer `Workflow` layer covers the whole workflow of quantitative investment.
|
||||
`Information Extractor` extracts data for models. `Forecast Model` focuses
|
||||
on producing all kinds of forecast signals (e.g. *alpha*, risk) for other
|
||||
modules. With these signals `Decision Generator` will generate the target
|
||||
trading decisions(i.e. portfolio, orders) to be executed by `Execution Env`
|
||||
(i.e. the trading market). There may be multiple levels of `Trading Agent`
|
||||
and `Execution Env` (e.g. an *order executor trading agent and intraday
|
||||
order execution environment* could behave like an interday trading
|
||||
environment and nested in *daily portfolio management trading agent and
|
||||
interday trading environment* )
|
||||
=========================== ==============================================================================
|
||||
Name Description
|
||||
=========================== ==============================================================================
|
||||
`Infrastructure` layer `Infrastructure` layer provides underlying support for Quant research.
|
||||
`DataServer` provides high-performance infrastructure for users to manage
|
||||
and retrieve raw data. `Trainer` provides flexible interface to control
|
||||
the training process of models which enable algorithms controlling the
|
||||
training process.
|
||||
|
||||
`Interface` layer `Interface` layer tries to present a user-friendly interface for the underlying
|
||||
system. `Analyser` module will provide users detailed analysis reports of
|
||||
forecasting signals, portfolios and execution results
|
||||
======================== ==============================================================================
|
||||
`Learning Framework` layer The `Forecast Model` and `Trading Agent` are learnable. They are learned
|
||||
based on the `Learning Framework` layer and then applied to multiple scenarios
|
||||
in `Workflow` layer. The supported learning paradigms can be categorized into
|
||||
reinforcement learning and supervised learning. The learning framework
|
||||
leverages the `Workflow` layer as well(e.g. sharing `Information Extractor`,
|
||||
creating environments based on `Execution Env`).
|
||||
|
||||
`Workflow` layer `Workflow` layer covers the whole workflow of quantitative investment.
|
||||
Both supervised-learning-based strategies and RL-based Strategies
|
||||
are supported.
|
||||
`Information Extractor` extracts data for models. `Forecast Model` focuses
|
||||
on producing all kinds of forecast signals (e.g. *alpha*, risk) for other
|
||||
modules. With these signals `Decision Generator` will generate the target
|
||||
trading decisions(i.e. portfolio, orders)
|
||||
If RL-based Strategies are adopted, the `Policy` is learned in a end-to-end way,
|
||||
the trading deicsions are generated directly.
|
||||
Decisions will be executed by `Execution Env`
|
||||
(i.e. the trading market). There may be multiple levels of `Strategy`
|
||||
and `Executor` (e.g. an *order executor trading strategy and intraday order executor*
|
||||
could behave like an interday trading loop and be nested in
|
||||
*daily portfolio management trading strategy and interday trading executor*
|
||||
trading loop)
|
||||
|
||||
`Interface` layer `Interface` layer tries to present a user-friendly interface for the underlying
|
||||
system. `Analyser` module will provide users detailed analysis reports of
|
||||
forecasting signals, portfolios and execution results
|
||||
=========================== ==============================================================================
|
||||
|
||||
- The modules with hand-drawn style are under development and will be released in the future.
|
||||
- The modules with dashed borders are highly user-customizable and extendible.
|
||||
|
||||
(p.s. framework image is created with https://draw.io/)
|
||||
|
||||
@@ -21,6 +21,7 @@ Users can easily intsall ``Qlib`` according to the following steps:
|
||||
- Before installing ``Qlib`` from source, users need to install some dependencies:
|
||||
|
||||
.. code-block::
|
||||
|
||||
pip install numpy
|
||||
pip install --upgrade cython
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
.. _api:
|
||||
|
||||
=============
|
||||
API Reference
|
||||
=============
|
||||
@@ -116,7 +117,7 @@ Model
|
||||
Strategy
|
||||
--------
|
||||
|
||||
.. automodule:: qlib.contrib.strategy.strategy
|
||||
.. automodule:: qlib.contrib.strategy
|
||||
:members:
|
||||
|
||||
Evaluate
|
||||
@@ -254,5 +255,38 @@ Utils
|
||||
Serializable
|
||||
------------
|
||||
|
||||
.. automodule:: qlib.utils.serial.Serializable
|
||||
.. automodule:: qlib.utils.serial
|
||||
:members:
|
||||
|
||||
RL
|
||||
==============
|
||||
|
||||
Base Component
|
||||
--------------
|
||||
.. automodule:: qlib.rl
|
||||
:members:
|
||||
:imported-members:
|
||||
|
||||
Strategy
|
||||
--------
|
||||
.. automodule:: qlib.rl.strategy
|
||||
:members:
|
||||
:imported-members:
|
||||
|
||||
Trainer
|
||||
-------
|
||||
.. automodule:: qlib.rl.trainer
|
||||
:members:
|
||||
:imported-members:
|
||||
|
||||
Order Execution
|
||||
---------------
|
||||
.. automodule:: qlib.rl.order_execution
|
||||
:members:
|
||||
:imported-members:
|
||||
|
||||
Utils
|
||||
---------------
|
||||
.. automodule:: qlib.rl.utils
|
||||
:members:
|
||||
:imported-members:
|
||||
@@ -4,3 +4,4 @@ numpy
|
||||
scipy
|
||||
scikit-learn
|
||||
pandas
|
||||
tianshou
|
||||
|
||||
@@ -83,15 +83,14 @@ Load features of certain instruments in a given time range:
|
||||
>> from qlib.data import D
|
||||
>> instruments = ['SH600000']
|
||||
>> fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']
|
||||
>> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head()
|
||||
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low
|
||||
instrument datetime
|
||||
SH600000 2010-01-04 86.778313 16162960.0 88.825928 88.061483 2.907631
|
||||
2010-01-05 87.433578 28117442.0 86.778313 87.679273 3.235252
|
||||
2010-01-06 85.713585 23632884.0 87.433578 86.641825 1.720009
|
||||
2010-01-07 83.788803 20813402.0 85.713585 85.645322 3.030487
|
||||
2010-01-08 84.730675 16044853.0 83.788803 84.744354 2.047623
|
||||
>> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head().to_string()
|
||||
' $close $volume Ref($close, 1) Mean($close, 3) $high-$low
|
||||
... instrument datetime
|
||||
... SH600000 2010-01-04 86.778313 16162960.0 88.825928 88.061483 2.907631
|
||||
... 2010-01-05 87.433578 28117442.0 86.778313 87.679273 3.235252
|
||||
... 2010-01-06 85.713585 23632884.0 87.433578 86.641825 1.720009
|
||||
... 2010-01-07 83.788803 20813402.0 85.713585 85.645322 3.030487
|
||||
... 2010-01-08 84.730675 16044853.0 83.788803 84.744354 2.047623'
|
||||
|
||||
Load features of certain stock pool in a given time range:
|
||||
|
||||
@@ -105,15 +104,14 @@ Load features of certain stock pool in a given time range:
|
||||
>> expressionDFilter = ExpressionDFilter(rule_expression='$close>Ref($close,1)')
|
||||
>> instruments = D.instruments(market='csi300', filter_pipe=[nameDFilter, expressionDFilter])
|
||||
>> fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']
|
||||
>> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head()
|
||||
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low
|
||||
instrument datetime
|
||||
SH600655 2010-01-04 2699.567383 158193.328125 2619.070312 2626.097738 124.580566
|
||||
2010-01-08 2612.359619 77501.406250 2584.567627 2623.220133 83.373047
|
||||
2010-01-11 2712.982422 160852.390625 2612.359619 2636.636556 146.621582
|
||||
2010-01-12 2788.688232 164587.937500 2712.982422 2704.676758 128.413818
|
||||
2010-01-13 2790.604004 145460.453125 2788.688232 2764.091553 128.413818
|
||||
>> D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head().to_string()
|
||||
' $close $volume Ref($close, 1) Mean($close, 3) $high-$low
|
||||
... instrument datetime
|
||||
... SH600655 2010-01-04 2699.567383 158193.328125 2619.070312 2626.097738 124.580566
|
||||
... 2010-01-08 2612.359619 77501.406250 2584.567627 2623.220133 83.373047
|
||||
... 2010-01-11 2712.982422 160852.390625 2612.359619 2636.636556 146.621582
|
||||
... 2010-01-12 2788.688232 164587.937500 2712.982422 2704.676758 128.413818
|
||||
... 2010-01-13 2790.604004 145460.453125 2788.688232 2764.091553 128.413818'
|
||||
|
||||
|
||||
For more details about features, please refer `Feature API <../component/data.html>`_.
|
||||
|
||||
@@ -21,84 +21,88 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
- ``Qlib`` passes the initialized parameters to the \_\_init\_\_ method.
|
||||
- The hyperparameters of model in the configuration must be consistent with those defined in the `__init__` method.
|
||||
- Code Example: In the following example, the hyperparameters of model in the configuration file should contain parameters such as `loss:mse`.
|
||||
.. code-block:: Python
|
||||
|
||||
def __init__(self, loss='mse', **kwargs):
|
||||
if loss not in {'mse', 'binary'}:
|
||||
raise NotImplementedError
|
||||
self._scorer = mean_squared_error if loss == 'mse' else roc_auc_score
|
||||
self._params.update(objective=loss, **kwargs)
|
||||
self._model = None
|
||||
.. code-block:: Python
|
||||
|
||||
def __init__(self, loss='mse', **kwargs):
|
||||
if loss not in {'mse', 'binary'}:
|
||||
raise NotImplementedError
|
||||
self._scorer = mean_squared_error if loss == 'mse' else roc_auc_score
|
||||
self._params.update(objective=loss, **kwargs)
|
||||
self._model = None
|
||||
|
||||
- Override the `fit` method
|
||||
- ``Qlib`` calls the fit method to train the model.
|
||||
- The parameters must include training feature `dataset`, which is designed in the interface.
|
||||
- The parameters could include some `optional` parameters with default values, such as `num_boost_round = 1000` for `GBDT`.
|
||||
- Code Example: In the following example, `num_boost_round = 1000` is an optional parameter.
|
||||
.. code-block:: Python
|
||||
|
||||
def fit(self, dataset: DatasetH, num_boost_round = 1000, **kwargs):
|
||||
.. code-block:: Python
|
||||
|
||||
# prepare dataset for lgb training and evaluation
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
def fit(self, dataset: DatasetH, num_boost_round = 1000, **kwargs):
|
||||
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
# prepare dataset for lgb training and evaluation
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
# fit the model
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
|
||||
# fit the model
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
- Override the `predict` method
|
||||
- The parameters must include the parameter `dataset`, which will be userd to get the test dataset.
|
||||
- Return the `prediction score`.
|
||||
- Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_ for the parameter types of the fit method.
|
||||
- Code Example: In the following example, users need to use `LightGBM` to predict the label(such as `preds`) of test data `x_test` and return it.
|
||||
.. code-block:: Python
|
||||
|
||||
def predict(self, dataset: DatasetH, **kwargs)-> pandas.Series:
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
.. code-block:: Python
|
||||
|
||||
def predict(self, dataset: DatasetH, **kwargs)-> pandas.Series:
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
- Override the `finetune` method (Optional)
|
||||
- This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
|
||||
- The parameters must include the parameter `dataset`.
|
||||
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
|
||||
.. code-block:: Python
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
)
|
||||
.. code-block:: Python
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
)
|
||||
|
||||
Configuration File
|
||||
==================
|
||||
@@ -107,21 +111,21 @@ The configuration file is described in detail in the `Workflow <../component/wor
|
||||
|
||||
- Example: The following example describes the `model` field of configuration file about the custom lightgbm model mentioned above, where `module_path` is the module path, `class` is the class name, and `args` is the hyperparameter passed into the __init__ method. All parameters in the field is passed to `self._params` by `\*\*kwargs` in `__init__` except `loss = mse`.
|
||||
|
||||
.. code-block:: YAML
|
||||
.. code-block:: YAML
|
||||
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
args:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
args:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
|
||||
Users could find configuration file of the baselines of the ``Model`` in ``examples/benchmarks``. All the configurations of different models are listed under the corresponding model folder.
|
||||
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 3
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
epochs: 1000
|
||||
early_stopping_rounds: 50
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -5,6 +5,8 @@ from qlib.data.inst_processor import InstProcessor
|
||||
|
||||
|
||||
class Resample1minProcessor(InstProcessor):
|
||||
"""This processor tries to resample the data. It will reasmple the data from 1min freq to day freq by selecting a specific miniute"""
|
||||
|
||||
def __init__(self, hour: int, minute: int, **kwargs):
|
||||
self.hour = hour
|
||||
self.minute = minute
|
||||
|
||||
@@ -29,13 +29,13 @@ class Avg15minHandler(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
inst_processors=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
data_loader = Avg15minLoader(
|
||||
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processor=inst_processor
|
||||
config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processors=inst_processors
|
||||
)
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
|
||||
@@ -18,7 +18,7 @@ data_handler_config: &data_handler_config
|
||||
label: day
|
||||
feature: 1min
|
||||
# with label as reference
|
||||
inst_processor:
|
||||
inst_processors:
|
||||
feature:
|
||||
- class: Resample1minProcessor
|
||||
module_path: features_sample.py
|
||||
|
||||
@@ -19,7 +19,7 @@ data_handler_config: &data_handler_config
|
||||
feature_15min: 1min
|
||||
feature_day: day
|
||||
# with label as reference
|
||||
inst_processor:
|
||||
inst_processors:
|
||||
feature_15min:
|
||||
- class: ResampleNProcessor
|
||||
module_path: features_resample_N.py
|
||||
|
||||
@@ -64,8 +64,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 8192
|
||||
|
||||
@@ -64,8 +64,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 8192
|
||||
|
||||
@@ -52,8 +52,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
|
||||
@@ -52,8 +52,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
|
||||
@@ -25,59 +25,65 @@
|
||||
"import seaborn as sns\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib\n",
|
||||
"sns.set(style='white')\n",
|
||||
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
|
||||
"matplotlib.rcParams['ps.fonttype'] = 42\n",
|
||||
"\n",
|
||||
"sns.set(style=\"white\")\n",
|
||||
"matplotlib.rcParams[\"pdf.fonttype\"] = 42\n",
|
||||
"matplotlib.rcParams[\"ps.fonttype\"] = 42\n",
|
||||
"\n",
|
||||
"from tqdm.auto import tqdm\n",
|
||||
"from joblib import Parallel, delayed\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def func(x, N=80):\n",
|
||||
" ret = x.ret.copy()\n",
|
||||
" x = x.rank(pct=True)\n",
|
||||
" x['ret'] = ret\n",
|
||||
" x[\"ret\"] = ret\n",
|
||||
" diff = x.score.sub(x.label)\n",
|
||||
" r = x.nlargest(N, columns='score').ret.mean()\n",
|
||||
" r -= x.nsmallest(N, columns='score').ret.mean()\n",
|
||||
" return pd.Series({\n",
|
||||
" 'MSE': diff.pow(2).mean(), \n",
|
||||
" 'MAE': diff.abs().mean(), \n",
|
||||
" 'IC': x.score.corr(x.label),\n",
|
||||
" 'R': r\n",
|
||||
" })\n",
|
||||
" \n",
|
||||
" r = x.nlargest(N, columns=\"score\").ret.mean()\n",
|
||||
" r -= x.nsmallest(N, columns=\"score\").ret.mean()\n",
|
||||
" return pd.Series(\n",
|
||||
" {\n",
|
||||
" \"MSE\": diff.pow(2).mean(),\n",
|
||||
" \"MAE\": diff.abs().mean(),\n",
|
||||
" \"IC\": x.score.corr(x.label),\n",
|
||||
" \"R\": r,\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ret = pd.read_pickle(\"data/ret.pkl\").clip(-0.1, 0.1)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def backtest(fname, **kwargs):\n",
|
||||
" pred = pd.read_pickle(fname).loc['2018-09-21':'2020-06-30'] # test period\n",
|
||||
" pred['ret'] = ret\n",
|
||||
" pred = pd.read_pickle(fname).loc[\"2018-09-21\":\"2020-06-30\"] # test period\n",
|
||||
" pred[\"ret\"] = ret\n",
|
||||
" dates = pred.index.unique(level=0)\n",
|
||||
" res = Parallel(n_jobs=-1)(delayed(func)(pred.loc[d], **kwargs) for d in dates)\n",
|
||||
" res = {\n",
|
||||
" dates[i]: res[i]\n",
|
||||
" for i in range(len(dates))\n",
|
||||
" }\n",
|
||||
" res = {dates[i]: res[i] for i in range(len(dates))}\n",
|
||||
" res = pd.DataFrame(res).T\n",
|
||||
" r = res['R'].copy()\n",
|
||||
" r = res[\"R\"].copy()\n",
|
||||
" r.index = pd.to_datetime(r.index)\n",
|
||||
" r = r.reindex(pd.date_range(r.index[0], r.index[-1])).fillna(0) # paper use 365 days\n",
|
||||
" return {\n",
|
||||
" 'MSE': res['MSE'].mean(),\n",
|
||||
" 'MAE': res['MAE'].mean(),\n",
|
||||
" 'IC': res['IC'].mean(),\n",
|
||||
" 'ICIR': res['IC'].mean()/res['IC'].std(),\n",
|
||||
" 'AR': r.mean()*365,\n",
|
||||
" 'AV': r.std()*365**0.5,\n",
|
||||
" 'SR': r.mean()/r.std()*365**0.5,\n",
|
||||
" 'MDD': (r.cumsum().cummax() - r.cumsum()).max()\n",
|
||||
" \"MSE\": res[\"MSE\"].mean(),\n",
|
||||
" \"MAE\": res[\"MAE\"].mean(),\n",
|
||||
" \"IC\": res[\"IC\"].mean(),\n",
|
||||
" \"ICIR\": res[\"IC\"].mean() / res[\"IC\"].std(),\n",
|
||||
" \"AR\": r.mean() * 365,\n",
|
||||
" \"AV\": r.std() * 365**0.5,\n",
|
||||
" \"SR\": r.mean() / r.std() * 365**0.5,\n",
|
||||
" \"MDD\": (r.cumsum().cummax() - r.cumsum()).max(),\n",
|
||||
" }, r\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def fmt(x, p=3, scale=1, std=False):\n",
|
||||
" _fmt = '{:.%df}'%p\n",
|
||||
" _fmt = \"{:.%df}\" % p\n",
|
||||
" string = _fmt.format((x.mean() if not isinstance(x, (float, np.floating)) else x) * scale)\n",
|
||||
" if std and len(x) > 1:\n",
|
||||
" string += ' ('+_fmt.format(x.std()*scale)+')'\n",
|
||||
" string += \" (\" + _fmt.format(x.std() * scale) + \")\"\n",
|
||||
" return string\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def backtest_multi(files, **kwargs):\n",
|
||||
" res = []\n",
|
||||
" pnl = []\n",
|
||||
@@ -88,14 +94,14 @@
|
||||
" res = pd.DataFrame(res)\n",
|
||||
" pnl = pd.concat(pnl, axis=1)\n",
|
||||
" return {\n",
|
||||
" 'MSE': fmt(res['MSE'], std=True),\n",
|
||||
" 'MAE': fmt(res['MAE'], std=True),\n",
|
||||
" 'IC': fmt(res['IC']),\n",
|
||||
" 'ICIR': fmt(res['ICIR']),\n",
|
||||
" 'AR': fmt(res['AR'], scale=100, p=1)+'%',\n",
|
||||
" 'VR': fmt(res['AV'], scale=100, p=1)+'%',\n",
|
||||
" 'SR': fmt(res['SR']),\n",
|
||||
" 'MDD': fmt(res['MDD'], scale=100, p=1)+'%'\n",
|
||||
" \"MSE\": fmt(res[\"MSE\"], std=True),\n",
|
||||
" \"MAE\": fmt(res[\"MAE\"], std=True),\n",
|
||||
" \"IC\": fmt(res[\"IC\"]),\n",
|
||||
" \"ICIR\": fmt(res[\"ICIR\"]),\n",
|
||||
" \"AR\": fmt(res[\"AR\"], scale=100, p=1) + \"%\",\n",
|
||||
" \"VR\": fmt(res[\"AV\"], scale=100, p=1) + \"%\",\n",
|
||||
" \"SR\": fmt(res[\"SR\"]),\n",
|
||||
" \"MDD\": fmt(res[\"MDD\"], scale=100, p=1) + \"%\",\n",
|
||||
" }, pnl"
|
||||
]
|
||||
},
|
||||
@@ -124,16 +130,20 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exps = {\n",
|
||||
" 'Linear': ['output/Linear/pred.pkl'],\n",
|
||||
" 'LightGBM': ['output/GBDT/lr0.05_leaves128/pred.pkl'],\n",
|
||||
" 'MLP': glob.glob('output/search/MLP/hs128_bs512_do0.3_lr0.001_seed*/pred.pkl'),\n",
|
||||
" 'SFM': glob.glob('output/search/SFM/hs32_bs512_do0.5_lr0.001_seed*/pred.pkl'),\n",
|
||||
" 'ALSTM': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
|
||||
" 'Trans.': glob.glob('output/search/Transformer/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
|
||||
" 'ALSTM+TS':glob.glob('output/LSTM_Attn_TS/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
|
||||
" 'Trans.+TS':glob.glob('output/Transformer_TS/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n",
|
||||
" 'ALSTM+TRA(Ours)': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
|
||||
" 'Trans.+TRA(Ours)': glob.glob('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb1.0_head4_hs64_bs512_do0.1_lr0.0005_seed*/pred.pkl')\n",
|
||||
" \"Linear\": [\"output/Linear/pred.pkl\"],\n",
|
||||
" \"LightGBM\": [\"output/GBDT/lr0.05_leaves128/pred.pkl\"],\n",
|
||||
" \"MLP\": glob.glob(\"output/search/MLP/hs128_bs512_do0.3_lr0.001_seed*/pred.pkl\"),\n",
|
||||
" \"SFM\": glob.glob(\"output/search/SFM/hs32_bs512_do0.5_lr0.001_seed*/pred.pkl\"),\n",
|
||||
" \"ALSTM\": glob.glob(\"output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
|
||||
" \"Trans.\": glob.glob(\"output/search/Transformer/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
|
||||
" \"ALSTM+TS\": glob.glob(\"output/LSTM_Attn_TS/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
|
||||
" \"Trans.+TS\": glob.glob(\"output/Transformer_TS/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl\"),\n",
|
||||
" \"ALSTM+TRA(Ours)\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
" \"Trans.+TRA(Ours)\": glob.glob(\n",
|
||||
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb1.0_head4_hs64_bs512_do0.1_lr0.0005_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@@ -160,14 +170,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"res = {\n",
|
||||
" name: backtest_multi(exps[name])\n",
|
||||
" for name in tqdm(exps)\n",
|
||||
"}\n",
|
||||
"report = pd.DataFrame({\n",
|
||||
" k: v[0]\n",
|
||||
" for k, v in res.items()\n",
|
||||
"}).T"
|
||||
"res = {name: backtest_multi(exps[name]) for name in tqdm(exps)}\n",
|
||||
"report = pd.DataFrame({k: v[0] for k, v in res.items()}).T"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -385,24 +389,40 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed1000/pred.pkl')\n",
|
||||
"code = 'SH600157'\n",
|
||||
"date = '2018-09-28'\n",
|
||||
"df = pd.read_pickle(\n",
|
||||
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed1000/pred.pkl\"\n",
|
||||
")\n",
|
||||
"code = \"SH600157\"\n",
|
||||
"date = \"2018-09-28\"\n",
|
||||
"lookbackperiod = 50\n",
|
||||
"\n",
|
||||
"prob = df.iloc[:, -3:].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n",
|
||||
"pred = df.loc[:,[\"score_0\",\"score_1\",\"score_2\",\"label\"]].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n",
|
||||
"e_all = pred.iloc[:,:-1].sub(pred.iloc[:,-1], axis=0).pow(2)\n",
|
||||
"pred = (\n",
|
||||
" df.loc[:, [\"score_0\", \"score_1\", \"score_2\", \"label\"]]\n",
|
||||
" .loc(axis=0)[:, code]\n",
|
||||
" .reset_index(level=1, drop=True)\n",
|
||||
" .loc[date:]\n",
|
||||
" .iloc[:lookbackperiod]\n",
|
||||
")\n",
|
||||
"e_all = pred.iloc[:, :-1].sub(pred.iloc[:, -1], axis=0).pow(2)\n",
|
||||
"e_all = e_all.sub(e_all.min(axis=1), axis=0)\n",
|
||||
"e_all.columns = [r'$\\theta_%d$'%d for d in range(1, 4)]\n",
|
||||
"e_all.columns = [r\"$\\theta_%d$\" % d for d in range(1, 4)]\n",
|
||||
"prob = pd.Series(np.argmax(prob.values, axis=1), index=prob.index).rolling(7).mean().round()\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(7, 3))\n",
|
||||
"e_all.plot(ax=axes[0], xlabel='', rot=30)\n",
|
||||
"prob.plot(ax=axes[1], xlabel='', rot=30, color='red', linestyle='None', marker='^', markersize=5)\n",
|
||||
"e_all.plot(ax=axes[0], xlabel=\"\", rot=30)\n",
|
||||
"prob.plot(\n",
|
||||
" ax=axes[1],\n",
|
||||
" xlabel=\"\",\n",
|
||||
" rot=30,\n",
|
||||
" color=\"red\",\n",
|
||||
" linestyle=\"None\",\n",
|
||||
" marker=\"^\",\n",
|
||||
" markersize=5,\n",
|
||||
")\n",
|
||||
"plt.yticks(np.array([0, 1, 2]), e_all.columns.values)\n",
|
||||
"axes[0].set_ylabel('Predictor Loss')\n",
|
||||
"axes[1].set_ylabel('Router Selection')\n",
|
||||
"axes[0].set_ylabel(\"Predictor Loss\")\n",
|
||||
"axes[1].set_ylabel(\"Router Selection\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"# plt.savefig('select.pdf', bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
@@ -428,10 +448,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exps = {\n",
|
||||
" 'Random': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcNONE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
|
||||
" 'LR': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcLR_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
|
||||
" 'TPE': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcTPE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n",
|
||||
" 'LR+TPE': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl')\n",
|
||||
" \"Random\": glob.glob(\n",
|
||||
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcNONE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
" \"LR\": glob.glob(\n",
|
||||
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcLR_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
" \"TPE\": glob.glob(\n",
|
||||
" \"output/search/LSTM_Attn_tra/K10_traHs16_traSrcTPE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
" \"LR+TPE\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl\"\n",
|
||||
" ),\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@@ -456,14 +484,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"res = {\n",
|
||||
" name: backtest_multi(exps[name])\n",
|
||||
" for name in tqdm(exps)\n",
|
||||
"}\n",
|
||||
"report = pd.DataFrame({\n",
|
||||
" k: v[0]\n",
|
||||
" for k, v in res.items()\n",
|
||||
"}).T"
|
||||
"res = {name: backtest_multi(exps[name]) for name in tqdm(exps)}\n",
|
||||
"report = pd.DataFrame({k: v[0] for k, v in res.items()}).T"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -597,18 +619,22 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n",
|
||||
"b = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n",
|
||||
"a = pd.read_pickle(\n",
|
||||
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl\"\n",
|
||||
")\n",
|
||||
"b = pd.read_pickle(\n",
|
||||
" \"output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl\"\n",
|
||||
")\n",
|
||||
"a = a.iloc[:, -3:]\n",
|
||||
"b = b.iloc[:, -3:]\n",
|
||||
"b = np.eye(3)[b.values.argmax(axis=1)]\n",
|
||||
"a = np.eye(3)[a.values.argmax(axis=1)]\n",
|
||||
"\n",
|
||||
"res = pd.DataFrame({\n",
|
||||
" 'with OT': b.sum(axis=0) / b.sum(),\n",
|
||||
" 'without OT': a.sum(axis=0)/ a.sum() \n",
|
||||
"},index=[r'$\\theta_1$',r'$\\theta_2$',r'$\\theta_3$'])\n",
|
||||
"res.plot.bar(rot=30, figsize=(5, 4), color=['b', 'g'])\n",
|
||||
"res = pd.DataFrame(\n",
|
||||
" {\"with OT\": b.sum(axis=0) / b.sum(), \"without OT\": a.sum(axis=0) / a.sum()},\n",
|
||||
" index=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n",
|
||||
")\n",
|
||||
"res.plot.bar(rot=30, figsize=(5, 4), color=[\"b\", \"g\"])\n",
|
||||
"del a, b"
|
||||
]
|
||||
},
|
||||
@@ -633,11 +659,19 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exps = {\n",
|
||||
" 'K=1': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/info.json'),\n",
|
||||
" 'K=3': glob.glob('output/search/finetune/LSTM_Attn_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
|
||||
" 'K=5': glob.glob('output/search/finetune/LSTM_Attn_tra/K5_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
|
||||
" 'K=10': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n",
|
||||
" 'K=20': glob.glob('output/search/finetune/LSTM_Attn_tra/K20_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json')\n",
|
||||
" \"K=1\": glob.glob(\"output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/info.json\"),\n",
|
||||
" \"K=3\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
|
||||
" ),\n",
|
||||
" \"K=5\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K5_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
|
||||
" ),\n",
|
||||
" \"K=10\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
|
||||
" ),\n",
|
||||
" \"K=20\": glob.glob(\n",
|
||||
" \"output/search/finetune/LSTM_Attn_tra/K20_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json\"\n",
|
||||
" ),\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@@ -649,16 +683,11 @@
|
||||
"source": [
|
||||
"report = dict()\n",
|
||||
"for k, v in exps.items():\n",
|
||||
" \n",
|
||||
" tmp = dict()\n",
|
||||
" for fname in v:\n",
|
||||
" with open(fname) as f:\n",
|
||||
" info = json.load(f)\n",
|
||||
" tmp[fname] = (\n",
|
||||
" {\n",
|
||||
" \"IC\":info[\"metric\"][\"IC\"],\n",
|
||||
" \"MSE\":info[\"metric\"][\"MSE\"]\n",
|
||||
" })\n",
|
||||
" tmp[fname] = {\"IC\": info[\"metric\"][\"IC\"], \"MSE\": info[\"metric\"][\"MSE\"]}\n",
|
||||
" tmp = pd.DataFrame(tmp).T\n",
|
||||
" report[k] = tmp.mean()\n",
|
||||
"report = pd.DataFrame(report).T"
|
||||
@@ -681,13 +710,14 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(6,3)); axes = axes.flatten()\n",
|
||||
"report['IC'].plot.bar(rot=30, ax=axes[0])\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(6, 3))\n",
|
||||
"axes = axes.flatten()\n",
|
||||
"report[\"IC\"].plot.bar(rot=30, ax=axes[0])\n",
|
||||
"axes[0].set_ylim(0.045, 0.062)\n",
|
||||
"axes[0].set_title('IC performance')\n",
|
||||
"report['MSE'].astype(float).plot.bar(rot=30, ax=axes[1], color='green')\n",
|
||||
"axes[0].set_title(\"IC performance\")\n",
|
||||
"report[\"MSE\"].astype(float).plot.bar(rot=30, ax=axes[1], color=\"green\")\n",
|
||||
"axes[1].set_ylim(0.155, 0.1585)\n",
|
||||
"axes[1].set_title('MSE performance')\n",
|
||||
"axes[1].set_title(\"MSE performance\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"# plt.savefig('sensitivity.pdf')"
|
||||
]
|
||||
|
||||
107
examples/benchmarks_dynamic/DDG-DA/vis_data.py
Normal file
107
examples/benchmarks_dynamic/DDG-DA/vis_data.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
sns.set(color_codes=True)
|
||||
plt.rcParams["font.sans-serif"] = "SimHei"
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
# tqdm.pandas() # for progress_apply
|
||||
# %matplotlib inline
|
||||
# %load_ext autoreload
|
||||
|
||||
|
||||
# # Meta Input
|
||||
|
||||
# +
|
||||
with open("./internal_data_s20.pkl", "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
data.data_ic_df.columns.names = ["start_date", "end_date"]
|
||||
|
||||
data_sim = data.data_ic_df.droplevel(axis=1, level="end_date")
|
||||
|
||||
data_sim.index.name = "test datetime"
|
||||
# -
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(data_sim)
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(data_sim.rolling(20).mean())
|
||||
|
||||
# # Meta Model
|
||||
|
||||
from qlib import auto_init
|
||||
|
||||
auto_init()
|
||||
from qlib.workflow import R
|
||||
|
||||
exp = R.get_exp(experiment_name="DDG-DA")
|
||||
meta_rec = exp.list_recorders(rtype="list", max_results=1)[0]
|
||||
meta_m = meta_rec.load_object("model")
|
||||
|
||||
pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].plot()
|
||||
|
||||
pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].rolling(5).mean().plot()
|
||||
|
||||
# # Meta Output
|
||||
|
||||
# +
|
||||
with open("./tasks_s20.pkl", "rb") as f:
|
||||
tasks = pickle.load(f)
|
||||
|
||||
task_df = {}
|
||||
for t in tasks:
|
||||
test_seg = t["dataset"]["kwargs"]["segments"]["test"]
|
||||
if None not in test_seg:
|
||||
# The last rolling is skipped.
|
||||
task_df[test_seg] = t["reweighter"].time_weight
|
||||
task_df = pd.concat(task_df)
|
||||
|
||||
task_df.index.names = ["OS_start", "OS_end", "IS_start", "IS_end"]
|
||||
task_df = task_df.droplevel(["OS_end", "IS_end"])
|
||||
task_df = task_df.unstack("OS_start")
|
||||
# -
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(task_df.T)
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(task_df.rolling(10).mean().T)
|
||||
|
||||
# # Sub Models
|
||||
#
|
||||
# NOTE:
|
||||
# - this section assumes that the model is Linear model!!
|
||||
# - Other models does not support this analysis
|
||||
|
||||
exp = R.get_exp(experiment_name="rolling_ds")
|
||||
|
||||
|
||||
def show_linear_weight(exp):
|
||||
coef_df = {}
|
||||
for r in exp.list_recorders("list"):
|
||||
t = r.load_object("task")
|
||||
if None in t["dataset"]["kwargs"]["segments"]["test"]:
|
||||
continue
|
||||
m = r.load_object("params.pkl")
|
||||
coef_df[t["dataset"]["kwargs"]["segments"]["test"]] = pd.Series(m.coef_)
|
||||
|
||||
coef_df = pd.concat(coef_df)
|
||||
|
||||
coef_df.index.names = ["test_start", "test_end", "coef_idx"]
|
||||
|
||||
coef_df = coef_df.droplevel("test_end").unstack("coef_idx").T
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(coef_df)
|
||||
plt.show()
|
||||
|
||||
|
||||
show_linear_weight(R.get_exp(experiment_name="rolling_ds"))
|
||||
|
||||
show_linear_weight(R.get_exp(experiment_name="rolling_models"))
|
||||
@@ -10,8 +10,10 @@ import pandas as pd
|
||||
import fire
|
||||
import sys
|
||||
import pickle
|
||||
from typing import Optional
|
||||
from qlib import auto_init
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.typehint import Literal
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
@@ -30,7 +32,33 @@ class DDGDA:
|
||||
- `rm -r mlruns`
|
||||
"""
|
||||
|
||||
def __init__(self, sim_task_model="linear", forecast_model="linear"):
|
||||
def __init__(
|
||||
self,
|
||||
sim_task_model: Literal["linear", "gbdt"] = "linear",
|
||||
forecast_model: Literal["linear", "gbdt"] = "linear",
|
||||
h_path: Optional[str] = None,
|
||||
test_end: Optional[str] = None,
|
||||
train_start: Optional[str] = None,
|
||||
meta_1st_train_end: Optional[str] = None,
|
||||
task_ext_conf: Optional[dict] = None,
|
||||
alpha: float = 0.0,
|
||||
proxy_hd: str = "handler_proxy.pkl",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
train_start: Optional[str]
|
||||
the start datetime for data. It is used in training start time (for both tasks & meta learing)
|
||||
test_end: Optional[str]
|
||||
the end datetime for data. It is used in test end time
|
||||
meta_1st_train_end: Optional[str]
|
||||
the datetime of training end of the first meta_task
|
||||
alpha: float
|
||||
Setting the L2 regularization for ridge
|
||||
The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)
|
||||
"""
|
||||
self.step = 20
|
||||
# NOTE:
|
||||
# the horizon must match the meaning in the base task template
|
||||
@@ -38,10 +66,19 @@ class DDGDA:
|
||||
self.meta_exp_name = "DDG-DA"
|
||||
self.sim_task_model = sim_task_model # The model to capture the distribution of data.
|
||||
self.forecast_model = forecast_model # downstream forecasting models' type
|
||||
self.rb_kwargs = {
|
||||
"h_path": h_path,
|
||||
"test_end": test_end,
|
||||
"train_start": train_start,
|
||||
"task_ext_conf": task_ext_conf,
|
||||
}
|
||||
self.alpha = alpha
|
||||
self.meta_1st_train_end = meta_1st_train_end
|
||||
self.proxy_hd = proxy_hd
|
||||
|
||||
def get_feature_importance(self):
|
||||
# this must be lightGBM, because it needs to get the feature importance
|
||||
rb = RollingBenchmark(model_type="gbdt")
|
||||
rb = RollingBenchmark(model_type="gbdt", **self.rb_kwargs)
|
||||
task = rb.basic_task()
|
||||
|
||||
with R.start(experiment_name="feature_importance"):
|
||||
@@ -69,7 +106,7 @@ class DDGDA:
|
||||
fi = self.get_feature_importance()
|
||||
col_selected = fi.nlargest(topk)
|
||||
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model)
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
|
||||
task = rb.basic_task()
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -96,7 +133,7 @@ class DDGDA:
|
||||
"kwargs": {"config": DIRNAME / "fea_label_df.pkl"},
|
||||
}
|
||||
)
|
||||
handler.to_pickle(DIRNAME / "handler_proxy.pkl", dump_all=True)
|
||||
handler.to_pickle(DIRNAME / self.proxy_hd, dump_all=True)
|
||||
|
||||
@property
|
||||
def _internal_data_path(self):
|
||||
@@ -108,7 +145,7 @@ class DDGDA:
|
||||
This function will dump the input data for meta model
|
||||
"""
|
||||
# According to the experiments, the choice of the model type is very important for achieving good results
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model)
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
|
||||
sim_task = rb.basic_task()
|
||||
|
||||
if self.sim_task_model == "gbdt":
|
||||
@@ -122,24 +159,27 @@ class DDGDA:
|
||||
with self._internal_data_path.open("wb") as f:
|
||||
pickle.dump(internal_data, f)
|
||||
|
||||
def train_meta_model(self):
|
||||
def train_meta_model(self, fill_method="max"):
|
||||
"""
|
||||
training a meta model based on a simplified linear proxy model;
|
||||
"""
|
||||
|
||||
# 1) leverage the simplified proxy forecasting model to train meta model.
|
||||
# - Only the dataset part is important, in current version of meta model will integrate the
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model)
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
|
||||
sim_task = rb.basic_task()
|
||||
train_start = self.rb_kwargs.get("train_start", "2008-01-01")
|
||||
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
|
||||
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
proxy_forecast_model_task = {
|
||||
# "model": "qlib.contrib.model.linear.LinearModel",
|
||||
"dataset": {
|
||||
"class": "qlib.data.dataset.DatasetH",
|
||||
"kwargs": {
|
||||
"handler": f"file://{(DIRNAME / 'handler_proxy.pkl').absolute()}",
|
||||
"handler": f"file://{(DIRNAME / self.proxy_hd).absolute()}",
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2010-12-31"),
|
||||
"test": ("2011-01-01", sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
|
||||
"train": (train_start, train_end),
|
||||
"test": (test_start, sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -156,7 +196,7 @@ class DDGDA:
|
||||
segments=0.62, # keep test period consistent with the dataset yaml
|
||||
trunc_days=1 + self.horizon,
|
||||
hist_step_n=30,
|
||||
fill_method="max",
|
||||
fill_method=fill_method,
|
||||
rolling_ext_days=0,
|
||||
)
|
||||
# NOTE:
|
||||
@@ -165,12 +205,15 @@ class DDGDA:
|
||||
# So the misalignment will not affect the effectiveness of the method.
|
||||
with self._internal_data_path.open("rb") as f:
|
||||
internal_data = pickle.load(f)
|
||||
|
||||
md = MetaDatasetDS(exp_name=internal_data, **kwargs)
|
||||
|
||||
# 3) train and logging meta model
|
||||
with R.start(experiment_name=self.meta_exp_name):
|
||||
R.log_params(**kwargs)
|
||||
mm = MetaModelDS(step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=200, seed=43)
|
||||
mm = MetaModelDS(
|
||||
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43, alpha=self.alpha
|
||||
)
|
||||
mm.fit(md)
|
||||
R.save_objects(model=mm)
|
||||
|
||||
@@ -203,7 +246,7 @@ class DDGDA:
|
||||
hist_step_n = int(param["hist_step_n"])
|
||||
fill_method = param.get("fill_method", "max")
|
||||
|
||||
rb = RollingBenchmark(model_type=self.forecast_model)
|
||||
rb = RollingBenchmark(model_type=self.forecast_model, **self.rb_kwargs)
|
||||
task_l = rb.create_rolling_tasks()
|
||||
|
||||
# 2.2) create meta dataset for final dataset
|
||||
@@ -233,13 +276,13 @@ class DDGDA:
|
||||
"""
|
||||
with self._task_path.open("rb") as f:
|
||||
tasks = pickle.load(f)
|
||||
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model)
|
||||
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model, **self.rb_kwargs)
|
||||
rb.train_rolling_tasks(tasks)
|
||||
rb.ens_rolling()
|
||||
rb.update_rolling_rec()
|
||||
|
||||
def run_all(self):
|
||||
# 1) file: handler_proxy.pkl
|
||||
# 1) file: handler_proxy.pkl (self.proxy_hd)
|
||||
self.dump_data_for_proxy_model()
|
||||
# 2)
|
||||
# file: internal_data_s20.pkl
|
||||
|
||||
@@ -4,15 +4,21 @@ So adapting the forecasting models/strategies to market dynamics is very importa
|
||||
|
||||
The table below shows the performances of different solutions on different forecasting models.
|
||||
|
||||
## Alpha158 dataset
|
||||
## Alpha158 Dataset
|
||||
Here is the [crowd sourced version of qlib data](data_collector/crowd_source/README.md): https://github.com/chenditc/investment_data/releases
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
```
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------------|---------|----|------|---------|-----------|-------------------|-------------------|--------------|
|
||||
| RR[Linear] |Alpha158 |0.088|0.570|0.102 |0.622 |0.077 |1.175 |-0.086 |
|
||||
| DDG-DA[Linear] |Alpha158 |0.093|0.622|0.106 |0.670 |0.085 |1.213 |-0.093 |
|
||||
| RR[LightGBM] |Alpha158 |0.079|0.566|0.088 |0.592 |0.075 |1.226 |-0.096 |
|
||||
| DDG-DA[LightGBM] |Alpha158 |0.084|0.639|0.093 |0.664 |0.099 |1.442 |-0.071 |
|
||||
| RR[Linear] |Alpha158 |0.089|0.577|0.102 |0.627 |0.093 |1.458 |-0.073 |
|
||||
| DDG-DA[Linear] |Alpha158 |0.096|0.636|0.107 |0.677 |0.067 |0.996 |-0.091 |
|
||||
| RR[LightGBM] |Alpha158 |0.082|0.589|0.091 |0.626 |0.077 |1.320 |-0.091 |
|
||||
| DDG-DA[LightGBM] |Alpha158 |0.085|0.658|0.094 |0.686 |0.115 |1.792 |-0.068 |
|
||||
|
||||
- The label horizon of the `Alpha158` dataset is set to 20.
|
||||
- The rolling time intervals are set to 20 trading days.
|
||||
- The test rolling periods are from January 2017 to August 2020.
|
||||
- The results are based on the crowd-sourced version. The Yahoo version of qlib data does not contain `VWAP`, so all related factors are missing and filled with 0, which leads to a rank-deficient matrix (a matrix does not have full rank) and makes lower-level optimization of DDG-DA can not be solved.
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from typing import Optional
|
||||
from qlib.model.ens.ensemble import RollingEnsemble
|
||||
from qlib.utils import init_instance_by_config
|
||||
import fire
|
||||
import yaml
|
||||
import pandas as pd
|
||||
from qlib import auto_init
|
||||
from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.data import update_config
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
@@ -25,11 +29,40 @@ class RollingBenchmark:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, rolling_exp="rolling_models", model_type="linear") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
rolling_exp: str = "rolling_models",
|
||||
model_type: str = "linear",
|
||||
h_path: Optional[str] = None,
|
||||
train_start: Optional[str] = None,
|
||||
test_end: Optional[str] = None,
|
||||
task_ext_conf: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
rolling_exp : str
|
||||
The name for the experiments for rolling
|
||||
model_type : str
|
||||
The model to be boosted.
|
||||
h_path : Optional[str]
|
||||
the dumped data handler;
|
||||
test_end : Optional[str]
|
||||
the test end for the data. It is typically used together with the handler
|
||||
train_start : Optional[str]
|
||||
the train start for the data. It is typically used together with the handler.
|
||||
task_ext_conf : Optional[dict]
|
||||
some option to update the
|
||||
"""
|
||||
self.step = 20
|
||||
self.horizon = 20
|
||||
self.rolling_exp = rolling_exp
|
||||
self.model_type = model_type
|
||||
self.h_path = h_path
|
||||
self.train_start = train_start
|
||||
self.test_end = test_end
|
||||
self.logger = get_module_logger("RollingBenchmark")
|
||||
self.task_ext_conf = task_ext_conf
|
||||
|
||||
def basic_task(self):
|
||||
"""For fast training rolling"""
|
||||
@@ -42,6 +75,10 @@ class RollingBenchmark:
|
||||
h_path = DIRNAME / "linear_alpha158_handler_horizon{}.pkl".format(self.horizon)
|
||||
else:
|
||||
raise AssertionError("Model type is not supported!")
|
||||
|
||||
if self.h_path is not None:
|
||||
h_path = Path(self.h_path)
|
||||
|
||||
with conf_path.open("r") as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
@@ -52,6 +89,9 @@ class RollingBenchmark:
|
||||
|
||||
task = conf["task"]
|
||||
|
||||
if self.task_ext_conf is not None:
|
||||
task = update_config(task, self.task_ext_conf)
|
||||
|
||||
if not h_path.exists():
|
||||
h_conf = task["dataset"]["kwargs"]["handler"]
|
||||
h = init_instance_by_config(h_conf)
|
||||
@@ -59,6 +99,15 @@ class RollingBenchmark:
|
||||
|
||||
task["dataset"]["kwargs"]["handler"] = f"file://{h_path}"
|
||||
task["record"] = ["qlib.workflow.record_temp.SignalRecord"]
|
||||
|
||||
if self.train_start is not None:
|
||||
seg = task["dataset"]["kwargs"]["segments"]["train"]
|
||||
task["dataset"]["kwargs"]["segments"]["train"] = pd.Timestamp(self.train_start), seg[1]
|
||||
|
||||
if self.test_end is not None:
|
||||
seg = task["dataset"]["kwargs"]["segments"]["test"]
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = seg[0], pd.Timestamp(self.test_end)
|
||||
self.logger.info(task)
|
||||
return task
|
||||
|
||||
def create_rolling_tasks(self):
|
||||
@@ -93,7 +142,7 @@ class RollingBenchmark:
|
||||
"""
|
||||
Evaluate the combined rolling results
|
||||
"""
|
||||
for rid, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
|
||||
for _, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
|
||||
for rt_cls in SigAnaRecord, PortAnaRecord:
|
||||
rt = rt_cls(recorder=rec, skip_existing=True)
|
||||
rt.generate()
|
||||
|
||||
357
examples/rl/simple_example.ipynb
Normal file
357
examples/rl/simple_example.ipynb
Normal file
File diff suppressed because one or more lines are too long
100
examples/rl_order_execution/README.md
Normal file
100
examples/rl_order_execution/README.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# RL Example for Order Execution
|
||||
|
||||
This folder comprises an example of Reinforcement Learning (RL) workflows for order execution scenario, including both training workflows and backtest workflows.
|
||||
|
||||
## Data Processing
|
||||
|
||||
### Get Data
|
||||
|
||||
```
|
||||
python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region hs300 --interval 5min
|
||||
```
|
||||
|
||||
### Generate Pickle-Style Data
|
||||
|
||||
To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):
|
||||
|
||||
[//]: # (TODO: Instead of dumping dataframe with different format (like `_gen_dataset` and `_gen_day_dataset` in `qlib/contrib/data/highfreq_provider.py`), we encourage to implement different subclass of `Dataset` and `DataHandler`. This will keep the workflow cleaner and interfaces more consistent, and move all the complexity to the subclass.)
|
||||
|
||||
```
|
||||
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
|
||||
python scripts/gen_training_orders.py
|
||||
python scripts/merge_orders.py
|
||||
```
|
||||
|
||||
When finished, the structure under `data/` should be:
|
||||
|
||||
```
|
||||
data
|
||||
├── bin
|
||||
├── orders
|
||||
└── pickle
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
Each training task is specified by a config file. The config file for task `TASKNAME` is `exp_configs/train_TASKNAME.yml`. This example provides two training tasks:
|
||||
|
||||
- **PPO**: Method proposed by IJCAL 2020 paper "[An End-to-End Optimal Trade Execution Framework based on Proximal Policy Optimization](https://www.ijcai.org/proceedings/2020/0627.pdf)".
|
||||
- **OPDS**: Method proposed by AAAI 2021 paper "[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)".
|
||||
|
||||
The main differece between these two methods is their reward functions. Please see their config files for details.
|
||||
|
||||
Take OPDS as an example, to run the training workflow, run:
|
||||
|
||||
```
|
||||
python -m qlib.rl.contrib.train_onpolicy --config_path exp_configs/train_opds.yml --run_backtest
|
||||
```
|
||||
|
||||
Metrics, logs, and checkpoints will be stored under `outputs/opds` (configured by `exp_configs/train_opds.yml`).
|
||||
|
||||
## Backtest
|
||||
|
||||
Once the training workflow has completed, the trained model can be used for the backtesting workflow. Still taking OPDS as an example, once training is finished, the latest checkpoint of the model can be found at `outputs/opds/checkpoints/latest.pth`. To run backtest workflow:
|
||||
|
||||
1. Uncomment the `weight_file` parameter in `exp_configs/train_opds.yml` (it is commented by default). While it is possible to run the backtesting workflow without setting a checkpoint, this will lead to randomly initialized model results, thus making them meaningless.
|
||||
2. Run `python -m qlib.rl.contrib.backtest --config_path exp_configs/backtest_opds.yml`.
|
||||
|
||||
The backtest result is stored in `outputs/checkpoints/backtest_result.csv`.
|
||||
|
||||
In addition to OPDS and PPO, we also provide TWAP ([Time-weighted average price](https://en.wikipedia.org/wiki/Time-weighted_average_price)) as a weak baseline. The config file for TWAP is `exp_configs/backtest_twap.yml`.
|
||||
|
||||
### Gap between backtest and training pipeline's testing
|
||||
|
||||
It is worthy to notice that the results of the backtesting process may differ from the results of the testing process used during training.
|
||||
This is because different simulators are used to simulate market conditions during training and backtesting.
|
||||
In training pipeline, the simplified simulator called `SingleAssetOrderExecutionSimple` is used for efficiency reasons.
|
||||
`SingleAssetOrderExecutionSimple` makes no restriction to trading amounts.
|
||||
No matter what the amount of the order is, it can be completely executed.
|
||||
However, during backtesting, a more realistic simulator called `SingleAssetOrderExecution` is used.
|
||||
It takes into account practical constraints in more real-world scenarios (for example, the trading volume must be a multiple of the smallest trading unit).
|
||||
As a result, the amount of an order that is actually executed during backtesting may differ from the amount expected to be executed.
|
||||
|
||||
If you would like to obtain results that are exactly the same as those obtained during testing in the training pipeline, you could run training pipeline with only backtest phrase.
|
||||
In order to do this:
|
||||
- Modify the training config. Add the path of the checkpoint you want to use (see following for an example).
|
||||
- Run `python -m qlib.rl.contrib.train_onpolicy --config_path PATH/TO/CONFIG --run_backtest --no_training`
|
||||
|
||||
```yaml
|
||||
...
|
||||
policy:
|
||||
class: PPO # PPO, DQN
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
weight_file: PATH/TO/CHECKPOINT
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
...
|
||||
```
|
||||
|
||||
## Benchmarks (TBD)
|
||||
|
||||
To accurately evaluate the performance of models using Reinforcement Learning algorithms, it's best to run experiments multiple times and compute the average performance across all trials. However, given the time-consuming nature of model training, this is not always feasible. An alternative approach is to run each training task only once, selecting the 10 checkpoints with the highest validation performance to simulate multiple trials. In this example, we use "Price Advantage (PA)" as the metric for selecting these checkpoints. The average performance of these 10 checkpoints on the testing set is as follows:
|
||||
|
||||
| **Model** | **PA mean with std.** |
|
||||
|-----------------------------|-----------------------|
|
||||
| OPDS (with PPO policy) | 0.4785 ± 0.7815 |
|
||||
| OPDS (with DQN policy) | -0.0114 ± 0.5780 |
|
||||
| PPO | -1.0935 ± 0.0922 |
|
||||
| TWAP | ≈ 0.0 ± 0.0 |
|
||||
|
||||
The table above also includes TWAP as a rule-based baseline. The ideal PA of TWAP should be 0.0, however, in this example, the order execution is divided into two steps: first, the order is split equally among each half hour, and then each five minutes within each half hour. Since trading is forbidden during the last five minutes of the day, this approach may slightly differ from traditional TWAP over the course of a full day (as there are 5 minutes missing in the last "half hour"). Therefore, the PA of TWAP can be considered as a number that is close to 0.0. To verify this, you may run a TWAP backtest and check the results.
|
||||
53
examples/rl_order_execution/exp_configs/backtest_opds.yml
Executable file
53
examples/rl_order_execution/exp_configs/backtest_opds.yml
Executable file
@@ -0,0 +1,53 @@
|
||||
order_file: ./data/orders/test_orders.pkl
|
||||
start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
data_granularity: "5min"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
volume_threshold: null
|
||||
strategies:
|
||||
1day:
|
||||
class: SAOEIntStrategy
|
||||
kwargs:
|
||||
data_granularity: 5
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
max_step: 8
|
||||
values: 4
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
network:
|
||||
class: Recurrent
|
||||
kwargs: {}
|
||||
module_path: qlib.rl.order_execution.network
|
||||
policy:
|
||||
class: PPO # PPO, DQN
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
|
||||
# weight_file: outputs/opds/checkpoints/latest.pth
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
kwargs:
|
||||
data_dim: 5
|
||||
data_ticks: 48
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
module_path: qlib.rl.order_execution.strategy
|
||||
30min:
|
||||
class: TWAPStrategy
|
||||
kwargs: {}
|
||||
module_path: qlib.contrib.strategy.rule_strategy
|
||||
concurrency: 16
|
||||
output_dir: outputs/opds/
|
||||
53
examples/rl_order_execution/exp_configs/backtest_ppo.yml
Executable file
53
examples/rl_order_execution/exp_configs/backtest_ppo.yml
Executable file
@@ -0,0 +1,53 @@
|
||||
order_file: ./data/orders/test_orders.pkl
|
||||
start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
data_granularity: "5min"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
volume_threshold: null
|
||||
strategies:
|
||||
1day:
|
||||
class: SAOEIntStrategy
|
||||
kwargs:
|
||||
data_granularity: 5
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
max_step: 8
|
||||
values: 4
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
network:
|
||||
class: Recurrent
|
||||
kwargs: {}
|
||||
module_path: qlib.rl.order_execution.network
|
||||
policy:
|
||||
class: PPO # PPO, DQN
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
# Restore `weight_file` once the training workflow finishes. You can change the checkpoint file you want to use.
|
||||
# weight_file: outputs/ppo/checkpoints/latest.pth
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
kwargs:
|
||||
data_dim: 5
|
||||
data_ticks: 48
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
module_path: qlib.rl.order_execution.strategy
|
||||
30min:
|
||||
class: TWAPStrategy
|
||||
kwargs: {}
|
||||
module_path: qlib.contrib.strategy.rule_strategy
|
||||
concurrency: 16
|
||||
output_dir: outputs/ppo/
|
||||
21
examples/rl_order_execution/exp_configs/backtest_twap.yml
Executable file
21
examples/rl_order_execution/exp_configs/backtest_twap.yml
Executable file
@@ -0,0 +1,21 @@
|
||||
order_file: ./data/orders/test_orders.pkl
|
||||
start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
data_granularity: "5min"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
volume_threshold: null
|
||||
strategies:
|
||||
1day:
|
||||
class: TWAPStrategy
|
||||
kwargs: {}
|
||||
module_path: qlib.contrib.strategy.rule_strategy
|
||||
30min:
|
||||
class: TWAPStrategy
|
||||
kwargs: {}
|
||||
module_path: qlib.contrib.strategy.rule_strategy
|
||||
concurrency: 16
|
||||
output_dir: outputs/twap/
|
||||
66
examples/rl_order_execution/exp_configs/train_opds.yml
Executable file
66
examples/rl_order_execution/exp_configs/train_opds.yml
Executable file
@@ -0,0 +1,66 @@
|
||||
simulator:
|
||||
data_granularity: 5
|
||||
time_per_step: 30
|
||||
vol_limit: null
|
||||
env:
|
||||
concurrency: 32
|
||||
parallel_mode: dummy
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
values: 4
|
||||
max_step: 8
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
kwargs:
|
||||
data_dim: 5
|
||||
data_ticks: 48 # 48 = 240 min / 5 min
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
backtest: false
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
reward:
|
||||
class: PAPenaltyReward
|
||||
kwargs:
|
||||
penalty: 4.0
|
||||
scale: 0.01
|
||||
module_path: qlib.rl.order_execution.reward
|
||||
data:
|
||||
source:
|
||||
order_dir: ./data/orders
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: ["$close0", "$volume0"]
|
||||
feature_columns_yesterday: []
|
||||
total_time: 240
|
||||
default_start_time_index: 0
|
||||
default_end_time_index: 235
|
||||
proc_data_dim: 5
|
||||
num_workers: 0
|
||||
queue_size: 20
|
||||
network:
|
||||
class: Recurrent
|
||||
module_path: qlib.rl.order_execution.network
|
||||
policy:
|
||||
class: PPO # PPO, DQN
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
runtime:
|
||||
seed: 42
|
||||
use_cuda: false
|
||||
trainer:
|
||||
max_epoch: 500
|
||||
repeat_per_collect: 25
|
||||
earlystop_patience: 50
|
||||
episode_per_collect: 10000
|
||||
batch_size: 1024
|
||||
val_every_n_epoch: 4
|
||||
checkpoint_path: ./outputs/opds
|
||||
checkpoint_every_n_iters: 1
|
||||
67
examples/rl_order_execution/exp_configs/train_ppo.yml
Executable file
67
examples/rl_order_execution/exp_configs/train_ppo.yml
Executable file
@@ -0,0 +1,67 @@
|
||||
simulator:
|
||||
data_granularity: 5
|
||||
time_per_step: 30
|
||||
vol_limit: null
|
||||
env:
|
||||
concurrency: 32
|
||||
parallel_mode: dummy
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
values: 4
|
||||
max_step: 8
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
state_interpreter:
|
||||
class: FullHistoryStateInterpreter
|
||||
kwargs:
|
||||
data_dim: 5
|
||||
data_ticks: 48 # 48 = 240 min / 5 min
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
backtest: false
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
reward:
|
||||
class: PPOReward
|
||||
kwargs:
|
||||
max_step: 8
|
||||
start_time_index: 0
|
||||
end_time_index: 46 # 46 = (240 - 5) min / 5 min - 1
|
||||
module_path: qlib.rl.order_execution.reward
|
||||
data:
|
||||
source:
|
||||
order_dir: ./data/orders
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: ["$close0", "$volume0"]
|
||||
feature_columns_yesterday: []
|
||||
total_time: 240
|
||||
default_start_time_index: 0
|
||||
default_end_time_index: 235
|
||||
proc_data_dim: 5
|
||||
num_workers: 0
|
||||
queue_size: 20
|
||||
network:
|
||||
class: Recurrent
|
||||
module_path: qlib.rl.order_execution.network
|
||||
policy:
|
||||
class: PPO # PPO, DQN
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
module_path: qlib.rl.order_execution.policy
|
||||
runtime:
|
||||
seed: 42
|
||||
use_cuda: false
|
||||
trainer:
|
||||
max_epoch: 500
|
||||
repeat_per_collect: 25
|
||||
earlystop_patience: 50
|
||||
episode_per_collect: 10000
|
||||
batch_size: 1024
|
||||
val_every_n_epoch: 4
|
||||
checkpoint_path: ./outputs/ppo
|
||||
checkpoint_every_n_iters: 1
|
||||
46
examples/rl_order_execution/scripts/gen_pickle_data.py
Executable file
46
examples/rl_order_execution/scripts/gen_pickle_data.py
Executable file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import yaml
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
|
||||
from qlib.contrib.data.highfreq_provider import HighFreqProvider
|
||||
|
||||
loader = yaml.FullLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config", type=str, default="config.yml")
|
||||
parser.add_argument("-d", "--dest", type=str, default=".")
|
||||
parser.add_argument("-s", "--split", type=str, choices=["none", "date", "stock", "both"], default="stock")
|
||||
args = parser.parse_args()
|
||||
|
||||
conf = yaml.load(open(args.config), Loader=loader)
|
||||
|
||||
for k, v in conf.items():
|
||||
if isinstance(v, dict) and "path" in v:
|
||||
v["path"] = os.path.join(args.dest, v["path"])
|
||||
provider = HighFreqProvider(**conf)
|
||||
|
||||
# Gen dataframe
|
||||
if "feature_conf" in conf:
|
||||
feature = provider._gen_dataframe(deepcopy(provider.feature_conf))
|
||||
if "backtest_conf" in conf:
|
||||
backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf))
|
||||
|
||||
provider.feature_conf["path"] = os.path.splitext(provider.feature_conf["path"])[0] + "/"
|
||||
provider.backtest_conf["path"] = os.path.splitext(provider.backtest_conf["path"])[0] + "/"
|
||||
# Split by date
|
||||
if args.split == "date" or args.split == "both":
|
||||
provider._gen_day_dataset(deepcopy(provider.feature_conf), "feature")
|
||||
provider._gen_day_dataset(deepcopy(provider.backtest_conf), "backtest")
|
||||
|
||||
# Split by stock
|
||||
if args.split == "stock" or args.split == "both":
|
||||
provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature")
|
||||
provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest")
|
||||
|
||||
shutil.rmtree("stat/", ignore_errors=True)
|
||||
53
examples/rl_order_execution/scripts/gen_training_orders.py
Executable file
53
examples/rl_order_execution/scripts/gen_training_orders.py
Executable file
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
DATA_PATH = Path(os.path.join("data", "pickle", "backtest"))
|
||||
OUTPUT_PATH = Path(os.path.join("data", "orders"))
|
||||
|
||||
|
||||
def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
|
||||
dataset = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
|
||||
df = dataset.handler.fetch(level=None).reset_index()
|
||||
if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5:
|
||||
return False
|
||||
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
df = df.set_index(["instrument", "datetime", "date"])
|
||||
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
|
||||
|
||||
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
|
||||
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
|
||||
order_all = order_all[order_all["amount"] > 0.0]
|
||||
order_all["order_type"] = 0
|
||||
order_all = order_all.drop(columns=["$volume0"])
|
||||
|
||||
order_train = order_all[order_all.index.get_level_values(0) <= pd.Timestamp("2021-06-30")]
|
||||
order_test = order_all[order_all.index.get_level_values(0) > pd.Timestamp("2021-06-30")]
|
||||
order_valid = order_test[order_test.index.get_level_values(0) <= pd.Timestamp("2021-09-30")]
|
||||
order_test = order_test[order_test.index.get_level_values(0) > pd.Timestamp("2021-09-30")]
|
||||
|
||||
for order, tag in zip((order_train, order_valid, order_test, order_all), ("train", "valid", "test", "all")):
|
||||
path = OUTPUT_PATH / tag
|
||||
os.makedirs(path, exist_ok=True)
|
||||
if len(order) > 0:
|
||||
order.to_pickle(path / f"{stock}.pkl.target")
|
||||
return True
|
||||
|
||||
|
||||
np.random.seed(1234)
|
||||
file_list = sorted(os.listdir(DATA_PATH))
|
||||
stocks = [f.replace(".pkl", "") for f in file_list]
|
||||
np.random.shuffle(stocks)
|
||||
|
||||
cnt = 0
|
||||
for stock in stocks:
|
||||
if generate_order(stock, 0, 240 // 5 - 1):
|
||||
cnt += 1
|
||||
if cnt == 100:
|
||||
break
|
||||
15
examples/rl_order_execution/scripts/merge_orders.py
Executable file
15
examples/rl_order_execution/scripts/merge_orders.py
Executable file
@@ -0,0 +1,15 @@
|
||||
import pickle
|
||||
import os
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
for tag in ["test", "valid"]:
|
||||
files = os.listdir(os.path.join("data/orders/", tag))
|
||||
dfs = []
|
||||
for f in tqdm(files):
|
||||
df = pickle.load(open(os.path.join("data/orders/", tag, f), "rb"))
|
||||
df = df.drop(["$close0"], axis=1)
|
||||
dfs.append(df)
|
||||
|
||||
total_df = pd.concat(dfs)
|
||||
pickle.dump(total_df, open(os.path.join("data", "orders", f"{tag}_orders.pkl"), "wb"))
|
||||
77
examples/rl_order_execution/scripts/pickle_data_config.yml
Executable file
77
examples/rl_order_execution/scripts/pickle_data_config.yml
Executable file
@@ -0,0 +1,77 @@
|
||||
# start & end time for training/validation/test datasets
|
||||
start_time: !!str &start 2020-01-01
|
||||
end_time: !!str &end 2021-12-31
|
||||
train_end_time: !!str &tend 2021-06-30
|
||||
valid_start_time: !!str &vstart 2021-07-01
|
||||
valid_end_time: !!str &vend 2021-09-30
|
||||
test_start_time: !!str &tstart 2021-10-01
|
||||
# the instrument set
|
||||
instruments: &ins csi300s19_22
|
||||
# qlib related configuration
|
||||
qlib_conf:
|
||||
provider_uri:
|
||||
5min: ./data/bin # path to generated qlib bin
|
||||
redis_port: 233
|
||||
feature_conf:
|
||||
path: ./data/pickle/feature.pkl # output path of feature
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: HighFreqGeneralHandler
|
||||
module_path: qlib.contrib.data.highfreq_handler
|
||||
kwargs:
|
||||
start_time: *start
|
||||
end_time: *end
|
||||
fit_start_time: *start
|
||||
fit_end_time: *tend
|
||||
instruments: *ins
|
||||
day_length: 240 # how many minutes in one trading day
|
||||
freq: 5min
|
||||
columns: ["$open", "$high", "$low", "$close"]
|
||||
infer_processors:
|
||||
- class: HighFreqNorm
|
||||
module_path: qlib.contrib.data.highfreq_processor
|
||||
kwargs:
|
||||
feature_save_dir: ./stat/ # output path of statistics of features (for feature normalization)
|
||||
norm_groups:
|
||||
price: 8
|
||||
volume: 2
|
||||
inst_processors:
|
||||
- class: TimeRangeFlt
|
||||
module_path: qlib.data.dataset.processor
|
||||
kwargs:
|
||||
start_time: "2020-01-01"
|
||||
end_time: "2021-12-31"
|
||||
freq: 5min
|
||||
segments:
|
||||
train: !!python/tuple [*start, *tend]
|
||||
valid: !!python/tuple [*vstart, *vend]
|
||||
test: !!python/tuple [*tstart, *end]
|
||||
backtest_conf:
|
||||
path: ./data/pickle/backtest.pkl # output path of backtest
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: HighFreqGeneralBacktestHandler
|
||||
module_path: qlib.contrib.data.highfreq_handler
|
||||
kwargs:
|
||||
start_time: *start
|
||||
end_time: *end
|
||||
instruments: *ins
|
||||
day_length: 240
|
||||
freq: 5min
|
||||
columns: ["$close", "$volume"]
|
||||
inst_processors:
|
||||
- class: TimeRangeFlt
|
||||
module_path: qlib.data.dataset.processor
|
||||
kwargs:
|
||||
start_time: "2020-01-01"
|
||||
end_time: "2021-12-31"
|
||||
freq: 5min
|
||||
segments:
|
||||
train: !!python/tuple [*start, *tend]
|
||||
valid: !!python/tuple [*vstart, *vend]
|
||||
test: !!python/tuple [*tstart, *end]
|
||||
freq: 5min
|
||||
@@ -253,7 +253,7 @@ class ModelRunner:
|
||||
default "" indicates that
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path (NOTE: the local path must be a absolute path)
|
||||
it could be URI on the remote or local path (NOTE: the local path must be an absolute path)
|
||||
exp_folder_name: str
|
||||
the name of the experiment folder
|
||||
wait_before_rm_env : bool
|
||||
|
||||
@@ -88,6 +88,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.tests.data import GetData\n",
|
||||
"\n",
|
||||
"GetData().qlib_data(exists_skip=True)"
|
||||
]
|
||||
},
|
||||
@@ -99,6 +100,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import qlib\n",
|
||||
"\n",
|
||||
"qlib.init()"
|
||||
]
|
||||
},
|
||||
@@ -134,7 +136,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.data import D\n",
|
||||
"D.calendar(start_time='2010-01-01', end_time='2017-12-31', freq='day')[:2] # calendar data"
|
||||
"\n",
|
||||
"print(D.calendar(start_time=\"2010-01-01\", end_time=\"2017-12-31\", freq=\"day\")[:2]) # calendar data"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -152,7 +155,12 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = D.features(['SH601216'], ['$open', '$high', '$low', '$close', '$factor'], start_time='2020-05-01', end_time='2020-05-31') "
|
||||
"df = D.features(\n",
|
||||
" [\"SH601216\"],\n",
|
||||
" [\"$open\", \"$high\", \"$low\", \"$close\", \"$factor\"],\n",
|
||||
" start_time=\"2020-05-01\",\n",
|
||||
" end_time=\"2020-05-31\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -163,11 +171,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import plotly.graph_objects as go\n",
|
||||
"fig = go.Figure(data=[go.Candlestick(x=df.index.get_level_values(\"datetime\"),\n",
|
||||
" open=df['$open'],\n",
|
||||
" high=df['$high'],\n",
|
||||
" low=df['$low'],\n",
|
||||
" close=df['$close'])])\n",
|
||||
"\n",
|
||||
"fig = go.Figure(\n",
|
||||
" data=[\n",
|
||||
" go.Candlestick(\n",
|
||||
" x=df.index.get_level_values(\"datetime\"),\n",
|
||||
" open=df[\"$open\"],\n",
|
||||
" high=df[\"$high\"],\n",
|
||||
" low=df[\"$low\"],\n",
|
||||
" close=df[\"$close\"],\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
@@ -197,11 +212,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import plotly.graph_objects as go\n",
|
||||
"fig = go.Figure(data=[go.Candlestick(x=df.index.get_level_values(\"datetime\"),\n",
|
||||
" open=df['$open'] / df['$factor'],\n",
|
||||
" high=df['$high'] / df['$factor'],\n",
|
||||
" low=df['$low'] / df['$factor'],\n",
|
||||
" close=df['$close'] / df['$factor'])])\n",
|
||||
"\n",
|
||||
"fig = go.Figure(\n",
|
||||
" data=[\n",
|
||||
" go.Candlestick(\n",
|
||||
" x=df.index.get_level_values(\"datetime\"),\n",
|
||||
" open=df[\"$open\"] / df[\"$factor\"],\n",
|
||||
" high=df[\"$high\"] / df[\"$factor\"],\n",
|
||||
" low=df[\"$low\"] / df[\"$factor\"],\n",
|
||||
" close=df[\"$close\"] / df[\"$factor\"],\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
@@ -240,7 +262,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# dynamic universe\n",
|
||||
"universe = D.list_instruments(D.instruments('csi100'), start_time='2010-01-01', end_time='2020-12-31')\n",
|
||||
"universe = D.list_instruments(D.instruments(\"csi100\"), start_time=\"2010-01-01\", end_time=\"2020-12-31\")\n",
|
||||
"pprint(universe)"
|
||||
]
|
||||
},
|
||||
@@ -271,8 +293,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = D.features(D.instruments('csi100'), ['$close'], start_time='2010-01-01', end_time='2020-12-31') \n",
|
||||
"df.groupby('datetime').size().plot()"
|
||||
"df = D.features(D.instruments(\"csi100\"), [\"$close\"], start_time=\"2010-01-01\", end_time=\"2020-12-31\")\n",
|
||||
"df.groupby(\"datetime\").size().plot()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -313,8 +335,7 @@
|
||||
" !cd ../../scripts/data_collector/pit/ && pip install -r requirements.txt\n",
|
||||
" !cd ../../scripts/data_collector/pit/ && python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex \"^(600519|000725).*\"\n",
|
||||
" !cd ../../scripts/data_collector/pit/ && python collector.py normalize_data --interval quarterly --source_dir ~/.qlib/stock_data/source/pit --normalize_dir ~/.qlib/stock_data/source/pit_normalized\n",
|
||||
" !cd ../../scripts/ && python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly\n",
|
||||
" pass"
|
||||
" !cd ../../scripts/ && python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -338,7 +359,13 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"instruments = [\"sh600519\"]\n",
|
||||
"data = D.features(instruments, ['P($$roewa_q)'], start_time=\"2019-01-01\", end_time=\"2019-07-19\", freq=\"day\")"
|
||||
"data = D.features(\n",
|
||||
" instruments,\n",
|
||||
" [\"P($$roewa_q)\"],\n",
|
||||
" start_time=\"2019-01-01\",\n",
|
||||
" end_time=\"2019-07-19\",\n",
|
||||
" freq=\"day\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -366,7 +393,10 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"D.features([\"sh600519\"], ['(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'])"
|
||||
"D.features(\n",
|
||||
" [\"sh600519\"],\n",
|
||||
" [\"(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -418,7 +448,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qdl = QlibDataLoader(config=(['$close / Ref($close, 10)'], ['RET10']))"
|
||||
"qdl = QlibDataLoader(config=([\"$close / Ref($close, 10)\"], [\"RET10\"]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -428,7 +458,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qdl.load(instruments=['sh600519'], start_time='20190101', end_time='20191231')"
|
||||
"qdl.load(instruments=[\"sh600519\"], start_time=\"20190101\", end_time=\"20191231\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -456,7 +486,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = qdl.load(instruments=['sh600519'], start_time='20190101', end_time='20191231')"
|
||||
"df = qdl.load(instruments=[\"sh600519\"], start_time=\"20190101\", end_time=\"20191231\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -476,7 +506,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df.plot(kind='hist')"
|
||||
"df.plot(kind=\"hist\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -508,9 +538,16 @@
|
||||
"source": [
|
||||
"# NOTE: normally, the training & validation time range will be `fit_start_time` , `fit_end_time`\n",
|
||||
"# however,all the components are decomposed, so the training & validation time range is unknown when preprocessing.\n",
|
||||
"dh = DataHandlerLP(instruments=['sh600519'], start_time='20170101', end_time='20191231',\n",
|
||||
" infer_processors=[ZScoreNorm(fit_start_time='20170101', fit_end_time='20181231'), Fillna()],\n",
|
||||
" data_loader=qdl)"
|
||||
"dh = DataHandlerLP(\n",
|
||||
" instruments=[\"sh600519\"],\n",
|
||||
" start_time=\"20170101\",\n",
|
||||
" end_time=\"20191231\",\n",
|
||||
" infer_processors=[\n",
|
||||
" ZScoreNorm(fit_start_time=\"20170101\", fit_end_time=\"20181231\"),\n",
|
||||
" Fillna(),\n",
|
||||
" ],\n",
|
||||
" data_loader=qdl,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -550,7 +587,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df.plot(kind='hist')"
|
||||
"df.plot(kind=\"hist\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -586,7 +623,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds = DatasetH(dh, segments={\"train\": ('20180101', '20181231'), \"valid\": ('20190101', '20191231')})"
|
||||
"ds = DatasetH(dh, segments={\"train\": (\"20180101\", \"20181231\"), \"valid\": (\"20190101\", \"20191231\")})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -596,7 +633,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds.prepare('train')"
|
||||
"ds.prepare(\"train\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -606,7 +643,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds.prepare('valid')"
|
||||
"ds.prepare(\"valid\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -628,8 +665,12 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds = TSDatasetH(step_len=10, handler=dh, segments={\"train\": ('20180101', '20181231'), \"valid\": ('20190101', '20191231')})\n",
|
||||
"train_sampler = ds.prepare('train')"
|
||||
"ds = TSDatasetH(\n",
|
||||
" step_len=10,\n",
|
||||
" handler=dh,\n",
|
||||
" segments={\"train\": (\"20180101\", \"20181231\"), \"valid\": (\"20190101\", \"20191231\")},\n",
|
||||
")\n",
|
||||
"train_sampler = ds.prepare(\"train\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -649,7 +690,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_sampler[0] # Retrieving the first example"
|
||||
"train_sampler[0] # Retrieving the first example"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -659,7 +700,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_sampler['2018-01-08', 'sh600519'] # get the time series by <'timestamp', 'instrument_id'> index"
|
||||
"train_sampler[\"2018-01-08\", \"sh600519\"] # get the time series by <'timestamp', 'instrument_id'> index"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -682,11 +723,11 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"handler_kwargs = {\n",
|
||||
" \"start_time\": \"2008-01-01\",\n",
|
||||
" \"end_time\": \"2020-08-01\",\n",
|
||||
" \"fit_start_time\": \"2008-01-01\",\n",
|
||||
" \"fit_end_time\": \"2014-12-31\",\n",
|
||||
" \"instruments\": MARKET,\n",
|
||||
" \"start_time\": \"2008-01-01\",\n",
|
||||
" \"end_time\": \"2020-08-01\",\n",
|
||||
" \"fit_start_time\": \"2008-01-01\",\n",
|
||||
" \"fit_end_time\": \"2014-12-31\",\n",
|
||||
" \"instruments\": MARKET,\n",
|
||||
"}\n",
|
||||
"handler_conf = {\n",
|
||||
" \"class\": \"Alpha158\",\n",
|
||||
@@ -735,6 +776,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.data.handler import Alpha158\n",
|
||||
"\n",
|
||||
"hd = Alpha158(**handler_kwargs)"
|
||||
]
|
||||
},
|
||||
@@ -826,7 +868,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hd.process_type # appending type"
|
||||
"hd.process_type # appending type"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -857,16 +899,16 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_conf = {\n",
|
||||
" \"class\": \"DatasetH\",\n",
|
||||
" \"module_path\": \"qlib.data.dataset\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"handler\": hd,\n",
|
||||
" \"segments\": {\n",
|
||||
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
|
||||
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
|
||||
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
|
||||
" },\n",
|
||||
" \"class\": \"DatasetH\",\n",
|
||||
" \"module_path\": \"qlib.data.dataset\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"handler\": hd,\n",
|
||||
" \"segments\": {\n",
|
||||
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
|
||||
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
|
||||
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@@ -908,7 +950,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = init_instance_by_config({\n",
|
||||
"model = init_instance_by_config(\n",
|
||||
" {\n",
|
||||
" \"class\": \"LGBModel\",\n",
|
||||
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
@@ -922,7 +965,8 @@
|
||||
" \"num_leaves\": 210,\n",
|
||||
" \"num_threads\": 20,\n",
|
||||
" },\n",
|
||||
"})"
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -938,7 +982,7 @@
|
||||
" R.save_objects(trained_model=model)\n",
|
||||
"\n",
|
||||
" rec = R.get_recorder()\n",
|
||||
" rid = rec.id # save the record id\n",
|
||||
" rid = rec.id # save the record id\n",
|
||||
"\n",
|
||||
" # Inference and saving signal\n",
|
||||
" sr = SignalRecord(model, dataset, rec)\n",
|
||||
@@ -1001,12 +1045,11 @@
|
||||
"\n",
|
||||
"# backtest and analysis\n",
|
||||
"with R.start(experiment_name=EXP_NAME, recorder_id=rid, resume=True):\n",
|
||||
"\n",
|
||||
" # signal-based analysis\n",
|
||||
" rec = R.get_recorder()\n",
|
||||
" sar = SigAnaRecord(rec)\n",
|
||||
" sar.generate()\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" # portfolio-based analysis: backtest\n",
|
||||
" par = PortAnaRecord(rec, port_analysis_config, \"day\")\n",
|
||||
" par.generate()"
|
||||
@@ -1137,7 +1180,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
|
||||
"label_df.columns = ['label']"
|
||||
"label_df.columns = [\"label\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -38,7 +38,7 @@
|
||||
" # install qlib\n",
|
||||
" ! pip install --upgrade numpy\n",
|
||||
" ! pip install pyqlib\n",
|
||||
" if 'google.colab' in sys.modules:\n",
|
||||
" if \"google.colab\" in sys.modules:\n",
|
||||
" # The Google colab environment is a little outdated. We have to downgrade the pyyaml to make it compatible with other packages\n",
|
||||
" ! pip install pyyaml==5.4.1\n",
|
||||
" # reload\n",
|
||||
@@ -50,7 +50,8 @@
|
||||
" scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
|
||||
" scripts_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" import requests\n",
|
||||
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n",
|
||||
"\n",
|
||||
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\", timeout=10) as resp:\n",
|
||||
" with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
|
||||
" fp.write(resp.content)"
|
||||
]
|
||||
@@ -61,14 +62,13 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"import qlib\n",
|
||||
"import pandas as pd\n",
|
||||
"from qlib.constant import REG_CN\n",
|
||||
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
|
||||
"from qlib.utils import flatten_dict\n"
|
||||
"from qlib.utils import flatten_dict"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -86,6 +86,7 @@
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(scripts_dir))\n",
|
||||
" from get_data import GetData\n",
|
||||
"\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
@@ -169,7 +170,7 @@
|
||||
" R.log_params(**flatten_dict(task))\n",
|
||||
" model.fit(dataset)\n",
|
||||
" R.save_objects(trained_model=model)\n",
|
||||
" rid = R.get_recorder().id\n"
|
||||
" rid = R.get_recorder().id"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -238,7 +239,7 @@
|
||||
"\n",
|
||||
" # backtest & analysis\n",
|
||||
" par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n",
|
||||
" par.generate()\n"
|
||||
" par.generate()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -256,6 +257,7 @@
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
"from qlib.data import D\n",
|
||||
"\n",
|
||||
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"print(recorder)\n",
|
||||
"pred_df = recorder.load_object(\"pred.pkl\")\n",
|
||||
@@ -317,7 +319,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
|
||||
"label_df.columns = ['label']"
|
||||
"label_df.columns = [\"label\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.8.6.99"
|
||||
__version__ = "0.9.1.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
@@ -34,8 +34,7 @@ def init(default_conf="client", **kwargs):
|
||||
from .config import C # pylint: disable=C0415
|
||||
from .data.cache import H # pylint: disable=C0415
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
logger = get_module_logger("Initialization", level=logging.INFO)
|
||||
logger = get_module_logger("Initialization")
|
||||
|
||||
skip_if_reg = kwargs.pop("skip_if_reg", False)
|
||||
if skip_if_reg and C.registered:
|
||||
@@ -48,6 +47,7 @@ def init(default_conf="client", **kwargs):
|
||||
if clear_mem_cache:
|
||||
H.clear()
|
||||
C.set(default_conf, **kwargs)
|
||||
get_module_logger.setLevel(C.logging_level)
|
||||
|
||||
# mount nfs
|
||||
for _freq, provider_uri in C.provider_uri.items():
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
|
||||
import pandas as pd
|
||||
|
||||
from .account import Account
|
||||
from .report import Indicator, PortfolioMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..strategy.base import BaseStrategy
|
||||
@@ -20,7 +19,7 @@ if TYPE_CHECKING:
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
from ..utils import init_instance_by_config
|
||||
from .backtest import backtest_loop, collect_data_loop
|
||||
from .backtest import INDICATOR_METRIC, PORT_METRIC, backtest_loop, collect_data_loop
|
||||
from .decision import Order
|
||||
from .exchange import Exchange
|
||||
from .utils import CommonInfrastructure
|
||||
@@ -41,8 +40,8 @@ def get_exchange(
|
||||
open_cost: float = 0.0015,
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] = None,
|
||||
limit_threshold: Union[Tuple[str, str], float, None] | None = None,
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Exchange:
|
||||
"""get_exchange
|
||||
@@ -114,7 +113,7 @@ def get_exchange(
|
||||
def create_account_instance(
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
benchmark: str,
|
||||
benchmark: Optional[str],
|
||||
account: Union[float, int, dict],
|
||||
pos_type: str = "Position",
|
||||
) -> Account:
|
||||
@@ -163,7 +162,9 @@ def create_account_instance(
|
||||
init_cash=init_cash,
|
||||
position_dict=position_dict,
|
||||
pos_type=pos_type,
|
||||
benchmark_config={
|
||||
benchmark_config={}
|
||||
if benchmark is None
|
||||
else {
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
@@ -176,9 +177,9 @@ def get_strategy_executor(
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
benchmark: Optional[str] = "SH000300",
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
exchange_kwargs: Union[dict, Exchange] = {}, # TODO: rename parameter
|
||||
pos_type: str = "Position",
|
||||
) -> Tuple[BaseStrategy, BaseExecutor]:
|
||||
|
||||
@@ -196,12 +197,15 @@ def get_strategy_executor(
|
||||
pos_type=pos_type,
|
||||
)
|
||||
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
exchange_kwargs["start_time"] = start_time
|
||||
if "end_time" not in exchange_kwargs:
|
||||
exchange_kwargs["end_time"] = end_time
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
if isinstance(exchange_kwargs, Exchange):
|
||||
trade_exchange = exchange_kwargs
|
||||
else:
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
exchange_kwargs["start_time"] = start_time
|
||||
if "end_time" not in exchange_kwargs:
|
||||
exchange_kwargs["end_time"] = end_time
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
|
||||
@@ -221,7 +225,7 @@ def backtest(
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
|
||||
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
|
||||
executor in the nested decision execution
|
||||
|
||||
@@ -242,7 +246,7 @@ def backtest(
|
||||
benchmark: str
|
||||
the benchmark for reporting.
|
||||
account : Union[float, int, Position]
|
||||
information for describing how to creating the account
|
||||
information for describing how to create the account
|
||||
For `float` or `int`:
|
||||
Using Account with only initial cash
|
||||
For `Position`:
|
||||
@@ -254,9 +258,9 @@ def backtest(
|
||||
|
||||
Returns
|
||||
-------
|
||||
portfolio_metrics_dict: Dict[PortfolioMetrics]
|
||||
portfolio_dict: PORT_METRIC
|
||||
it records the trading portfolio_metrics information
|
||||
indicator_dict: Dict[Indicator]
|
||||
indicator_dict: INDICATOR_METRIC
|
||||
it computes the trading indicator
|
||||
It is organized in a dict format
|
||||
|
||||
@@ -271,8 +275,7 @@ def backtest(
|
||||
exchange_kwargs,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
portfolio_metrics, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
return portfolio_metrics, indicator
|
||||
return backtest_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
|
||||
|
||||
def collect_data(
|
||||
@@ -284,7 +287,7 @@ def collect_data(
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
return_value: dict = None,
|
||||
return_value: dict | None = None,
|
||||
) -> Generator[object, None, None]:
|
||||
"""initialize the strategy and executor, then collect the trade decision data for rl training
|
||||
|
||||
@@ -345,4 +348,4 @@ def format_decisions(
|
||||
return res
|
||||
|
||||
|
||||
__all__ = ["Order", "backtest"]
|
||||
__all__ = ["Order", "backtest", "get_strategy_executor"]
|
||||
|
||||
@@ -152,7 +152,9 @@ class Account:
|
||||
# trading related metrics(e.g. high-frequency trading)
|
||||
self.indicator = Indicator()
|
||||
|
||||
def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabled: bool = None) -> None:
|
||||
def reset(
|
||||
self, freq: str | None = None, benchmark_config: dict | None = None, port_metr_enabled: bool | None = None
|
||||
) -> None:
|
||||
"""reset freq and report of account
|
||||
|
||||
Parameters
|
||||
@@ -236,7 +238,7 @@ class Account:
|
||||
if not self.current_position.skip_update():
|
||||
stock_list = self.current_position.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
# if suspended, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time))
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
|
||||
from typing import Dict, TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import BaseTradeDecision
|
||||
from qlib.backtest.report import Indicator, PortfolioMetrics
|
||||
from qlib.backtest.report import Indicator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
@@ -19,30 +19,35 @@ from tqdm.auto import tqdm
|
||||
from ..utils.time import Freq
|
||||
|
||||
|
||||
PORT_METRIC = Dict[str, Tuple[pd.DataFrame, dict]]
|
||||
INDICATOR_METRIC = Dict[str, Tuple[pd.DataFrame, Indicator]]
|
||||
|
||||
|
||||
def backtest_loop(
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
|
||||
"""backtest function for the interaction of the outermost strategy and executor in the nested decision execution
|
||||
|
||||
please refer to the docs of `collect_data_loop`
|
||||
|
||||
Returns
|
||||
-------
|
||||
portfolio_metrics: PortfolioMetrics
|
||||
portfolio_dict: PORT_METRIC
|
||||
it records the trading portfolio_metrics information
|
||||
indicator: Indicator
|
||||
indicator_dict: INDICATOR_METRIC
|
||||
it computes the trading indicator
|
||||
"""
|
||||
return_value: dict = {}
|
||||
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||
pass
|
||||
|
||||
portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
|
||||
indicator = cast(Indicator, return_value.get("indicator"))
|
||||
return portfolio_metrics, indicator
|
||||
portfolio_dict = cast(PORT_METRIC, return_value.get("portfolio_dict"))
|
||||
indicator_dict = cast(INDICATOR_METRIC, return_value.get("indicator_dict"))
|
||||
|
||||
return portfolio_dict, indicator_dict
|
||||
|
||||
|
||||
def collect_data_loop(
|
||||
@@ -50,7 +55,8 @@ def collect_data_loop(
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
return_value: dict = None,
|
||||
return_value: dict | None = None,
|
||||
show_progress: bool = True,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
@@ -69,6 +75,8 @@ def collect_data_loop(
|
||||
the outermost executor
|
||||
return_value : dict
|
||||
used for backtest_loop
|
||||
show_progress: bool
|
||||
whether to show execution progress
|
||||
|
||||
Yields
|
||||
-------
|
||||
@@ -78,23 +86,29 @@ def collect_data_loop(
|
||||
trade_executor.reset(start_time=start_time, end_time=end_time)
|
||||
trade_strategy.reset(level_infra=trade_executor.get_level_infra())
|
||||
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
|
||||
disable = not show_progress
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop", disable=disable) as bar:
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = yield from trade_executor.collect_data(_trade_decision, level=0)
|
||||
trade_strategy.post_exe_step(_execute_result)
|
||||
bar.update(1)
|
||||
trade_strategy.post_upper_level_exe_step()
|
||||
|
||||
if return_value is not None:
|
||||
all_executors = trade_executor.get_all_executors()
|
||||
all_portfolio_metrics = {
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
|
||||
for _executor in all_executors
|
||||
if _executor.trade_account.is_port_metr_enabled()
|
||||
}
|
||||
all_indicators = {}
|
||||
for _executor in all_executors:
|
||||
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
|
||||
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
|
||||
return_value.update({"portfolio_metrics": all_portfolio_metrics, "indicator": all_indicators})
|
||||
|
||||
portfolio_dict: PORT_METRIC = {}
|
||||
indicator_dict: INDICATOR_METRIC = {}
|
||||
|
||||
for executor in all_executors:
|
||||
key = "{}{}".format(*Freq.parse(executor.time_per_step))
|
||||
if executor.trade_account.is_port_metr_enabled():
|
||||
portfolio_dict[key] = executor.trade_account.get_portfolio_metrics()
|
||||
|
||||
indicator_df = executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
indicator_obj = executor.trade_account.get_trade_indicator()
|
||||
indicator_dict[key] = (indicator_df, indicator_obj)
|
||||
|
||||
return_value.update({"portfolio_dict": portfolio_dict, "indicator_dict": indicator_dict})
|
||||
|
||||
@@ -135,6 +135,21 @@ class Order:
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@property
|
||||
def key_by_day(self) -> tuple:
|
||||
"""A hashable & unique key to identify this order, under the granularity in day."""
|
||||
return self.stock_id, self.date, self.direction
|
||||
|
||||
@property
|
||||
def key(self) -> tuple:
|
||||
"""A hashable & unique key to identify this order."""
|
||||
return self.stock_id, self.start_time, self.end_time, self.direction
|
||||
|
||||
@property
|
||||
def date(self) -> pd.Timestamp:
|
||||
"""Date of the order."""
|
||||
return pd.Timestamp(self.start_time.replace(hour=0, minute=0, second=0))
|
||||
|
||||
|
||||
class OrderHelper:
|
||||
"""
|
||||
@@ -239,7 +254,7 @@ class IdxTradeRange(TradeRange):
|
||||
self._start_idx = start_idx
|
||||
self._end_idx = end_idx
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
|
||||
def __call__(self, trade_calendar: TradeCalendarManager | None = None) -> Tuple[int, int]:
|
||||
return self._start_idx, self._end_idx
|
||||
|
||||
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
@@ -286,7 +301,7 @@ class TradeRangeByTime(TradeRange):
|
||||
|
||||
class BaseTradeDecision(Generic[DecisionType]):
|
||||
"""
|
||||
Trade decisions ara made by strategy and executed by executor
|
||||
Trade decisions are made by strategy and executed by executor
|
||||
|
||||
Motivation:
|
||||
Here are several typical scenarios for `BaseTradeDecision`
|
||||
@@ -300,7 +315,7 @@ class BaseTradeDecision(Generic[DecisionType]):
|
||||
2. Same as `case 1.3`
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None) -> None:
|
||||
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange, None] = None) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -539,7 +554,7 @@ class TradeDecisionWO(BaseTradeDecision[Order]):
|
||||
self,
|
||||
order_list: List[Order],
|
||||
strategy: BaseStrategy,
|
||||
trade_range: Union[Tuple[int, int], TradeRange] = None,
|
||||
trade_range: Union[Tuple[int, int], TradeRange, None] = None,
|
||||
) -> None:
|
||||
super().__init__(strategy, trade_range=trade_range)
|
||||
self.order_list = cast(List[Order], order_list)
|
||||
@@ -561,3 +576,21 @@ class TradeDecisionWO(BaseTradeDecision[Order]):
|
||||
f"trade_range: {self.trade_range}; "
|
||||
f"order_list[{len(self.order_list)}]"
|
||||
)
|
||||
|
||||
|
||||
class TradeDecisionWithDetails(TradeDecisionWO):
|
||||
"""
|
||||
Decision with detail information.
|
||||
Detail information is used to generate execution reports.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order_list: List[Order],
|
||||
strategy: BaseStrategy,
|
||||
trade_range: Optional[Tuple[int, int]] = None,
|
||||
details: Optional[Any] = None,
|
||||
) -> None:
|
||||
super().__init__(order_list, strategy, trade_range)
|
||||
|
||||
self.details = details
|
||||
|
||||
@@ -18,7 +18,7 @@ import pandas as pd
|
||||
from qlib.backtest.position import BasePosition
|
||||
|
||||
from ..config import C
|
||||
from ..constant import REG_CN
|
||||
from ..constant import REG_CN, REG_TW
|
||||
from ..data.data import D
|
||||
from ..log import get_module_logger
|
||||
from .decision import Order, OrderDir, OrderHelper
|
||||
@@ -26,16 +26,25 @@ from .high_performance_ds import BaseQuote, NumpyQuote
|
||||
|
||||
|
||||
class Exchange:
|
||||
# `quote_df` is a pd.DataFrame class that contains basic information for backtesting
|
||||
# After some processing, the data will later be maintained by `quote_cls` object for faster data retrieving.
|
||||
# Some conventions for `quote_df`
|
||||
# - $close is for calculating the total value at end of each day.
|
||||
# - if $close is None, the stock on that day is regarded as suspended.
|
||||
# - $factor is for rounding to the trading unit;
|
||||
# - if any $factor is missing when $close exists, trading unit rounding will be disabled
|
||||
quote_df: pd.DataFrame
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
freq: str = "day",
|
||||
start_time: Union[pd.Timestamp, str] = None,
|
||||
end_time: Union[pd.Timestamp, str] = None,
|
||||
codes: Union[list, str] = "all",
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] = None,
|
||||
deal_price: Union[str, Tuple[str, str], List[str], None] = None,
|
||||
subscribe_fields: list = [],
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
volume_threshold: Union[tuple, dict] = None,
|
||||
volume_threshold: Union[tuple, dict, None] = None,
|
||||
open_cost: float = 0.0015,
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
@@ -132,17 +141,17 @@ class Exchange:
|
||||
if deal_price is None:
|
||||
deal_price = C.deal_price
|
||||
|
||||
# we have some verbose information here. So logging is enable
|
||||
# we have some verbose information here. So logging is enabled
|
||||
self.logger = get_module_logger("online operator")
|
||||
|
||||
# TODO: the quote, trade_dates, codes are not necessary.
|
||||
# It is just for performance consideration.
|
||||
self.limit_type = self._get_limit_type(limit_threshold)
|
||||
if limit_threshold is None:
|
||||
if C.region == REG_CN:
|
||||
if C.region in [REG_CN, REG_TW]:
|
||||
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
|
||||
elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1:
|
||||
if C.region == REG_CN:
|
||||
if C.region in [REG_CN, REG_TW]:
|
||||
self.logger.warning(f"limit_threshold may not be set to a reasonable value")
|
||||
|
||||
if isinstance(deal_price, str):
|
||||
@@ -159,6 +168,7 @@ class Exchange:
|
||||
self.codes = codes
|
||||
# Necessary fields
|
||||
# $close is for calculating the total value at end of each day.
|
||||
# - if $close is None, the stock on that day is regarded as suspended.
|
||||
# $factor is for rounding to the trading unit
|
||||
# $change is for calculating the limit of the stock
|
||||
|
||||
@@ -167,7 +177,7 @@ class Exchange:
|
||||
|
||||
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
||||
if self.limit_type == self.LT_TP_EXP:
|
||||
assert isinstance(limit_threshold, tuple)
|
||||
assert isinstance(limit_threshold, tuple) or (isinstance(limit_threshold, list) and len(limit_threshold) == 2)
|
||||
for exp in limit_threshold:
|
||||
necessary_fields.add(exp)
|
||||
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
|
||||
@@ -199,7 +209,7 @@ class Exchange:
|
||||
self.end_time,
|
||||
freq=self.freq,
|
||||
disk_cache=True,
|
||||
).dropna(subset=["$close"])
|
||||
)
|
||||
self.quote_df.columns = self.all_fields
|
||||
|
||||
# check buy_price data and sell_price data
|
||||
@@ -209,7 +219,7 @@ class Exchange:
|
||||
self.logger.warning("{} field data contains nan.".format(pstr))
|
||||
|
||||
# update trade_w_adj_price
|
||||
if self.quote_df["$factor"].isna().any():
|
||||
if (self.quote_df["$factor"].isna() & ~self.quote_df["$close"].isna()).any():
|
||||
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
|
||||
# Use adjusted price
|
||||
self.trade_w_adj_price = True
|
||||
@@ -245,14 +255,17 @@ class Exchange:
|
||||
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
|
||||
self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0)
|
||||
|
||||
LT_TP_EXP = "(exp)" # Tuple[str, str]
|
||||
LT_FLT = "float" # float
|
||||
LT_NONE = "none" # none
|
||||
LT_TP_EXP = "(exp)" # Tuple[str, str]: the limitation is calculated by a Qlib expression.
|
||||
LT_FLT = "float" # float: the trading limitation is based on `abs($change) < limit_threshold`
|
||||
LT_NONE = "none" # none: there is no trading limitation
|
||||
|
||||
def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
|
||||
"""get limit type"""
|
||||
if isinstance(limit_threshold, tuple):
|
||||
return self.LT_TP_EXP
|
||||
if isinstance(limit_threshold, list):
|
||||
assert len(limit_threshold) == 2
|
||||
return self.LT_TP_EXP
|
||||
elif isinstance(limit_threshold, float):
|
||||
return self.LT_FLT
|
||||
elif limit_threshold is None:
|
||||
@@ -261,20 +274,25 @@ class Exchange:
|
||||
raise NotImplementedError(f"This type of `limit_threshold` is not supported")
|
||||
|
||||
def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None:
|
||||
# $close may contain NaN, the nan indicates that the stock is not tradable at that timestamp
|
||||
suspended = self.quote_df["$close"].isna()
|
||||
# check limit_threshold
|
||||
limit_type = self._get_limit_type(limit_threshold)
|
||||
if limit_type == self.LT_NONE:
|
||||
self.quote_df["limit_buy"] = False
|
||||
self.quote_df["limit_sell"] = False
|
||||
self.quote_df["limit_buy"] = suspended
|
||||
self.quote_df["limit_sell"] = suspended
|
||||
elif limit_type == self.LT_TP_EXP:
|
||||
# set limit
|
||||
limit_threshold = cast(tuple, limit_threshold)
|
||||
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
|
||||
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
|
||||
# astype bool is necessary, because quote_df is an expression and could be float
|
||||
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]].astype("bool") | suspended
|
||||
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]].astype("bool") | suspended
|
||||
elif limit_type == self.LT_FLT:
|
||||
limit_threshold = cast(float, limit_threshold)
|
||||
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
|
||||
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
|
||||
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) | suspended
|
||||
self.quote_df["limit_sell"] = (
|
||||
self.quote_df["$change"].le(-limit_threshold) | suspended
|
||||
) # pylint: disable=E1130
|
||||
|
||||
@staticmethod
|
||||
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
|
||||
@@ -310,7 +328,7 @@ class Exchange:
|
||||
|
||||
assert isinstance(volume_threshold, dict)
|
||||
for key, vol_limit in volume_threshold.items():
|
||||
assert isinstance(vol_limit, tuple)
|
||||
assert isinstance(vol_limit, tuple) or (isinstance(vol_limit, list) and len(vol_limit) == 2)
|
||||
fields.add(vol_limit[1])
|
||||
|
||||
if key in ("buy", "all"):
|
||||
@@ -325,7 +343,7 @@ class Exchange:
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: int = None,
|
||||
direction: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Parameters
|
||||
@@ -338,8 +356,18 @@ class Exchange:
|
||||
- if direction is None, check if tradable for buying and selling.
|
||||
- if direction == Order.BUY, check the if tradable for buying
|
||||
- if direction == Order.SELL, check the sell limit for selling.
|
||||
|
||||
Returns
|
||||
-------
|
||||
True: the trading of the stock is limited (maybe hit the highest/lowest price), hence the stock is not tradable
|
||||
False: the trading of the stock is not limited, hence the stock may be tradable
|
||||
"""
|
||||
# NOTE:
|
||||
# **all** is used when checking limitation.
|
||||
# For example, the stock trading is limited in a day if every minute is limited in a day if every minute is limited.
|
||||
if direction is None:
|
||||
# The trading limitation is related to the trading direction
|
||||
# if the direction is not provided, then any limitation from buy or sell will result in trading limitation
|
||||
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
|
||||
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
|
||||
return bool(buy_limit or sell_limit)
|
||||
@@ -356,10 +384,24 @@ class Exchange:
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> bool:
|
||||
"""if stock is suspended(hence not tradable), True will be returned"""
|
||||
# is suspended
|
||||
if stock_id in self.quote.get_all_stock():
|
||||
return self.quote.get_data(stock_id, start_time, end_time, "$close") is None
|
||||
# suspended stocks are represented by None $close stock
|
||||
# The $close may contain NaN,
|
||||
close = self.quote.get_data(stock_id, start_time, end_time, "$close")
|
||||
if close is None:
|
||||
# if no close record exists
|
||||
return True
|
||||
elif isinstance(close, IndexData):
|
||||
# **any** non-NaN $close represents trading opportunity may exist
|
||||
# if all returned is nan, then the stock is suspended
|
||||
return cast(bool, cast(IndexData, close).isna().all())
|
||||
else:
|
||||
# it is single value, make sure is not None
|
||||
return np.isnan(close)
|
||||
else:
|
||||
# if the stock is not in the stock list, then it is not tradable and regarded as suspended
|
||||
return True
|
||||
|
||||
def is_stock_tradable(
|
||||
@@ -367,7 +409,7 @@ class Exchange:
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: int = None,
|
||||
direction: int | None = None,
|
||||
) -> bool:
|
||||
# check if stock can be traded
|
||||
return not (
|
||||
@@ -382,8 +424,8 @@ class Exchange:
|
||||
def deal_order(
|
||||
self,
|
||||
order: Order,
|
||||
trade_account: Account = None,
|
||||
position: BasePosition = None,
|
||||
trade_account: Account | None = None,
|
||||
position: BasePosition | None = None,
|
||||
dealt_order_amount: Dict[str, float] = defaultdict(float),
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
@@ -501,8 +543,8 @@ class Exchange:
|
||||
direction: OrderDir = OrderDir.BUY,
|
||||
) -> dict:
|
||||
"""
|
||||
The generate the target position according to the weight and the cash.
|
||||
NOTE: All the cash will assigned to the tradable stock.
|
||||
Generates the target position according to the weight and the cash.
|
||||
NOTE: All the cash will be assigned to the tradable stock.
|
||||
Parameter:
|
||||
weight_position : dict {stock_id : weight}; allocate cash by weight_position
|
||||
among then, weight must be in this range: 0 < weight < 1
|
||||
@@ -547,7 +589,7 @@ class Exchange:
|
||||
)
|
||||
return amount_dict
|
||||
|
||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float = None) -> float:
|
||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float | None = None) -> float:
|
||||
"""
|
||||
Calculate the real adjust deal amount when considering the trading unit
|
||||
:param current_amount:
|
||||
@@ -600,7 +642,7 @@ class Exchange:
|
||||
random.shuffle(sorted_ids)
|
||||
for stock_id in sorted_ids:
|
||||
|
||||
# Do not generate order for the nontradable stocks
|
||||
# Do not generate order for the non-tradable stocks
|
||||
if not self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
|
||||
continue
|
||||
|
||||
@@ -673,8 +715,8 @@ class Exchange:
|
||||
|
||||
def _get_factor_or_raise_error(
|
||||
self,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
factor: float | None = None,
|
||||
stock_id: str | None = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> float:
|
||||
@@ -689,8 +731,8 @@ class Exchange:
|
||||
|
||||
def get_amount_of_trade_unit(
|
||||
self,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
factor: float | None = None,
|
||||
stock_id: str | None = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> Optional[float]:
|
||||
@@ -723,8 +765,8 @@ class Exchange:
|
||||
def round_amount_by_trade_unit(
|
||||
self,
|
||||
deal_amount: float,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
factor: float | None = None,
|
||||
stock_id: str | None = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> float:
|
||||
@@ -764,7 +806,7 @@ class Exchange:
|
||||
|
||||
vol_limit_num: List[float] = []
|
||||
for limit in vol_limit:
|
||||
assert isinstance(limit, tuple)
|
||||
assert isinstance(limit, tuple) or (isinstance(limit, list) and len(limit) == 2)
|
||||
if limit[0] == "current":
|
||||
limit_value = self.quote.get_data(
|
||||
order.stock_id,
|
||||
|
||||
@@ -31,8 +31,8 @@ class BaseExecutor:
|
||||
generate_portfolio_metrics: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_exchange: Exchange | None = None,
|
||||
common_infra: CommonInfrastructure | None = None,
|
||||
settle_type: str = BasePosition.ST_NO,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@@ -114,7 +114,7 @@ class BaseExecutor:
|
||||
self.track_data = track_data
|
||||
self._trade_exchange = trade_exchange
|
||||
self.level_infra = LevelInfrastructure()
|
||||
self.level_infra.reset_infra(common_infra=common_infra)
|
||||
self.level_infra.reset_infra(common_infra=common_infra, executor=self)
|
||||
self._settle_type = settle_type
|
||||
self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra)
|
||||
if common_infra is None:
|
||||
@@ -134,6 +134,8 @@ class BaseExecutor:
|
||||
else:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
self.level_infra.reset_infra(common_infra=self.common_infra)
|
||||
|
||||
if common_infra.has("trade_account"):
|
||||
# NOTE: there is a trick in the code.
|
||||
# shallow copy is used instead of deepcopy.
|
||||
@@ -159,7 +161,7 @@ class BaseExecutor:
|
||||
"""
|
||||
return self.level_infra.get("trade_calendar")
|
||||
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs: Any) -> None:
|
||||
def reset(self, common_infra: CommonInfrastructure | None = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
- reset `start_time` and `end_time`, used in trade calendar
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
@@ -225,7 +227,7 @@ class BaseExecutor:
|
||||
def collect_data(
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
return_value: dict = None,
|
||||
return_value: dict | None = None,
|
||||
level: int = 0,
|
||||
) -> Generator[Any, Any, List[object]]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
@@ -256,6 +258,7 @@ class BaseExecutor:
|
||||
object
|
||||
trade decision
|
||||
"""
|
||||
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
|
||||
@@ -296,6 +299,7 @@ class BaseExecutor:
|
||||
|
||||
if return_value is not None:
|
||||
return_value.update({"execute_result": res})
|
||||
|
||||
return res
|
||||
|
||||
def get_all_executors(self) -> List[BaseExecutor]:
|
||||
@@ -323,7 +327,7 @@ class NestedExecutor(BaseExecutor):
|
||||
track_data: bool = False,
|
||||
skip_empty_decision: bool = True,
|
||||
align_range_limit: bool = True,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
common_infra: CommonInfrastructure | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -396,7 +400,7 @@ class NestedExecutor(BaseExecutor):
|
||||
trade_decision = updated_trade_decision
|
||||
# NEW UPDATE
|
||||
# create a hook for inner strategy to update outer decision
|
||||
self.inner_strategy.alter_outer_trade_decision(trade_decision)
|
||||
trade_decision = self.inner_strategy.alter_outer_trade_decision(trade_decision)
|
||||
return trade_decision
|
||||
|
||||
def _collect_data(
|
||||
@@ -473,6 +477,9 @@ class NestedExecutor(BaseExecutor):
|
||||
# do nothing and just step forward
|
||||
sub_cal.step()
|
||||
|
||||
# Let inner strategy know that the outer level execution is done.
|
||||
self.inner_strategy.post_upper_level_exe_step()
|
||||
|
||||
return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}
|
||||
|
||||
def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:
|
||||
@@ -527,7 +534,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
generate_portfolio_metrics: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
common_infra: CommonInfrastructure | None = None,
|
||||
trade_type: str = TT_SERIAL,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@@ -580,20 +587,18 @@ class SimulatorExecutor(BaseExecutor):
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return order_it
|
||||
|
||||
def _update_dealt_order_amount(self, order: Order) -> None:
|
||||
"""update date and dealt order amount in the day."""
|
||||
|
||||
now_deal_day = self.trade_calendar.get_step_time()[0].floor(freq="D")
|
||||
if self.deal_day is None or now_deal_day > self.deal_day:
|
||||
self.dealt_order_amount = defaultdict(float)
|
||||
self.deal_day = now_deal_day
|
||||
self.dealt_order_amount[order.stock_id] += order.deal_amount
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
trade_start_time, _ = self.trade_calendar.get_step_time()
|
||||
execute_result: list = []
|
||||
|
||||
for order in self._get_order_iterator(trade_decision):
|
||||
# Each time we move into a new date, clear `self.dealt_order_amount` since it only maintains intraday
|
||||
# information.
|
||||
now_deal_day = self.trade_calendar.get_step_time()[0].floor(freq="D")
|
||||
if self.deal_day is None or now_deal_day > self.deal_day:
|
||||
self.dealt_order_amount = defaultdict(float)
|
||||
self.deal_day = now_deal_day
|
||||
|
||||
# execute the order.
|
||||
# NOTE: The trade_account will be changed in this function
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
|
||||
@@ -602,7 +607,9 @@ class SimulatorExecutor(BaseExecutor):
|
||||
dealt_order_amount=self.dealt_order_amount,
|
||||
)
|
||||
execute_result.append((order, trade_val, trade_cost, trade_price))
|
||||
self._update_dealt_order_amount(order)
|
||||
|
||||
self.dealt_order_amount[order.stock_id] += order.deal_amount
|
||||
|
||||
if self.verbose:
|
||||
print(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, "
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, List, Union
|
||||
@@ -320,7 +321,7 @@ class Position(BasePosition):
|
||||
self.position[stock]["price"] = price_dict[stock]
|
||||
self.position["now_account_value"] = self.calculate_value()
|
||||
|
||||
def _init_stock(self, stock_id: str, amount: float, price: float = None) -> None:
|
||||
def _init_stock(self, stock_id: str, amount: float, price: float | None = None) -> None:
|
||||
"""
|
||||
initialization the stock in current position
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pathlib
|
||||
from collections import OrderedDict
|
||||
@@ -86,7 +87,7 @@ class PortfolioMetrics:
|
||||
self.benches: dict = OrderedDict()
|
||||
self.latest_pm_time: Optional[pd.TimeStamp] = None
|
||||
|
||||
def init_bench(self, freq: str = None, benchmark_config: dict = None) -> None:
|
||||
def init_bench(self, freq: str | None = None, benchmark_config: dict | None = None) -> None:
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
self.benchmark_config = benchmark_config
|
||||
@@ -149,15 +150,15 @@ class PortfolioMetrics:
|
||||
self,
|
||||
trade_start_time: Union[str, pd.Timestamp] = None,
|
||||
trade_end_time: Union[str, pd.Timestamp] = None,
|
||||
account_value: float = None,
|
||||
cash: float = None,
|
||||
return_rate: float = None,
|
||||
total_turnover: float = None,
|
||||
turnover_rate: float = None,
|
||||
total_cost: float = None,
|
||||
cost_rate: float = None,
|
||||
stock_value: float = None,
|
||||
bench_value: float = None,
|
||||
account_value: float | None = None,
|
||||
cash: float | None = None,
|
||||
return_rate: float | None = None,
|
||||
total_turnover: float | None = None,
|
||||
turnover_rate: float | None = None,
|
||||
total_cost: float | None = None,
|
||||
cost_rate: float | None = None,
|
||||
stock_value: float | None = None,
|
||||
bench_value: float | None = None,
|
||||
) -> None:
|
||||
# check data
|
||||
if None in [
|
||||
|
||||
@@ -3,9 +3,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Set, Tuple, Union
|
||||
from typing import Any, Set, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -32,7 +31,7 @@ class TradeCalendarManager:
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
level_infra: LevelInfrastructure | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
@@ -100,7 +99,7 @@ class TradeCalendarManager:
|
||||
def get_trade_step(self) -> int:
|
||||
return self.trade_step
|
||||
|
||||
def get_step_time(self, trade_step: int = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
def get_step_time(self, trade_step: int | None = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
"""
|
||||
Get the left and right endpoints of the trade_step'th trading interval
|
||||
|
||||
@@ -184,8 +183,8 @@ class TradeCalendarManager:
|
||||
Tuple[int, int]:
|
||||
the index of the range. **the left and right are closed**
|
||||
"""
|
||||
left = bisect.bisect_right(list(self._calendar), start_time) - 1
|
||||
right = bisect.bisect_right(list(self._calendar), end_time) - 1
|
||||
left = int(np.searchsorted(self._calendar, start_time, side="right") - 1)
|
||||
right = int(np.searchsorted(self._calendar, end_time, side="right") - 1)
|
||||
left -= self.start_index
|
||||
right -= self.start_index
|
||||
|
||||
@@ -248,7 +247,7 @@ class LevelInfrastructure(BaseInfrastructure):
|
||||
sub_level_infra:
|
||||
- **NOTE**: this will only work after _init_sub_trading !!!
|
||||
"""
|
||||
return {"trade_calendar", "sub_level_infra", "common_infra"}
|
||||
return {"trade_calendar", "sub_level_infra", "common_infra", "executor"}
|
||||
|
||||
def reset_cal(
|
||||
self,
|
||||
|
||||
@@ -75,7 +75,8 @@ class Config:
|
||||
def set_conf_from_C(self, config_c):
|
||||
self.update(**config_c.__dict__["_config"])
|
||||
|
||||
def register_from_C(self, config, skip_register=True):
|
||||
@staticmethod
|
||||
def register_from_C(config, skip_register=True):
|
||||
from .utils import set_log_with_config # pylint: disable=C0415
|
||||
|
||||
if C.registered and skip_register:
|
||||
@@ -146,6 +147,7 @@ _default_config = {
|
||||
"redis_host": "127.0.0.1",
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
"redis_password": None,
|
||||
# This value can be reset via qlib.init
|
||||
"logging_level": logging.INFO,
|
||||
# Global configuration of qlib log
|
||||
@@ -172,6 +174,9 @@ _default_config = {
|
||||
}
|
||||
},
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||
# To let qlib work with other packages, we shouldn't disable existing loggers.
|
||||
# Note that this param is default to True according to the documentation of logging.
|
||||
"disable_existing_loggers": False,
|
||||
},
|
||||
# Default config for experiment manager
|
||||
"exp_manager": {
|
||||
@@ -199,7 +204,7 @@ _default_config = {
|
||||
"task_url": "mongodb://localhost:27017/",
|
||||
"task_db_name": "default_task_db",
|
||||
},
|
||||
# Shift minute for highfreq minite data, used in backtest
|
||||
# Shift minute for highfreq minute data, used in backtest
|
||||
# if min_data_shift == 0, use default market time [9:30, 11:29, 1:00, 2:59]
|
||||
# if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:00, 2:59] - shift*minute
|
||||
"min_data_shift": 0,
|
||||
@@ -408,8 +413,7 @@ class QlibConfig(Config):
|
||||
if _logging_config:
|
||||
set_log_with_config(_logging_config)
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
logger = get_module_logger("Initialization", level=logging.INFO)
|
||||
logger = get_module_logger("Initialization", kwargs.get("logging_level", self.logging_level))
|
||||
logger.info(f"default_conf: {default_conf}.")
|
||||
|
||||
self.set_mode(default_conf)
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# REGION CONST
|
||||
from typing import TypeVar
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
REG_CN = "cn"
|
||||
REG_US = "us"
|
||||
REG_TW = "tw"
|
||||
@@ -10,4 +15,8 @@ REG_TW = "tw"
|
||||
EPS = 1e-12
|
||||
|
||||
# Infinity in integer
|
||||
INF = 10**18
|
||||
INF = int(1e18)
|
||||
ONE_DAY = pd.Timedelta("1day")
|
||||
ONE_MIN = pd.Timedelta("1min")
|
||||
EPS_T = pd.Timedelta("1s") # use 1 second to exclude the right interval point
|
||||
float_or_ndarray = TypeVar("float_or_ndarray", float, np.ndarray)
|
||||
|
||||
@@ -56,8 +56,8 @@ class Alpha360(DataHandlerLP):
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
**kwargs,
|
||||
inst_processors=None,
|
||||
**kwargs
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
@@ -67,11 +67,11 @@ class Alpha360(DataHandlerLP):
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": self.get_feature_config(),
|
||||
"label": kwargs.get("label", self.get_label_config()),
|
||||
"label": kwargs.pop("label", self.get_label_config()),
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"inst_processor": inst_processor,
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -82,12 +82,14 @@ class Alpha360(DataHandlerLP):
|
||||
data_loader=data_loader,
|
||||
learn_processors=learn_processors,
|
||||
infer_processors=infer_processors,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def get_label_config(self):
|
||||
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
|
||||
|
||||
def get_feature_config(self):
|
||||
@staticmethod
|
||||
def get_feature_config():
|
||||
# NOTE:
|
||||
# Alpha360 tries to provide a dataset with original price data
|
||||
# the original price data includes the prices and volume in the last 60 days.
|
||||
@@ -99,33 +101,33 @@ class Alpha360(DataHandlerLP):
|
||||
names = []
|
||||
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($close, %d)/$close" % (i)]
|
||||
names += ["CLOSE%d" % (i)]
|
||||
fields += ["Ref($close, %d)/$close" % i]
|
||||
names += ["CLOSE%d" % i]
|
||||
fields += ["$close/$close"]
|
||||
names += ["CLOSE0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($open, %d)/$close" % (i)]
|
||||
names += ["OPEN%d" % (i)]
|
||||
fields += ["Ref($open, %d)/$close" % i]
|
||||
names += ["OPEN%d" % i]
|
||||
fields += ["$open/$close"]
|
||||
names += ["OPEN0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($high, %d)/$close" % (i)]
|
||||
names += ["HIGH%d" % (i)]
|
||||
fields += ["Ref($high, %d)/$close" % i]
|
||||
names += ["HIGH%d" % i]
|
||||
fields += ["$high/$close"]
|
||||
names += ["HIGH0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($low, %d)/$close" % (i)]
|
||||
names += ["LOW%d" % (i)]
|
||||
fields += ["Ref($low, %d)/$close" % i]
|
||||
names += ["LOW%d" % i]
|
||||
fields += ["$low/$close"]
|
||||
names += ["LOW0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($vwap, %d)/$close" % (i)]
|
||||
names += ["VWAP%d" % (i)]
|
||||
fields += ["Ref($vwap, %d)/$close" % i]
|
||||
names += ["VWAP%d" % i]
|
||||
fields += ["$vwap/$close"]
|
||||
names += ["VWAP0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % (i)]
|
||||
names += ["VOLUME%d" % (i)]
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % i]
|
||||
names += ["VOLUME%d" % i]
|
||||
fields += ["$volume/($volume+1e-12)"]
|
||||
names += ["VOLUME0"]
|
||||
|
||||
@@ -134,7 +136,7 @@ class Alpha360(DataHandlerLP):
|
||||
|
||||
class Alpha360vwap(Alpha360):
|
||||
def get_label_config(self):
|
||||
return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])
|
||||
return ["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"]
|
||||
|
||||
|
||||
class Alpha158(DataHandlerLP):
|
||||
@@ -150,8 +152,8 @@ class Alpha158(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
**kwargs,
|
||||
inst_processors=None,
|
||||
**kwargs
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
@@ -161,11 +163,11 @@ class Alpha158(DataHandlerLP):
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": self.get_feature_config(),
|
||||
"label": kwargs.get("label", self.get_label_config()),
|
||||
"label": kwargs.pop("label", self.get_label_config()),
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"inst_processor": inst_processor,
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
@@ -176,6 +178,7 @@ class Alpha158(DataHandlerLP):
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
process_type=process_type,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
@@ -190,7 +193,7 @@ class Alpha158(DataHandlerLP):
|
||||
return self.parse_config_to_fields(conf)
|
||||
|
||||
def get_label_config(self):
|
||||
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
|
||||
|
||||
@staticmethod
|
||||
def parse_config_to_fields(config):
|
||||
@@ -426,4 +429,4 @@ class Alpha158(DataHandlerLP):
|
||||
|
||||
class Alpha158vwap(Alpha158):
|
||||
def get_label_config(self):
|
||||
return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])
|
||||
return ["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
|
||||
|
||||
from .handler import check_transform_proc
|
||||
|
||||
EPSILON = 1e-4
|
||||
|
||||
|
||||
@@ -15,20 +17,9 @@ class HighFreqHandler(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
@@ -53,7 +44,7 @@ class HighFreqHandler(DataHandlerLP):
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
template_paused = "Select(Gt($paused_num, 1.001), {0})"
|
||||
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
# norm with the close price of 237th minute of yesterday.
|
||||
@@ -110,6 +101,102 @@ class HighFreqHandler(DataHandlerLP):
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqGeneralHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
day_length=240,
|
||||
freq="1min",
|
||||
columns=["$open", "$high", "$low", "$close", "$vwap"],
|
||||
inst_processors=None,
|
||||
):
|
||||
self.day_length = day_length
|
||||
self.columns = columns
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": freq,
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
drop_raw=drop_raw,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
|
||||
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
# norm with the close price of 237th minute of yesterday.
|
||||
if shift == 0:
|
||||
template_norm = f"{{0}}/DayLast(Ref({{1}}, {self.day_length * 2}))"
|
||||
else:
|
||||
template_norm = f"Ref({{0}}, " + str(shift) + f")/DayLast(Ref({{1}}, {self.day_length}))"
|
||||
|
||||
template_fillnan = "FFillNan({0})"
|
||||
# calculate -> ffill -> remove paused
|
||||
feature_ops = template_paused.format(
|
||||
template_fillnan.format(
|
||||
template_norm.format(template_if.format("$close", price_field), template_fillnan.format("$close"))
|
||||
)
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
for column_name in self.columns:
|
||||
fields.append(get_normalized_price_feature(column_name, 0))
|
||||
names.append(column_name)
|
||||
|
||||
for column_name in self.columns:
|
||||
fields.append(get_normalized_price_feature(column_name, self.day_length))
|
||||
names.append(column_name + "_1")
|
||||
|
||||
# calculate and fill nan with 0
|
||||
fields += [
|
||||
template_paused.format(
|
||||
"If(IsNull({0}), 0, {0})".format(
|
||||
f"{{0}}/Ref(DayLast(Mean({{0}}, {self.day_length * 30})), {self.day_length})".format("$volume")
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume"]
|
||||
|
||||
fields += [
|
||||
template_paused.format(
|
||||
"If(IsNull({0}), 0, {0})".format(
|
||||
f"Ref({{0}}, {self.day_length})/Ref(DayLast(Mean({{0}}, {self.day_length * 30})), {self.day_length})".format(
|
||||
"$volume"
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqBacktestHandler(DataHandler):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -163,6 +250,61 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqGeneralBacktestHandler(DataHandler):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
day_length=240,
|
||||
freq="1min",
|
||||
columns=["$close", "$vwap", "$volume"],
|
||||
inst_processors=None,
|
||||
):
|
||||
self.day_length = day_length
|
||||
self.columns = set(columns)
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": freq,
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
if "$close" in self.columns:
|
||||
template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
fields += [
|
||||
template_paused.format(template_fillnan.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
|
||||
if "$vwap" in self.columns:
|
||||
fields += [
|
||||
template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")),
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
|
||||
if "$volume" in self.columns:
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
|
||||
names += ["$volume0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqOrderHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -173,22 +315,12 @@ class HighFreqOrderHandler(DataHandlerLP):
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
inst_processors=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
@@ -196,6 +328,7 @@ class HighFreqOrderHandler(DataHandlerLP):
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
@@ -355,8 +488,7 @@ class HighFreqBacktestOrderHandler(DataHandler):
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
# template_paused = "{0}"
|
||||
template_paused = "Select(Gt($paused_num, 1.001), {0})"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
|
||||
@@ -4,6 +4,7 @@ import datetime
|
||||
from typing import Optional
|
||||
|
||||
import qlib
|
||||
from qlib import get_module_logger
|
||||
from qlib.data import D
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
@@ -12,7 +13,6 @@ from qlib.data.data import Cal
|
||||
from qlib.contrib.ops.high_freq import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut
|
||||
import pickle as pkl
|
||||
from joblib import Parallel, delayed
|
||||
from utilsd.logging import print_log
|
||||
|
||||
|
||||
class HighFreqProvider:
|
||||
@@ -28,6 +28,7 @@ class HighFreqProvider:
|
||||
feature_conf: dict,
|
||||
label_conf: Optional[dict] = None,
|
||||
backtest_conf: dict = None,
|
||||
freq: str = "1min",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.start_time = start_time
|
||||
@@ -41,6 +42,8 @@ class HighFreqProvider:
|
||||
self.label_conf = label_conf
|
||||
self.backtest_conf = backtest_conf
|
||||
self.qlib_conf = qlib_conf
|
||||
self.logger = get_module_logger("HighFreqProvider")
|
||||
self.freq = freq
|
||||
|
||||
def get_pre_datasets(self):
|
||||
"""Generate the training, validation and test datasets for prediction
|
||||
@@ -115,8 +118,8 @@ class HighFreqProvider:
|
||||
# This code used the copy-on-write feature of Linux
|
||||
# to avoid calculating the calendar multiple times in the subprocess.
|
||||
# This code may accelerate, but may be not useful on Windows and Mac Os
|
||||
Cal.calendar(freq="1min")
|
||||
get_calendar_day(freq="1min")
|
||||
Cal.calendar(freq=self.freq)
|
||||
get_calendar_day(freq=self.freq)
|
||||
|
||||
def _gen_dataframe(self, config, datasets=["train", "valid", "test"]):
|
||||
try:
|
||||
@@ -125,7 +128,7 @@ class HighFreqProvider:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
@@ -134,11 +137,11 @@ class HighFreqProvider:
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self.logger.info(f"[{__name__}]Generating dataset")
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
@@ -157,7 +160,7 @@ class HighFreqProvider:
|
||||
with open(path[:-4] + "test.pkl", "wb") as f:
|
||||
pkl.dump(testset, f)
|
||||
res = [data[i] for i in datasets]
|
||||
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}")
|
||||
return res
|
||||
|
||||
def _gen_data(self, config, datasets=["train", "valid", "test"]):
|
||||
@@ -167,7 +170,7 @@ class HighFreqProvider:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
@@ -176,18 +179,18 @@ class HighFreqProvider:
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self.logger.info(f"[{__name__}]Generating dataset")
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
res = dataset.prepare(datasets)
|
||||
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}")
|
||||
return res
|
||||
|
||||
def _gen_dataset(self, config):
|
||||
@@ -197,21 +200,21 @@ class HighFreqProvider:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
|
||||
|
||||
with open(path, "rb") as f:
|
||||
dataset = pkl.load(f)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}")
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self.logger.info(f"[{__name__}]Generating dataset")
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
|
||||
dataset.prepare(["train", "valid", "test"])
|
||||
print_log(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset prepared, time cost: {time.time() - start:.2f}")
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
return dataset
|
||||
@@ -224,22 +227,22 @@ class HighFreqProvider:
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self.logger.info(f"[{__name__}]Generating dataset")
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
with open(path + "tmp_dataset.pkl", "rb") as f:
|
||||
new_dataset = pkl.load(f)
|
||||
|
||||
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="1min")[::240]
|
||||
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq=self.freq)[::240]
|
||||
|
||||
def generate_dataset(times):
|
||||
if os.path.isfile(path + times.strftime("%Y-%m-%d") + ".pkl"):
|
||||
@@ -265,15 +268,15 @@ class HighFreqProvider:
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset exists, load from disk.")
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self.logger.info(f"[{__name__}]Generating dataset")
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}")
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
@@ -282,7 +285,7 @@ class HighFreqProvider:
|
||||
|
||||
instruments = D.instruments(market="all")
|
||||
stock_list = D.list_instruments(
|
||||
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq="1min", as_list=True
|
||||
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq=self.freq, as_list=True
|
||||
)
|
||||
|
||||
def generate_dataset(stock):
|
||||
|
||||
@@ -96,9 +96,11 @@ def indicator_analysis(df, method="mean"):
|
||||
index: Index(datetime)
|
||||
method : str, optional
|
||||
statistics method of pa/ffr, by default "mean"
|
||||
|
||||
- if method is 'mean', count the mean statistical value of each trade indicator
|
||||
- if method is 'amount_weighted', count the deal_amount weighted mean statistical value of each trade indicator
|
||||
- if method is 'value_weighted', count the value weighted mean statistical value of each trade indicator
|
||||
|
||||
Note: statistics method of pos is always "mean"
|
||||
|
||||
Returns
|
||||
@@ -154,6 +156,7 @@ def backtest_daily(
|
||||
E.g.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# dict
|
||||
strategy = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
@@ -180,16 +183,19 @@ def backtest_daily(
|
||||
# 3) specify module path with class name
|
||||
# - "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
|
||||
|
||||
|
||||
executor : Union[str, dict, BaseExecutor]
|
||||
for initializing the outermost executor.
|
||||
benchmark: str
|
||||
the benchmark for reporting.
|
||||
account : Union[float, int, Position]
|
||||
information for describing how to creating the account
|
||||
|
||||
For `float` or `int`:
|
||||
|
||||
Using Account with only initial cash
|
||||
|
||||
For `Position`:
|
||||
|
||||
Using Account with a Position
|
||||
exchange_kwargs : dict
|
||||
the kwargs for initializing Exchange
|
||||
@@ -283,8 +289,8 @@ def long_short_backtest(
|
||||
NOTE: This will be faster with offline qlib.
|
||||
:return: The result of backtest, it is represented by a dict.
|
||||
{ "long": long_returns(excess),
|
||||
"short": short_returns(excess),
|
||||
"long_short": long_short_returns}
|
||||
"short": short_returns(excess),
|
||||
"long_short": long_short_returns}
|
||||
"""
|
||||
if get_level_index(pred, level="datetime") == 1:
|
||||
pred = pred.swaplevel().sort_index()
|
||||
|
||||
@@ -55,8 +55,10 @@ class InternalData:
|
||||
# The handler is initialized for only once.
|
||||
if not trainer.has_worker():
|
||||
self.dh = init_task_handler(perf_task_tpl)
|
||||
self.dh.config(dump_all=False) # in some cases, the data handler are saved to disk with `dump_all=True`
|
||||
else:
|
||||
self.dh = init_instance_by_config(perf_task_tpl["dataset"]["kwargs"]["handler"])
|
||||
assert self.dh.dump_all is False # otherwise, it will save all the detailed data
|
||||
|
||||
seg = perf_task_tpl["dataset"]["kwargs"]["segments"]
|
||||
|
||||
@@ -77,7 +79,7 @@ class InternalData:
|
||||
get_module_logger("Internal Data").info("the data has been initialized")
|
||||
else:
|
||||
# train new models
|
||||
assert 0 == len(recorders), "An empty experiment is required for setup `InternalData``"
|
||||
assert 0 == len(recorders), "An empty experiment is required for setup `InternalData`"
|
||||
trainer.train(gen_task)
|
||||
|
||||
# 2) extract the similarity matrix
|
||||
@@ -119,6 +121,7 @@ class MetaTaskDS(MetaTask):
|
||||
|
||||
def __init__(self, task: dict, meta_info: pd.DataFrame, mode: str = MetaTask.PROC_MODE_FULL, fill_method="max"):
|
||||
"""
|
||||
|
||||
The description of the processed data
|
||||
|
||||
time_perf: A array with shape <hist_step_n * step, data pieces> -> data piece performance
|
||||
@@ -132,6 +135,10 @@ class MetaTaskDS(MetaTask):
|
||||
[0., 0., 0., ..., 0., 0., 1.],
|
||||
[0., 0., 0., ..., 0., 0., 1.]])
|
||||
|
||||
Parameters
|
||||
----------
|
||||
meta_info: pd.DataFrame
|
||||
please refer to the docs of _prepare_meta_ipt for detailed explanation.
|
||||
"""
|
||||
super().__init__(task, meta_info)
|
||||
self.fill_method = fill_method
|
||||
@@ -180,12 +187,41 @@ class MetaTaskDS(MetaTask):
|
||||
self.processed_meta_input = data_to_tensor(self.processed_meta_input)
|
||||
|
||||
def _get_processed_meta_info(self):
|
||||
meta_info_norm = self.meta_info.sub(self.meta_info.mean(axis=1), axis=0) # .fillna(0.)
|
||||
if self.fill_method == "max":
|
||||
meta_info_norm = meta_info_norm.T.fillna(
|
||||
meta_info_norm.max(axis=1)
|
||||
).T # fill it with row max to align with previous implementation
|
||||
meta_info_norm = self.meta_info.sub(self.meta_info.mean(axis=1), axis=0)
|
||||
if self.fill_method.startswith("max"):
|
||||
suffix = self.fill_method.lstrip("max")
|
||||
if suffix == "seg":
|
||||
fill_value = {}
|
||||
for col in meta_info_norm.columns:
|
||||
fill_value[col] = meta_info_norm.loc[meta_info_norm[col].isna(), :].dropna(axis=1).mean().max()
|
||||
fill_value = pd.Series(fill_value).sort_index()
|
||||
# The NaN Values are filled segment-wise. Below is an exampleof fill_value
|
||||
# 2009-01-05 2009-02-06 0.145809
|
||||
# 2009-02-09 2009-03-06 0.148005
|
||||
# 2009-03-09 2009-04-03 0.090385
|
||||
# 2009-04-07 2009-05-05 0.114318
|
||||
# 2009-05-06 2009-06-04 0.119328
|
||||
# ...
|
||||
meta_info_norm = meta_info_norm.fillna(fill_value)
|
||||
else:
|
||||
if len(suffix) > 0:
|
||||
get_module_logger("MetaTaskDS").warning(
|
||||
f"fill_method={self.fill_method}; the info after can't be correctly parsed. Please check your parameters."
|
||||
)
|
||||
fill_value = meta_info_norm.max(axis=1)
|
||||
# fill it with row max to align with previous implementation
|
||||
# This will magnify the data similarity when data is in daily freq
|
||||
|
||||
# the fill value corresponds to data like this
|
||||
# It get a performance value for each day.
|
||||
# The performance value are get from other models on this day
|
||||
# 2009-01-16 0.276320
|
||||
# 2009-01-19 0.280603
|
||||
# ...
|
||||
# 2011-06-27 0.203773
|
||||
meta_info_norm = meta_info_norm.T.fillna(fill_value).T
|
||||
elif self.fill_method == "zero":
|
||||
# It will fillna(0.0) at the end.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -286,7 +322,33 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
logger.warning(f"ValueError: {e}")
|
||||
assert len(self.meta_task_l) > 0, "No meta tasks found. Please check the data and setting"
|
||||
|
||||
def _prepare_meta_ipt(self, task):
|
||||
def _prepare_meta_ipt(self, task) -> pd.DataFrame:
|
||||
"""
|
||||
Please refer to `self.internal_data.setup` for detailed information about `self.internal_data.data_ic_df`
|
||||
|
||||
Indices with format below can be successfully sliced by `ic_df.loc[:end, pd.IndexSlice[:, :end]]`
|
||||
|
||||
2021-06-21 2021-06-04 .. 2021-03-22 2021-03-08
|
||||
2021-07-02 2021-06-18 .. 2021-04-02 None
|
||||
|
||||
Returns
|
||||
-------
|
||||
a pd.DataFrame with similar content below.
|
||||
- each column corresponds to a trained model named by the training data range
|
||||
- each row corresponds to a day of data tested by the models of the columns
|
||||
- The rows cells that overlaps with the data used by columns are masked
|
||||
|
||||
|
||||
2009-01-05 2009-02-09 ... 2011-04-27 2011-05-26
|
||||
2009-02-06 2009-03-06 ... 2011-05-25 2011-06-23
|
||||
datetime ...
|
||||
2009-01-13 NaN 0.310639 ... -0.169057 0.137792
|
||||
2009-01-14 NaN 0.261086 ... -0.143567 0.082581
|
||||
... ... ... ... ... ...
|
||||
2011-06-30 -0.054907 -0.020219 ... -0.023226 NaN
|
||||
2011-07-01 -0.075762 -0.026626 ... -0.003167 NaN
|
||||
|
||||
"""
|
||||
ic_df = self.internal_data.data_ic_df
|
||||
|
||||
segs = task["dataset"]["kwargs"]["segments"]
|
||||
@@ -294,15 +356,19 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
ic_df_avail = ic_df.loc[:end, pd.IndexSlice[:, :end]]
|
||||
|
||||
# meta data set focus on the **information** instead of preprocess
|
||||
# 1) filter the future info
|
||||
def mask_future(s):
|
||||
"""mask future information"""
|
||||
# from qlib.utils import get_date_by_shift
|
||||
# 1) filter the overlap info
|
||||
def mask_overlap(s):
|
||||
"""
|
||||
mask overlap information
|
||||
data after self.name[end] with self.trunc_days that contains future info are also considered as overlap info
|
||||
|
||||
Approximately the diagnal + horizon length of data are masked.
|
||||
"""
|
||||
start, end = s.name
|
||||
end = get_date_by_shift(trading_date=end, shift=self.trunc_days - 1, future=True)
|
||||
return s.mask((s.index >= start) & (s.index <= end))
|
||||
|
||||
ic_df_avail = ic_df_avail.apply(mask_future) # apply to each col
|
||||
ic_df_avail = ic_df_avail.apply(mask_overlap) # apply to each col
|
||||
|
||||
# 2) filter the info with too long periods
|
||||
total_len = self.step * self.hist_step_n
|
||||
|
||||
@@ -52,6 +52,7 @@ class MetaModelDS(MetaTaskModel):
|
||||
lr=0.0001,
|
||||
max_epoch=100,
|
||||
seed=43,
|
||||
alpha=0.0,
|
||||
):
|
||||
self.step = step
|
||||
self.hist_step_n = hist_step_n
|
||||
@@ -61,6 +62,7 @@ class MetaModelDS(MetaTaskModel):
|
||||
self.lr = lr
|
||||
self.max_epoch = max_epoch
|
||||
self.fitted = False
|
||||
self.alpha = alpha
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
|
||||
@@ -144,7 +146,11 @@ class MetaModelDS(MetaTaskModel):
|
||||
) # debug: record when the test phase starts
|
||||
|
||||
self.tn = PredNet(
|
||||
step=self.step, hist_step_n=self.hist_step_n, clip_weight=self.clip_weight, clip_method=self.clip_method
|
||||
step=self.step,
|
||||
hist_step_n=self.hist_step_n,
|
||||
clip_weight=self.clip_weight,
|
||||
clip_method=self.clip_method,
|
||||
alpha=self.alpha,
|
||||
)
|
||||
|
||||
opt = optim.Adam(self.tn.parameters(), lr=self.lr)
|
||||
|
||||
@@ -41,11 +41,18 @@ class TimeWeightMeta(SingleMetaBase):
|
||||
|
||||
|
||||
class PredNet(nn.Module):
|
||||
def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh"):
|
||||
def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh", alpha: float = 0.0):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
alpha : float
|
||||
the regularization for sub model (useful when align meta model with linear submodel)
|
||||
"""
|
||||
super().__init__()
|
||||
self.step = step
|
||||
self.twm = TimeWeightMeta(hist_step_n=hist_step_n, clip_weight=clip_weight, clip_method=clip_method)
|
||||
self.init_paramters(hist_step_n)
|
||||
self.alpha = alpha
|
||||
|
||||
def get_sample_weights(self, X, time_perf, time_belong, ignore_weight=False):
|
||||
weights = torch.from_numpy(np.ones(X.shape[0])).float().to(X.device)
|
||||
@@ -59,7 +66,7 @@ class PredNet(nn.Module):
|
||||
"""Please refer to the docs of MetaTaskDS for the description of the variables"""
|
||||
weights = self.get_sample_weights(X, time_perf, time_belong, ignore_weight=ignore_weight)
|
||||
X_w = X.T * weights.view(1, -1)
|
||||
theta = torch.inverse(X_w @ X) @ X_w @ y
|
||||
theta = torch.inverse(X_w @ X + self.alpha * torch.eye(X_w.shape[0])) @ X_w @ y
|
||||
return X_test @ theta, weights
|
||||
|
||||
def init_paramters(self, hist_step_n):
|
||||
|
||||
@@ -5,6 +5,9 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from qlib.constant import EPS
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
class ICLoss(nn.Module):
|
||||
def forward(self, pred, y, idx, skip_size=50):
|
||||
@@ -24,6 +27,7 @@ class ICLoss(nn.Module):
|
||||
diff_point.append(i)
|
||||
prev = date
|
||||
diff_point.append(None)
|
||||
# The lengths of diff_point will be one more larger then diff_point
|
||||
|
||||
ic_all = 0.0
|
||||
skip_n = 0
|
||||
@@ -34,13 +38,23 @@ class ICLoss(nn.Module):
|
||||
skip_n += 1
|
||||
continue
|
||||
y_focus = y[start_i:end_i]
|
||||
if pred_focus.std() < EPS or y_focus.std() < EPS:
|
||||
# These cases often happend at the end of test data.
|
||||
# Usually caused by fillna(0.)
|
||||
skip_n += 1
|
||||
continue
|
||||
|
||||
ic_day = torch.dot(
|
||||
(pred_focus - pred_focus.mean()) / np.sqrt(pred_focus.shape[0]) / pred_focus.std(),
|
||||
(y_focus - y_focus.mean()) / np.sqrt(y_focus.shape[0]) / y_focus.std(),
|
||||
)
|
||||
ic_all += ic_day
|
||||
if len(diff_point) - 1 - skip_n <= 0:
|
||||
raise ValueError("No enough data for calculating iC")
|
||||
raise ValueError("No enough data for calculating IC")
|
||||
if skip_n > 0:
|
||||
get_module_logger("ICLoss").info(
|
||||
f"{skip_n} days are skipped due to zero std or small scale of valid samples."
|
||||
)
|
||||
ic_mean = ic_all / (len(diff_point) - 1 - skip_n)
|
||||
return -ic_mean # ic loss
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ try:
|
||||
from .catboost_model import CatBoostModel
|
||||
except ModuleNotFoundError:
|
||||
CatBoostModel = None
|
||||
print("Please install necessary libs for CatBoostModel.")
|
||||
print("ModuleNotFoundError. CatBoostModel are skipped. (optional: maybe installing CatBoostModel can fix it.)")
|
||||
try:
|
||||
from .double_ensemble import DEnsembleModel
|
||||
from .gbdt import LGBModel
|
||||
|
||||
@@ -30,6 +30,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
sample_ratios=None,
|
||||
sub_weights=None,
|
||||
epochs=100,
|
||||
early_stopping_rounds=None,
|
||||
**kwargs
|
||||
):
|
||||
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
|
||||
@@ -59,6 +60,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
self.params = {"objective": loss}
|
||||
self.params.update(kwargs)
|
||||
self.loss = loss
|
||||
self.early_stopping_rounds = early_stopping_rounds
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
@@ -103,14 +105,19 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
def train_submodel(self, df_train, df_valid, weights, features):
|
||||
dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)
|
||||
evals_result = dict()
|
||||
|
||||
callbacks = [lgb.log_evaluation(20), lgb.record_evaluation(evals_result)]
|
||||
if self.early_stopping_rounds:
|
||||
callbacks.append(lgb.early_stopping(self.early_stopping_rounds))
|
||||
self.logger.info("Training with early_stopping...")
|
||||
|
||||
model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=self.epochs,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
verbose_eval=20,
|
||||
evals_result=evals_result,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
from scipy.optimize import nnls
|
||||
from sklearn.linear_model import LinearRegression, Ridge, Lasso
|
||||
@@ -29,7 +30,7 @@ class LinearModel(Model):
|
||||
RIDGE = "ridge"
|
||||
LASSO = "lasso"
|
||||
|
||||
def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False):
|
||||
def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False, include_valid: bool = False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -39,6 +40,9 @@ class LinearModel(Model):
|
||||
l1 or l2 regularization parameter
|
||||
fit_intercept : bool
|
||||
whether fit intercept
|
||||
include_valid: bool
|
||||
Should the validation data be included for training?
|
||||
The validation data should be included
|
||||
"""
|
||||
assert estimator in [self.OLS, self.NNLS, self.RIDGE, self.LASSO], f"unsupported estimator `{estimator}`"
|
||||
self.estimator = estimator
|
||||
@@ -49,9 +53,16 @@ class LinearModel(Model):
|
||||
self.fit_intercept = fit_intercept
|
||||
|
||||
self.coef_ = None
|
||||
self.include_valid = include_valid
|
||||
|
||||
def fit(self, dataset: DatasetH, reweighter: Reweighter = None):
|
||||
df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if self.include_valid:
|
||||
try:
|
||||
df_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
df_train = pd.concat([df_train, df_valid])
|
||||
except KeyError:
|
||||
get_module_logger("LinearModel").info("include_valid=True, but valid does not exist")
|
||||
if df_train.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
if reweighter is not None:
|
||||
|
||||
@@ -28,7 +28,7 @@ class ADARNN(Model):
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
@@ -56,7 +56,7 @@ class ADARNN(Model):
|
||||
n_splits=2,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**_
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADARNN")
|
||||
@@ -81,7 +81,7 @@ class ADARNN(Model):
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_splits = n_splits
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -213,7 +213,8 @@ class ADARNN(Model):
|
||||
weight_mat = self.transform_type(out_weight_list)
|
||||
return weight_mat, None
|
||||
|
||||
def calc_all_metrics(self, pred):
|
||||
@staticmethod
|
||||
def calc_all_metrics(pred):
|
||||
"""pred is a pandas dataframe that has two attributes: score (pred) and label (real)"""
|
||||
res = {}
|
||||
ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score))
|
||||
@@ -259,8 +260,6 @@ class ADARNN(Model):
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
@@ -400,7 +399,7 @@ class AdaRNN(nn.Module):
|
||||
self.model_type = model_type
|
||||
self.trans_loss = trans_loss
|
||||
self.len_seq = len_seq
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
in_size = self.n_input
|
||||
|
||||
features = nn.ModuleList()
|
||||
@@ -499,7 +498,8 @@ class AdaRNN(nn.Module):
|
||||
res = self.softmax(weight).squeeze()
|
||||
return res
|
||||
|
||||
def get_features(self, output_list):
|
||||
@staticmethod
|
||||
def get_features(output_list):
|
||||
fea_list_src, fea_list_tar = [], []
|
||||
for fea in output_list:
|
||||
fea_list_src.append(fea[0 : fea.size(0) // 2])
|
||||
@@ -561,7 +561,7 @@ class TransferLoss:
|
||||
"""
|
||||
self.loss_type = loss_type
|
||||
self.input_dim = input_dim
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
|
||||
def compute(self, X, Y):
|
||||
"""Compute adaptation loss
|
||||
@@ -676,7 +676,8 @@ class MMD_loss(nn.Module):
|
||||
self.fix_sigma = None
|
||||
self.kernel_type = kernel_type
|
||||
|
||||
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||
@staticmethod
|
||||
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||
n_samples = int(source.size()[0]) + int(target.size()[0])
|
||||
total = torch.cat([source, target], dim=0)
|
||||
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
||||
@@ -691,7 +692,8 @@ class MMD_loss(nn.Module):
|
||||
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
||||
return sum(kernel_val)
|
||||
|
||||
def linear_mmd(self, X, Y):
|
||||
@staticmethod
|
||||
def linear_mmd(X, Y):
|
||||
delta = X.mean(axis=0) - Y.mean(axis=0)
|
||||
loss = delta.dot(delta.T)
|
||||
return loss
|
||||
|
||||
@@ -36,7 +36,7 @@ class ADD(Model):
|
||||
d_feat : int
|
||||
input dimensions for each time step
|
||||
metric : str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : int
|
||||
|
||||
@@ -30,7 +30,7 @@ class ALSTM(Model):
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : int
|
||||
|
||||
@@ -33,7 +33,7 @@ class ALSTM(Model):
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : int
|
||||
|
||||
@@ -33,7 +33,7 @@ class GATs(Model):
|
||||
d_feat : int
|
||||
input dimensions for each time step
|
||||
metric : str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : int
|
||||
|
||||
@@ -50,7 +50,7 @@ class GATs(Model):
|
||||
d_feat : int
|
||||
input dimensions for each time step
|
||||
metric : str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : int
|
||||
|
||||
@@ -30,7 +30,7 @@ class GRU(Model):
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
|
||||
@@ -31,7 +31,7 @@ class GRU(Model):
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
|
||||
@@ -34,7 +34,7 @@ class HIST(Model):
|
||||
d_feat : int
|
||||
input dimensions for each time step
|
||||
metric : str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
|
||||
@@ -32,7 +32,7 @@ class IGMTF(Model):
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
|
||||
@@ -29,7 +29,7 @@ class LSTM(Model):
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
|
||||
@@ -30,7 +30,7 @@ class LSTM(Model):
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
|
||||
@@ -47,10 +47,6 @@ class DNNModelPytorch(Model):
|
||||
layer sizes
|
||||
lr : float
|
||||
learning rate
|
||||
lr_decay : float
|
||||
learning rate decay
|
||||
lr_decay_steps : int
|
||||
learning rate decay steps
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : int
|
||||
@@ -64,8 +60,6 @@ class DNNModelPytorch(Model):
|
||||
batch_size=2000,
|
||||
early_stop_rounds=50,
|
||||
eval_steps=20,
|
||||
lr_decay=0.96,
|
||||
lr_decay_steps=100,
|
||||
optimizer="gd",
|
||||
loss="mse",
|
||||
GPU=0,
|
||||
@@ -93,8 +87,6 @@ class DNNModelPytorch(Model):
|
||||
self.batch_size = batch_size
|
||||
self.early_stop_rounds = early_stop_rounds
|
||||
self.eval_steps = eval_steps
|
||||
self.lr_decay = lr_decay
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss_type = loss
|
||||
if isinstance(GPU, str):
|
||||
@@ -116,8 +108,6 @@ class DNNModelPytorch(Model):
|
||||
f"\nbatch_size : {batch_size}"
|
||||
f"\nearly_stop_rounds : {early_stop_rounds}"
|
||||
f"\neval_steps : {eval_steps}"
|
||||
f"\nlr_decay : {lr_decay}"
|
||||
f"\nlr_decay_steps : {lr_decay_steps}"
|
||||
f"\noptimizer : {optimizer}"
|
||||
f"\nloss_type : {loss}"
|
||||
f"\nseed : {seed}"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user