mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-17 11:18:24 +08:00
Compare commits
79 Commits
v0.8.4
...
mini_proje
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a02ac95538 | ||
|
|
cc94c32db6 | ||
|
|
9a40fd3cdc | ||
|
|
c4281121e3 | ||
|
|
2de9903200 | ||
|
|
2cf842bcfe | ||
|
|
9e381493c2 | ||
|
|
a73b60d05a | ||
|
|
64979ad769 | ||
|
|
c5cf8fb9cc | ||
|
|
5d579d1a20 | ||
|
|
3c9c76b384 | ||
|
|
9d0a8f61d1 | ||
|
|
701b18af1b | ||
|
|
84ff662a26 | ||
|
|
00e40e775b | ||
|
|
45fe5e6974 | ||
|
|
366a9c33f3 | ||
|
|
982e0da715 | ||
|
|
cd5e5d5235 | ||
|
|
caea495f40 | ||
|
|
d934c8caba | ||
|
|
a139986f4e | ||
|
|
12c3de42d0 | ||
|
|
fe0f9427f2 | ||
|
|
a973e4fb66 | ||
|
|
c60366addd | ||
|
|
41447f320b | ||
|
|
e1271a83f7 | ||
|
|
30b531086c | ||
|
|
87926513cb | ||
|
|
7bfc7e1797 | ||
|
|
85e7cdcac3 | ||
|
|
08fd1d3f42 | ||
|
|
defd6758f6 | ||
|
|
61cc1a3867 | ||
|
|
655ed982cf | ||
|
|
2952c443ca | ||
|
|
7f1293ec34 | ||
|
|
73438807f9 | ||
|
|
962751c72d | ||
|
|
56cfa480dc | ||
|
|
6edd0bf298 | ||
|
|
fe155703b0 | ||
|
|
3c4f4bfd44 | ||
|
|
5200ff520a | ||
|
|
30e457119c | ||
|
|
243e516cf1 | ||
|
|
e229b567ad | ||
|
|
f129bfef5d | ||
|
|
9dd5e07819 | ||
|
|
00ed35fc1b | ||
|
|
3f53a097b0 | ||
|
|
fb230a8097 | ||
|
|
ff4724e248 | ||
|
|
73d90f7f44 | ||
|
|
b7988e6428 | ||
|
|
8efc8b92ef | ||
|
|
f2a5ecd98a | ||
|
|
705354cc28 | ||
|
|
1b5d0d4d6d | ||
|
|
f4a481945b | ||
|
|
5f18ba7970 | ||
|
|
2681c61c60 | ||
|
|
776b0c5bb4 | ||
|
|
829ad9f5e9 | ||
|
|
921c13cc90 | ||
|
|
0f519f6053 | ||
|
|
2ed806c846 | ||
|
|
d2f0bebf60 | ||
|
|
615a381038 | ||
|
|
568a88fddb | ||
|
|
058f976727 | ||
|
|
faa99f30fa | ||
|
|
837067b9e1 | ||
|
|
3a911bc09b | ||
|
|
90be21bb40 | ||
|
|
40dd84857c | ||
|
|
74cc21fc2c |
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -8,7 +8,7 @@
|
||||
<!--- Why is this change required? What problem does it solve? -->
|
||||
|
||||
## How Has This Been Tested?
|
||||
<! --- Put an `x` in all the boxes that apply: --->
|
||||
<!--- Put an `x` in all the boxes that apply: --->
|
||||
- [ ] Pass the test by running: `pytest qlib/tests/test_all_pipeline.py` under upper directory of `qlib`.
|
||||
- [ ] If you are adding a new feature, test on your own test scripts.
|
||||
|
||||
|
||||
49
.github/workflows/test.yml
vendored
49
.github/workflows/test.yml
vendored
@@ -35,6 +35,15 @@ jobs:
|
||||
pip install numpy==1.19.5 ruamel.yaml
|
||||
pip install pyqlib --ignore-installed
|
||||
|
||||
- name: Make html with sphinx
|
||||
run: |
|
||||
pip install -U sphinx
|
||||
pip install sphinx_rtd_theme readthedocs_sphinx_ext
|
||||
pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
|
||||
cd docs
|
||||
sphinx-build -b html . build
|
||||
cd ..
|
||||
|
||||
# Check Qlib with pylint
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
@@ -65,9 +74,44 @@ jobs:
|
||||
pip install pylint
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0201,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
|
||||
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
# Description: We have used black to limit the length of each line to 120.
|
||||
# F541 f-string is missing placeholders
|
||||
# Description: The same thing is done when using pylint for detection.
|
||||
# E266 too many leading '#' for block comment
|
||||
# Description: To make the code more readable, a lot of "#" is used.
|
||||
# This error code appears centrally in:
|
||||
# qlib/backtest/executor.py
|
||||
# qlib/data/ops.py
|
||||
# qlib/utils/__init__.py
|
||||
# E402 module level import not at top of file
|
||||
# Description: There are times when module level import is not available at the top of the file.
|
||||
# W503 line break before binary operator
|
||||
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
|
||||
# E731 do not assign a lambda expression, use a def
|
||||
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
|
||||
# E203 whitespace before ':'
|
||||
# Description: If there is whitespace before ":", it cannot pass the black check.
|
||||
- name: Check Qlib with flake8
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install flake8
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
# https://github.com/python/mypy/issues/10600
|
||||
- name: Check Qlib with mypy
|
||||
run: |
|
||||
pip install mypy
|
||||
mypy qlib --install-types --non-interactive || true
|
||||
mypy qlib
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
|
||||
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
@@ -78,6 +122,7 @@ jobs:
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install gym tianshou torch
|
||||
pip install -e .
|
||||
|
||||
- name: Install test dependencies
|
||||
@@ -87,10 +132,10 @@ jobs:
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
pip install -r scripts/data_collector/pit/requirements.txt
|
||||
cd tests
|
||||
python -m pytest . --durations=10
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
|
||||
21
.github/workflows/test_macos.yml
vendored
21
.github/workflows/test_macos.yml
vendored
@@ -34,10 +34,24 @@ jobs:
|
||||
python -m black qlib -l 120 --check --diff
|
||||
# Test Qlib installed with pip
|
||||
|
||||
- name: Check Qlib with flake8
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install flake8
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
python -m pip install numpy==1.19.5
|
||||
python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
- name: Make html with sphnix
|
||||
run: |
|
||||
pip install -U sphinx
|
||||
pip install sphinx_rtd_theme readthedocs_sphinx_ext
|
||||
pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
|
||||
cd docs
|
||||
sphinx-build -b html . build
|
||||
cd ..
|
||||
- name: Install Lightgbm for MacOS
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
@@ -49,7 +63,10 @@ jobs:
|
||||
brew install libomp.rb
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
|
||||
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
@@ -60,6 +77,7 @@ jobs:
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
python -m pip install gym tianshou torch
|
||||
pip install -e .
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
@@ -68,6 +86,7 @@ jobs:
|
||||
python -m pip install black pytest
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
pip install -r scripts/data_collector/pit/requirements.txt
|
||||
cd tests
|
||||
python -m pytest . --durations=0
|
||||
- name: Test workflow by config (install from source)
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -27,6 +27,10 @@ examples/estimator/estimator_example/
|
||||
|
||||
*.egg-info/
|
||||
|
||||
# test related
|
||||
test-output.xml
|
||||
.output
|
||||
.data
|
||||
|
||||
# special software
|
||||
mlruns/
|
||||
@@ -34,6 +38,7 @@ mlruns/
|
||||
tags
|
||||
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
17
.mypy.ini
Normal file
17
.mypy.ini
Normal file
@@ -0,0 +1,17 @@
|
||||
[mypy]
|
||||
exclude = (?x)(
|
||||
^qlib/backtest
|
||||
| ^qlib/contrib
|
||||
| ^qlib/data
|
||||
| ^qlib/model
|
||||
| ^qlib/strategy
|
||||
| ^qlib/tests
|
||||
| ^qlib/utils
|
||||
| ^qlib/workflow
|
||||
| ^qlib/config\.py$
|
||||
| ^qlib/log\.py$
|
||||
| ^qlib/__init__\.py$
|
||||
)
|
||||
ignore_missing_imports = true
|
||||
disallow_incomplete_defs = true
|
||||
follow_imports = skip
|
||||
12
.pre-commit-config.yaml
Normal file
12
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: ["qlib", "-l 120"]
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 4.0.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: ["--ignore=E501,F541,E266,E402,W503,E731,E203"]
|
||||
28
README.md
28
README.md
@@ -11,7 +11,11 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Arctic Provider Backend & Orderbook data example | :hammer: [Rleased](https://github.com/microsoft/qlib/pull/744) on Jan 17, 2022 |
|
||||
| HIST and IGMTF models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1040) on Apr 10, 2022 |
|
||||
| Qlib [notebook tutorial](https://github.com/microsoft/qlib/tree/main/examples/tutorial) | 📖 [Released](https://github.com/microsoft/qlib/pull/1037) on Apr 7, 2022 |
|
||||
| Ibovespa index data | :rice: [Released](https://github.com/microsoft/qlib/pull/990) on Apr 6, 2022 |
|
||||
| Point-in-Time database | :hammer: [Released](https://github.com/microsoft/qlib/pull/343) on Mar 10, 2022 |
|
||||
| Arctic Provider Backend & Orderbook data example | :hammer: [Released](https://github.com/microsoft/qlib/pull/744) on Jan 17, 2022 |
|
||||
| Meta-Learning-based framework & DDG-DA | :chart_with_upwards_trend: :hammer: [Released](https://github.com/microsoft/qlib/pull/743) on Jan 10, 2022 |
|
||||
| Planning-based portfolio optimization | :hammer: [Released](https://github.com/microsoft/qlib/pull/754) on Dec 28, 2021 |
|
||||
| Release Qlib v0.8.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.8.0) on Dec 8, 2021 |
|
||||
@@ -28,7 +32,7 @@ Recent released features
|
||||
| High-frequency data processing example | :hammer: [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
|
||||
| High-frequency trading example | :chart_with_upwards_trend: [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 |
|
||||
| High-frequency data(1min) | :rice: [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 |
|
||||
| Tabnet Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
|
||||
| Tabnet Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
|
||||
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
@@ -95,9 +99,8 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
# Plans
|
||||
New features under development(order by estimated release time).
|
||||
Your feedbacks about the features are very important.
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
|
||||
<!-- | Feature | Status | -->
|
||||
<!-- | -- | ------ | -->
|
||||
|
||||
# Framework of Qlib
|
||||
|
||||
@@ -105,7 +108,6 @@ Your feedbacks about the features are very important.
|
||||
<img src="docs/_static/img/framework.svg" />
|
||||
</div>
|
||||
|
||||
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules, and each component could be used stand-alone.
|
||||
|
||||
| Name | Description |
|
||||
@@ -117,6 +119,8 @@ At the module level, Qlib is a platform that consists of the above components. T
|
||||
* The modules with hand-drawn style are under development and will be released in the future.
|
||||
* The modules with dashed borders are highly user-customizable and extendible.
|
||||
|
||||
(p.s. framework image is created with https://draw.io/)
|
||||
|
||||
|
||||
# Quick Start
|
||||
|
||||
@@ -336,6 +340,8 @@ Here is a list of models built on `Qlib`.
|
||||
- [TCN based on pytorch (Shaojie Bai, et al. 2018)](examples/benchmarks/TCN/)
|
||||
- [ADARNN based on pytorch (YunTao Du, et al. 2021)](examples/benchmarks/ADARNN/)
|
||||
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
|
||||
- [IGMTF based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/IGMTF/)
|
||||
- [HIST based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/HIST/)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -384,6 +390,8 @@ Dataset plays a very important role in Quant. Here is a list of the datasets bui
|
||||
Your PR to build new Quant dataset is highly welcomed.
|
||||
|
||||
# More About Qlib
|
||||
If you want to have a quick glance at the most frequently used components of qlib, you can try notebooks [here](examples/tutorial/).
|
||||
|
||||
The detailed documents are organized in [docs](docs/).
|
||||
[Sphinx](http://www.sphinx-doc.org) and the readthedocs theme is required to build the documentation in html formats.
|
||||
```bash
|
||||
@@ -466,9 +474,13 @@ If you don't know how to start to contribute, you can refer to the following exa
|
||||
| Docs | [Improve docs quality](https://github.com/microsoft/qlib/pull/797/files) ; [Fix a typo](https://github.com/microsoft/qlib/pull/774) |
|
||||
| Feature | Implement a [requested feature](https://github.com/microsoft/qlib/projects) like [this](https://github.com/microsoft/qlib/pull/754); [Refactor interfaces](https://github.com/microsoft/qlib/pull/539/files) |
|
||||
| Dataset | [Add a dataset](https://github.com/microsoft/qlib/pull/733) |
|
||||
| Models | [Implement a new model](https://github.com/microsoft/qlib/pull/689) |
|
||||
| Models | [Implement a new model](https://github.com/microsoft/qlib/pull/689), [some instructions to contribute models](https://github.com/microsoft/qlib/tree/main/examples/benchmarks#contributing) |
|
||||
|
||||
If you would like to become one of Qlib's maintainers to contribute more (e.g. help merge PR, triage issues), please contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)). We are glad to help you to set the right permission.
|
||||
[Good first issues](https://github.com/microsoft/qlib/labels/good%20first%20issue) are labelled to indicate that they are easy to start your contributions.
|
||||
|
||||
You can find some impefect implementation in Qlib by `rg 'TODO|FIXME' qlib`
|
||||
|
||||
If you would like to become one of Qlib's maintainers to contribute more (e.g. help merge PR, triage issues), please contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)). We are glad to help to upgrade your permission.
|
||||
|
||||
## Licence
|
||||
Most contributions require you to agree to a
|
||||
|
||||
136
docs/advanced/PIT.rst
Normal file
136
docs/advanced/PIT.rst
Normal file
@@ -0,0 +1,136 @@
|
||||
.. _pit:
|
||||
|
||||
===========================
|
||||
(P)oint-(I)n-(T)ime Database
|
||||
===========================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
------------
|
||||
Point-in-time data is a very important consideration when performing any sort of historical market analysis.
|
||||
|
||||
For example, let’s say we are backtesting a trading strategy and we are using the past five years of historical data as our input.
|
||||
Our model is assumed to trade once a day, at the market close, and we’ll say we are calculating the trading signal for 1 January 2020 in our backtest. At that point, we should only have data for 1 January 2020, 31 December 2019, 30 December 2019 etc.
|
||||
|
||||
In financial data (especially financial reports), the same piece of data may be amended for multiple times overtime. If we only use the latest version for historical backtesting, data leakage will happen.
|
||||
Point-in-time database is designed for solving this problem to make sure user get the right version of data at any historical timestamp. It will keep the performance of online trading and historical backtesting the same.
|
||||
|
||||
|
||||
|
||||
Data Preparation
|
||||
----------------
|
||||
|
||||
Qlib provides a crawler to help users to download financial data and then a converter to dump the data in Qlib format.
|
||||
Please follow `scripts/data_collector/pit/README.md <https://github.com/microsoft/qlib/tree/main/scripts/data_collector/pit/>`_ to download and convert data.
|
||||
Besides, you can find some additional usage examples there.
|
||||
|
||||
|
||||
File-based design for PIT data
|
||||
------------------------------
|
||||
|
||||
Qlib provides a file-based storage for PIT data.
|
||||
|
||||
For each feature, it contains 4 columns, i.e. date, period, value, _next.
|
||||
Each row corresponds to a statement.
|
||||
|
||||
The meaning of each feature with filename like `XXX_a.data`:
|
||||
|
||||
- `date`: the statement's date of publication.
|
||||
- `period`: the period of the statement. (e.g. it will be quarterly frequency in most of the markets)
|
||||
- If it is an annual period, it will be an integer corresponding to the year
|
||||
- If it is an quarterly periods, it will be an integer like `<year><index of quarter>`. The last two decimal digits represents the index of quarter. Others represent the year.
|
||||
- `value`: the described value
|
||||
- `_next`: the byte index of the next occurance of the field.
|
||||
|
||||
Besides the feature data, an index `XXX_a.index` is included to speed up the querying performance
|
||||
|
||||
The statements are soted by the `date` in ascending order from the beginning of the file.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# the data format from XXXX.data
|
||||
array([(20070428, 200701, 0.090219 , 4294967295),
|
||||
(20070817, 200702, 0.13933 , 4294967295),
|
||||
(20071023, 200703, 0.24586301, 4294967295),
|
||||
(20080301, 200704, 0.3479 , 80),
|
||||
(20080313, 200704, 0.395989 , 4294967295),
|
||||
(20080422, 200801, 0.100724 , 4294967295),
|
||||
(20080828, 200802, 0.24996801, 4294967295),
|
||||
(20081027, 200803, 0.33412001, 4294967295),
|
||||
(20090325, 200804, 0.39011699, 4294967295),
|
||||
(20090421, 200901, 0.102675 , 4294967295),
|
||||
(20090807, 200902, 0.230712 , 4294967295),
|
||||
(20091024, 200903, 0.30072999, 4294967295),
|
||||
(20100402, 200904, 0.33546099, 4294967295),
|
||||
(20100426, 201001, 0.083825 , 4294967295),
|
||||
(20100812, 201002, 0.200545 , 4294967295),
|
||||
(20101029, 201003, 0.260986 , 4294967295),
|
||||
(20110321, 201004, 0.30739301, 4294967295),
|
||||
(20110423, 201101, 0.097411 , 4294967295),
|
||||
(20110831, 201102, 0.24825101, 4294967295),
|
||||
(20111018, 201103, 0.318919 , 4294967295),
|
||||
(20120323, 201104, 0.4039 , 420),
|
||||
(20120411, 201104, 0.403925 , 4294967295),
|
||||
(20120426, 201201, 0.112148 , 4294967295),
|
||||
(20120810, 201202, 0.26484701, 4294967295),
|
||||
(20121026, 201203, 0.370487 , 4294967295),
|
||||
(20130329, 201204, 0.45004699, 4294967295),
|
||||
(20130418, 201301, 0.099958 , 4294967295),
|
||||
(20130831, 201302, 0.21044201, 4294967295),
|
||||
(20131016, 201303, 0.30454299, 4294967295),
|
||||
(20140325, 201304, 0.394328 , 4294967295),
|
||||
(20140425, 201401, 0.083217 , 4294967295),
|
||||
(20140829, 201402, 0.16450299, 4294967295),
|
||||
(20141030, 201403, 0.23408499, 4294967295),
|
||||
(20150421, 201404, 0.319612 , 4294967295),
|
||||
(20150421, 201501, 0.078494 , 4294967295),
|
||||
(20150828, 201502, 0.137504 , 4294967295),
|
||||
(20151023, 201503, 0.201709 , 4294967295),
|
||||
(20160324, 201504, 0.26420501, 4294967295),
|
||||
(20160421, 201601, 0.073664 , 4294967295),
|
||||
(20160827, 201602, 0.136576 , 4294967295),
|
||||
(20161029, 201603, 0.188062 , 4294967295),
|
||||
(20170415, 201604, 0.244385 , 4294967295),
|
||||
(20170425, 201701, 0.080614 , 4294967295),
|
||||
(20170728, 201702, 0.15151 , 4294967295),
|
||||
(20171026, 201703, 0.25416601, 4294967295),
|
||||
(20180328, 201704, 0.32954201, 4294967295),
|
||||
(20180428, 201801, 0.088887 , 4294967295),
|
||||
(20180802, 201802, 0.170563 , 4294967295),
|
||||
(20181029, 201803, 0.25522 , 4294967295),
|
||||
(20190329, 201804, 0.34464401, 4294967295),
|
||||
(20190425, 201901, 0.094737 , 4294967295),
|
||||
(20190713, 201902, 0. , 1040),
|
||||
(20190718, 201902, 0.175322 , 4294967295),
|
||||
(20191016, 201903, 0.25581899, 4294967295)],
|
||||
dtype=[('date', '<u4'), ('period', '<u4'), ('value', '<f8'), ('_next', '<u4')])
|
||||
# - each row contains 20 byte
|
||||
|
||||
|
||||
# The data format from XXXX.index. It consists of two parts
|
||||
# 1) the start index of the data. So the first part of the info will be like
|
||||
2007
|
||||
# 2) the remain index data will be like information below
|
||||
# - The data indicate the **byte index** of first data update of a period.
|
||||
# - e.g. Because the info at both byte 80 and 100 corresponds to 200704. The byte index of first occurance (i.e. 100) is recorded in the data.
|
||||
array([ 0, 20, 40, 60, 100,
|
||||
120, 140, 160, 180, 200,
|
||||
220, 240, 260, 280, 300,
|
||||
320, 340, 360, 380, 400,
|
||||
440, 460, 480, 500, 520,
|
||||
540, 560, 580, 600, 620,
|
||||
640, 660, 680, 700, 720,
|
||||
740, 760, 780, 800, 820,
|
||||
840, 860, 880, 900, 920,
|
||||
940, 960, 980, 1000, 1020,
|
||||
1060, 4294967295], dtype=uint32)
|
||||
|
||||
|
||||
|
||||
|
||||
Known limitations:
|
||||
|
||||
- Currently, the PIT database is designed for quarterly or annually factors, which can handle fundamental data of financial reports in most markets.
|
||||
- Qlib leverage the file name to identify the type of the data. File with name like `XXX_q.data` corresponds to quarterly data. File with name like `XXX_a.data` corresponds to annual data.
|
||||
- The caclulation of PIT is not performed in the optimal way. There is great potential to boost the performance of PIT data calcuation.
|
||||
@@ -437,7 +437,7 @@ Dataset
|
||||
|
||||
The ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing.
|
||||
|
||||
The motivation of this module is that we want to maximize the flexibility of of different models to handle data that are suitable for themselves. This module gives the model the flexibility to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
|
||||
The motivation of this module is that we want to maximize the flexibility of different models to handle data that are suitable for themselves. This module gives the model the flexibility to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
|
||||
|
||||
If user's model need process its data in a different way, user could implement his own ``Dataset`` class. If the model's
|
||||
data processing is not special, ``DatasetH`` can be used directly.
|
||||
|
||||
@@ -143,3 +143,9 @@ Here is a simple exampke of what is done in ``PortAnaRecord``, which users can r
|
||||
print(analysis_df)
|
||||
|
||||
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
|
||||
|
||||
|
||||
|
||||
Known Limitations
|
||||
=================
|
||||
- The Python objects are saved based on pickle, which may results in issues when the environment dumping objects and loading objects are different.
|
||||
|
||||
@@ -20,6 +20,9 @@ Introduction
|
||||
- model_performance_graph
|
||||
|
||||
|
||||
All of the accumulated profit metrics(e.g. return, max drawdown) in Qlib are calculated by summation.
|
||||
This avoids the metrics or the plots being skewed exponentially over time.
|
||||
|
||||
Graphical Reports
|
||||
===================
|
||||
|
||||
@@ -101,7 +104,7 @@ Graphical Result
|
||||
- Axis Y:
|
||||
- `ic`
|
||||
The `Pearson correlation coefficient` series between `label` and `prediction score`.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
In the above example, the `label` is formulated as `Ref($close, -2)/Ref($close, -1)-1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
|
||||
- `rank_ic`
|
||||
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.
|
||||
|
||||
@@ -24,11 +24,8 @@ BaseStrategy
|
||||
|
||||
Qlib provides a base class ``qlib.strategy.base.BaseStrategy``. All strategy classes need to inherit the base class and implement its interface.
|
||||
|
||||
- `get_risk_degree`
|
||||
Return the proportion of your total value you will use in investment. Dynamically risk_degree will result in Market timing.
|
||||
|
||||
- `generate_order_list`
|
||||
Return the order list.
|
||||
- `generate_trade_decision`
|
||||
generate_trade_decision is a key interface that generates trade decisions in each trading bar.
|
||||
The frequency to call this method depends on the executor frequency("time_per_step"="day" by default). But the trading frequency can be decided by users' implementation.
|
||||
For example, if the user wants to trading in weekly while the `time_per_step` is "day" in executor, user can return non-empty TradeDecision weekly(otherwise return empty like `this <https://github.com/microsoft/qlib/blob/main/qlib/contrib/strategy/signal_strategy.py#L132>`_ ).
|
||||
|
||||
@@ -69,18 +66,24 @@ TopkDropoutStrategy
|
||||
- Adopt the ``Topk-Drop`` algorithm to calculate the target amount of each stock
|
||||
|
||||
.. note::
|
||||
``Topk-Drop`` algorithm:
|
||||
There are two parameters for the ``Topk-Drop`` algorithm:
|
||||
|
||||
- `Topk`: The number of stocks held
|
||||
- `Drop`: The number of stocks sold on each trading day
|
||||
|
||||
Currently, the number of held stocks is `Topk`.
|
||||
On each trading day, the `Drop` number of held stocks with the worst `prediction score` will be sold, and the same number of unheld stocks with the best `prediction score` will be bought.
|
||||
|
||||
In general, the number of stocks currently held is `Topk`, with the exception of being zero at the beginning period of trading.
|
||||
For each trading day, let $d$ be the number of the instruments currently held and with a rank $\gt K$ when ranked by the prediction scores from high to low.
|
||||
Then `d` number of stocks currently held with the worst `prediction score` will be sold, and the same number of unheld stocks with the best `prediction score` will be bought.
|
||||
|
||||
In general, $d=$`Drop`, especially when the pool of the candidate instruments is large, $K$ is large, and `Drop` is small.
|
||||
|
||||
In most cases, ``TopkDrop`` algorithm sells and buys `Drop` stocks every trading day, which yields a turnover rate of 2$\times$`Drop`/$K$.
|
||||
|
||||
The following images illustrate a typical scenario.
|
||||
.. image:: ../_static/img/topk_drop.png
|
||||
:alt: Topk-Drop
|
||||
|
||||
``TopkDrop`` algorithm sells `Drop` stocks every trading day, which guarantees a fixed turnover rate.
|
||||
|
||||
|
||||
- Generate the order list from the target amount
|
||||
|
||||
@@ -164,12 +167,9 @@ Running backtest
|
||||
start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_obj
|
||||
)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"], freq=analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq
|
||||
)
|
||||
# default frequency will be daily (i.e. "day")
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
pprint(analysis_df)
|
||||
|
||||
@@ -233,7 +233,7 @@ The meaning of each field is as follows:
|
||||
Dataset Section
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Model <../component/data.html#dataset>`_.
|
||||
The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Data <../component/data.html#dataset>`_.
|
||||
|
||||
The keywords arguments configuration of the ``DataHandler`` is as follows:
|
||||
|
||||
@@ -248,7 +248,7 @@ The keywords arguments configuration of the ``DataHandler`` is as follows:
|
||||
|
||||
Users can refer to the document of `DataHandler <../component/data.html#datahandler>`_ for more information about the meaning of each field in the configuration.
|
||||
|
||||
Here is the configuration for the ``Dataset`` module which will take care of data preprossing and slicing during the training and testing phase.
|
||||
Here is the configuration for the ``Dataset`` module which will take care of data preprocessing and slicing during the training and testing phase.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ When you submit a PR request, you can check whether your code passes the CI test
|
||||
1. Qlib will check the code format with black. The PR will raise error if your code does not align to the standard of Qlib(e.g. a common error is the mixed use of space and tab).
|
||||
You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: bash
|
||||
|
||||
pip install black
|
||||
python -m black . -l 120
|
||||
@@ -30,3 +30,19 @@ When you submit a PR request, you can check whether your code passes the CI test
|
||||
|
||||
return -ICLoss()(pred, target, index) # pylint: disable=E1130
|
||||
|
||||
|
||||
3. Qlib will check your code style flake8. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L73).
|
||||
You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
flake8 --ignore E501,F541,E402,F401,W503,E741,E266,E203,E302,E731,E262,F523,F821,F811,F841,E713,E265,W291,E712,E722,W293 qlib
|
||||
|
||||
|
||||
4. Qlib has integrated pre-commit, which will make it easier for developers to format their code.
|
||||
Just run the following two commands, and the code will be automatically formatted using black and flake8 when the git commit command is executed.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -e .[dev]
|
||||
pre-commit install
|
||||
@@ -53,6 +53,7 @@ Document Structure
|
||||
Online & Offline mode <advanced/server.rst>
|
||||
Serialization <advanced/serial.rst>
|
||||
Task Management <advanced/task_management.rst>
|
||||
Point-In-Time database <advanced/PIT.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -6,3 +6,4 @@
|
||||
|
||||
[https://www.ijcai.org/Proceedings/2017/0366.pdf](https://www.ijcai.org/Proceedings/2017/0366.pdf)
|
||||
|
||||
- NOTE: Current version of implementation is just a simplified version of ALSTM. It is an LSTM with attention.
|
||||
|
||||
3
examples/benchmarks/HIST/README.md
Normal file
3
examples/benchmarks/HIST/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# HIST
|
||||
* Code: [https://github.com/Wentao-Xu/HIST](https://github.com/Wentao-Xu/HIST)
|
||||
* Paper: [HIST: A Graph-based Framework for Stock Trend Forecasting via Mining Concept-Oriented Shared InformationAdaRNN: Adaptive Learning and Forecasting for Time Series](https://arxiv.org/abs/2110.13716).
|
||||
BIN
examples/benchmarks/HIST/qlib_csi300_stock_index.npy
Normal file
BIN
examples/benchmarks/HIST/qlib_csi300_stock_index.npy
Normal file
Binary file not shown.
4
examples/benchmarks/HIST/requirements.txt
Normal file
4
examples/benchmarks/HIST/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
92
examples/benchmarks/HIST/workflow_config_hist_Alpha360.yaml
Normal file
92
examples/benchmarks/HIST/workflow_config_hist_Alpha360.yaml
Normal file
@@ -0,0 +1,92 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: HIST
|
||||
module_path: qlib.contrib.model.pytorch_hist
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
metric: ic
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
stock2concept: "benchmarks/HIST/qlib_csi300_stock2concept.npy"
|
||||
stock_index: "benchmarks/HIST/qlib_csi300_stock_index.npy"
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
4
examples/benchmarks/IGMTF/README.md
Normal file
4
examples/benchmarks/IGMTF/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# IGMTF
|
||||
* Code: [https://github.com/Wentao-Xu/IGMTF](https://github.com/Wentao-Xu/IGMTF)
|
||||
* Paper: [IGMTF: An Instance-wise Graph-based Framework for
|
||||
Multivariate Time Series Forecasting](https://arxiv.org/abs/2109.06489).
|
||||
4
examples/benchmarks/IGMTF/requirements.txt
Normal file
4
examples/benchmarks/IGMTF/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -0,0 +1,89 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: IGMTF
|
||||
module_path: qlib.contrib.model.pytorch_igmtf
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
metric: ic
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -65,6 +65,9 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
|
||||
| TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 |
|
||||
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
|
||||
| IGMTF(Wentao Xu, et al.) | Alpha360 | 0.0480±0.00 | 0.3589±0.02 | 0.0606±0.00 | 0.4773±0.01 | 0.0946±0.02 | 1.3509±0.25 | -0.0716±0.02 |
|
||||
| HIST(Wentao Xu, et al.) | Alpha360 | 0.0522±0.00 | 0.3530±0.01 | 0.0667±0.00 | 0.4576±0.01 | 0.0987±0.02 | 1.3726±0.27 | -0.0681±0.01 |
|
||||
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
@@ -75,3 +78,20 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
- The metrics can be categorized into two
|
||||
- Signal-based evaluation: IC, ICIR, Rank IC, Rank ICIR
|
||||
- Portfolio-based metrics: Annualized Return, Information Ratio, Max Drawdown
|
||||
|
||||
|
||||
# Contributing
|
||||
|
||||
Your contributions to new models are highly welcome!
|
||||
|
||||
If you want to contribute your new models, you can follow the steps below.
|
||||
1. Create a folder for your model
|
||||
2. The folder contains following items(you can refer to [this example](https://github.com/microsoft/qlib/tree/main/examples/benchmarks/TCTS)).
|
||||
- `requirements.txt`: required dependencies.
|
||||
- `README.md`: a brief introduction to your models
|
||||
- `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.
|
||||
3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).
|
||||
4. Please updated your results in the benchmark tables, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on 20 runs with different random seeds, if you don't have enough computational resource, you can ask for help in the PR).
|
||||
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
|
||||
|
||||
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))
|
||||
|
||||
@@ -6,8 +6,7 @@ import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -95,7 +94,7 @@ class MTSDatasetH(DatasetH):
|
||||
shuffle=True,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
assert horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
@@ -150,8 +149,15 @@ class MTSDatasetH(DatasetH):
|
||||
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
fn = _get_date_parse_fn(self._index[0][1])
|
||||
start_date = fn(slc.start)
|
||||
end_date = fn(slc.stop)
|
||||
|
||||
if isinstance(slc, slice):
|
||||
start, stop = slc.start, slc.stop
|
||||
elif isinstance(slc, (list, tuple)):
|
||||
start, stop = slc
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
start_date = fn(start)
|
||||
end_date = fn(stop)
|
||||
obj = copy.copy(self) # shallow copy
|
||||
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
|
||||
obj._data = self._data
|
||||
|
||||
@@ -9,13 +9,10 @@ from qlib.data.dataset.handler import DataHandlerLP
|
||||
import pandas as pd
|
||||
import fire
|
||||
import sys
|
||||
from tqdm.auto import tqdm
|
||||
import yaml
|
||||
import pickle
|
||||
from qlib import auto_init
|
||||
from qlib.model.trainer import TrainerR, task_train
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
@@ -47,9 +44,10 @@ class DDGDA:
|
||||
rb = RollingBenchmark(model_type="gbdt")
|
||||
task = rb.basic_task()
|
||||
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
with R.start(experiment_name="feature_importance"):
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
fi = model.get_feature_importance()
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
@@ -15,6 +15,10 @@ from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
|
||||
import pandas as pd
|
||||
from qlib.contrib.evaluate import backtest_daily
|
||||
from qlib.contrib.evaluate import risk_analysis
|
||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
@@ -30,6 +34,7 @@ class OnlineSimulationExample:
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=None,
|
||||
trainer="TrainerR",
|
||||
):
|
||||
"""
|
||||
Init OnlineManagerExample.
|
||||
@@ -60,7 +65,13 @@ class OnlineSimulationExample:
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
if trainer == "TrainerRM":
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool)
|
||||
elif trainer == "TrainerR":
|
||||
self.trainer = TrainerR(self.exp_name)
|
||||
else:
|
||||
# TODO: support all the trainers: TrainerR, TrainerRM, DelayTrainerR
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
@@ -70,7 +81,8 @@ class OnlineSimulationExample:
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
TaskManager(self.task_pool).remove()
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
@@ -84,7 +96,30 @@ class OnlineSimulationExample:
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
signals = self.rolling_online_manager.get_signals()
|
||||
print(signals)
|
||||
# Backtesting
|
||||
# - the code is based on this example https://qlib.readthedocs.io/en/latest/component/strategy.html
|
||||
CSI300_BENCH = "SH000903"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 30,
|
||||
"n_drop": 3,
|
||||
"signal": signals.to_frame("score"),
|
||||
}
|
||||
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = backtest_daily(
|
||||
start_time=signals.index.get_level_values("datetime").min(),
|
||||
end_time=signals.index.get_level_values("datetime").max(),
|
||||
strategy=strategy_obj,
|
||||
)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
pprint(analysis_df)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
|
||||
1218
examples/tutorial/detailed_workflow.ipynb
Normal file
1218
examples/tutorial/detailed_workflow.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -256,7 +256,6 @@
|
||||
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"print(recorder)\n",
|
||||
"pred_df = recorder.load_object(\"pred.pkl\")\n",
|
||||
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
|
||||
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n",
|
||||
"positions = recorder.load_object(\"portfolio_analysis/positions_normal_1day.pkl\")\n",
|
||||
"analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis_1day.pkl\")"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.8.4"
|
||||
__version__ = "0.8.5.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
@@ -12,6 +12,7 @@ import platform
|
||||
import subprocess
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
from typing import List, Tuple, Union, TYPE_CHECKING
|
||||
@@ -323,3 +324,6 @@ def format_decisions(
|
||||
last_dec_idx = i
|
||||
res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 :])))
|
||||
return res
|
||||
|
||||
|
||||
__all__ = ["Order"]
|
||||
|
||||
@@ -242,7 +242,7 @@ class BaseExecutor:
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
|
||||
atomic = not issubclass(self.__class__, NestedExecutor) # issubclass(A, A) is True
|
||||
atomic = not issubclass(self.__class__, NestedExecutor) # issubclass(A, A) is True
|
||||
|
||||
if atomic and trade_decision.get_range_limit(default_value=None) is not None:
|
||||
raise ValueError("atomic executor doesn't support specify `range_limit`")
|
||||
|
||||
@@ -28,7 +28,6 @@ class Signal(metaclass=abc.ABCMeta):
|
||||
Union[pd.Series, pd.DataFrame, None]:
|
||||
returns None if no signal in the specific day
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SignalWCache(Signal):
|
||||
|
||||
@@ -22,7 +22,7 @@ from pathlib import Path
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from qlib.constant import REG_CN, REG_US
|
||||
from qlib.constant import REG_CN, REG_US, REG_TW
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.utils.time import Freq
|
||||
@@ -92,6 +92,7 @@ _default_config = {
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
"instrument_provider": "LocalInstrumentProvider",
|
||||
"feature_provider": "LocalFeatureProvider",
|
||||
"pit_provider": "LocalPITProvider",
|
||||
"expression_provider": "LocalExpressionProvider",
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
@@ -108,7 +109,6 @@ _default_config = {
|
||||
"provider_uri": "",
|
||||
# cache
|
||||
"expression_cache": None,
|
||||
"dataset_cache": None,
|
||||
"calendar_cache": None,
|
||||
# for simple dataset cache
|
||||
"local_cache_path": None,
|
||||
@@ -171,6 +171,18 @@ _default_config = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
},
|
||||
"pit_record_type": {
|
||||
"date": "I", # uint32
|
||||
"period": "I", # uint32
|
||||
"value": "d", # float64
|
||||
"index": "I", # uint32
|
||||
},
|
||||
"pit_record_nan": {
|
||||
"date": 0,
|
||||
"period": 0,
|
||||
"value": float("NAN"),
|
||||
"index": 0xFFFFFFFF,
|
||||
},
|
||||
# Default config for MongoDB
|
||||
"mongo": {
|
||||
"task_url": "mongodb://localhost:27017/",
|
||||
@@ -184,20 +196,12 @@ _default_config = {
|
||||
|
||||
MODE_CONF = {
|
||||
"server": {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
"instrument_provider": "LocalInstrumentProvider",
|
||||
"feature_provider": "LocalFeatureProvider",
|
||||
"expression_provider": "LocalExpressionProvider",
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
# config it in qlib.init()
|
||||
"provider_uri": "",
|
||||
# redis
|
||||
"redis_host": "127.0.0.1",
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# cache
|
||||
"expression_cache": DISK_EXPRESSION_CACHE,
|
||||
"dataset_cache": DISK_DATASET_CACHE,
|
||||
@@ -205,25 +209,15 @@ MODE_CONF = {
|
||||
"mount_path": None,
|
||||
},
|
||||
"client": {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
"instrument_provider": "LocalInstrumentProvider",
|
||||
"feature_provider": "LocalFeatureProvider",
|
||||
"expression_provider": "LocalExpressionProvider",
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
# config it in user's own code
|
||||
"provider_uri": "~/.qlib/qlib_data/cn_data",
|
||||
# cache
|
||||
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
|
||||
# Disable cache by default. Avoid introduce advanced features for beginners
|
||||
"expression_cache": None,
|
||||
"dataset_cache": None,
|
||||
# SimpleDatasetCache directory
|
||||
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
|
||||
"calendar_cache": None,
|
||||
# client config
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
"mount_path": None,
|
||||
"auto_mount": False, # The nfs is already mounted on our server[auto_mount: False].
|
||||
# The nfs should be auto-mounted by qlib on other
|
||||
@@ -257,6 +251,11 @@ _default_region_config = {
|
||||
"limit_threshold": None,
|
||||
"deal_price": "close",
|
||||
},
|
||||
REG_TW: {
|
||||
"trade_unit": 1000,
|
||||
"limit_threshold": 0.1,
|
||||
"deal_price": "close",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,10 @@
|
||||
# REGION CONST
|
||||
REG_CN = "cn"
|
||||
REG_US = "us"
|
||||
REG_TW = "tw"
|
||||
|
||||
# Epsilon for avoiding division by zero.
|
||||
EPS = 1e-12
|
||||
|
||||
# Infinity in integer
|
||||
INF = 10**18
|
||||
|
||||
164
qlib/contrib/data/highfreq_handler.py
Normal file
164
qlib/contrib/data/highfreq_handler.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
|
||||
|
||||
EPSILON = 1e-4
|
||||
|
||||
|
||||
class HighFreqHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
drop_raw=drop_raw,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
# norm with the close price of 237th minute of yesterday.
|
||||
if shift == 0:
|
||||
template_norm = "{0}/DayLast(Ref({1}, 243))"
|
||||
else:
|
||||
template_norm = "Ref({0}, " + str(shift) + ")/DayLast(Ref({1}, 243))"
|
||||
|
||||
template_fillnan = "FFillNan({0})"
|
||||
# calculate -> ffill -> remove paused
|
||||
feature_ops = template_paused.format(
|
||||
template_fillnan.format(
|
||||
template_norm.format(template_if.format("$close", price_field), template_fillnan.format("$close"))
|
||||
)
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 0)]
|
||||
fields += [get_normalized_price_feature("$high", 0)]
|
||||
fields += [get_normalized_price_feature("$low", 0)]
|
||||
fields += [get_normalized_price_feature("$close", 0)]
|
||||
fields += [get_normalized_price_feature("$vwap", 0)]
|
||||
names += ["$open", "$high", "$low", "$close", "$vwap"]
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 240)]
|
||||
fields += [get_normalized_price_feature("$high", 240)]
|
||||
fields += [get_normalized_price_feature("$low", 240)]
|
||||
fields += [get_normalized_price_feature("$close", 240)]
|
||||
fields += [get_normalized_price_feature("$vwap", 240)]
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
|
||||
# calculate and fill nan with 0
|
||||
template_gzero = "If(Ge({0}, 0), {0}, 0)"
|
||||
fields += [
|
||||
template_gzero.format(
|
||||
template_paused.format(
|
||||
"If(IsNull({0}), 0, {0})".format("{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format("$volume"))
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume"]
|
||||
|
||||
fields += [
|
||||
template_gzero.format(
|
||||
template_paused.format(
|
||||
"If(IsNull({0}), 0, {0})".format(
|
||||
"Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format("$volume")
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqBacktestHandler(DataHandler):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
# template_paused = "{0}"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
|
||||
fields += [
|
||||
template_paused.format(
|
||||
template_if.format(
|
||||
template_fillnan.format("$close"),
|
||||
"$vwap",
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
|
||||
names += ["$volume0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$factor"))]
|
||||
names += ["$factor0"]
|
||||
|
||||
return fields, names
|
||||
81
qlib/contrib/data/highfreq_processor.py
Normal file
81
qlib/contrib/data/highfreq_processor.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.data.dataset.utils import fetch_df_by_index
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class HighFreqTrans(Processor):
|
||||
def __init__(self, dtype: str = "bool"):
|
||||
self.dtype = dtype
|
||||
|
||||
def fit(self, df_features):
|
||||
pass
|
||||
|
||||
def __call__(self, df_features):
|
||||
if self.dtype == "bool":
|
||||
return df_features.astype(np.int8)
|
||||
else:
|
||||
return df_features.astype(np.float32)
|
||||
|
||||
|
||||
class HighFreqNorm(Processor):
|
||||
def __init__(
|
||||
self,
|
||||
fit_start_time: pd.Timestamp,
|
||||
fit_end_time: pd.Timestamp,
|
||||
feature_save_dir: str,
|
||||
norm_groups: Dict[str, int],
|
||||
):
|
||||
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.feature_save_dir = feature_save_dir
|
||||
self.norm_groups = norm_groups
|
||||
|
||||
def fit(self, df_features) -> None:
|
||||
if os.path.exists(self.feature_save_dir) and len(os.listdir(self.feature_save_dir)) != 0:
|
||||
return
|
||||
os.makedirs(self.feature_save_dir)
|
||||
fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
del df_features
|
||||
index = 0
|
||||
names = {}
|
||||
for name, dim in self.norm_groups.items():
|
||||
names[name] = slice(index, index + dim)
|
||||
index += dim
|
||||
for name, name_val in names.items():
|
||||
df_values = fetch_df.iloc(axis=1)[name_val].values
|
||||
if name.endswith("volume"):
|
||||
df_values = np.log1p(df_values)
|
||||
self.feature_mean = np.nanmean(df_values)
|
||||
np.save(self.feature_save_dir + name + "_mean.npy", self.feature_mean)
|
||||
df_values = df_values - self.feature_mean
|
||||
self.feature_std = np.nanstd(np.absolute(df_values))
|
||||
np.save(self.feature_save_dir + name + "_std.npy", self.feature_std)
|
||||
df_values = df_values / self.feature_std
|
||||
np.save(self.feature_save_dir + name + "_vmax.npy", np.nanmax(df_values))
|
||||
np.save(self.feature_save_dir + name + "_vmin.npy", np.nanmin(df_values))
|
||||
return
|
||||
|
||||
def __call__(self, df_features):
|
||||
if "date" in df_features:
|
||||
df_features.droplevel("date", inplace=True)
|
||||
df_values = df_features.values
|
||||
index = 0
|
||||
names = {}
|
||||
for name, dim in self.norm_groups.items():
|
||||
names[name] = slice(index, index + dim)
|
||||
index += dim
|
||||
for name, name_val in names.items():
|
||||
feature_mean = np.load(self.feature_save_dir + name + "_mean.npy")
|
||||
feature_std = np.load(self.feature_save_dir + name + "_std.npy")
|
||||
|
||||
if name.endswith("volume"):
|
||||
df_values[:, name_val] = np.log1p(df_values[:, name_val])
|
||||
df_values[:, name_val] -= feature_mean
|
||||
df_values[:, name_val] /= feature_std
|
||||
df_features = pd.DataFrame(data=df_values, index=df_features.index, columns=df_features.columns)
|
||||
return df_features.fillna(0)
|
||||
301
qlib/contrib/data/highfreq_provider.py
Normal file
301
qlib/contrib/data/highfreq_provider.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.data import Cal
|
||||
from qlib.contrib.ops.high_freq import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut
|
||||
import pickle as pkl
|
||||
from joblib import Parallel, delayed
|
||||
from utilsd.logging import print_log
|
||||
|
||||
|
||||
class HighFreqProvider:
|
||||
def __init__(
|
||||
self,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
train_end_time: str,
|
||||
valid_start_time: str,
|
||||
valid_end_time: str,
|
||||
test_start_time: str,
|
||||
qlib_conf: dict,
|
||||
feature_conf: dict,
|
||||
label_conf: Optional[dict] = None,
|
||||
backtest_conf: dict = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.test_start_time = test_start_time
|
||||
self.train_end_time = train_end_time
|
||||
self.valid_start_time = valid_start_time
|
||||
self.valid_end_time = valid_end_time
|
||||
self._init_qlib(qlib_conf)
|
||||
self.feature_conf = feature_conf
|
||||
self.label_conf = label_conf
|
||||
self.backtest_conf = backtest_conf
|
||||
self.qlib_conf = qlib_conf
|
||||
|
||||
def get_pre_datasets(self):
|
||||
"""Generate the training, validation and test datasets for prediction
|
||||
|
||||
Returns:
|
||||
Tuple[BaseDataset, BaseDataset, BaseDataset]: The training and test datasets
|
||||
"""
|
||||
|
||||
dict_feature_path = self.feature_conf["path"]
|
||||
train_feature_path = dict_feature_path[:-4] + "_train.pkl"
|
||||
valid_feature_path = dict_feature_path[:-4] + "_valid.pkl"
|
||||
test_feature_path = dict_feature_path[:-4] + "_test.pkl"
|
||||
|
||||
dict_label_path = self.label_conf["path"]
|
||||
train_label_path = dict_label_path[:-4] + "_train.pkl"
|
||||
valid_label_path = dict_label_path[:-4] + "_valid.pkl"
|
||||
test_label_path = dict_label_path[:-4] + "_test.pkl"
|
||||
|
||||
if (
|
||||
not os.path.isfile(train_feature_path)
|
||||
or not os.path.isfile(valid_feature_path)
|
||||
or not os.path.isfile(test_feature_path)
|
||||
):
|
||||
xtrain, xvalid, xtest = self._gen_data(self.feature_conf)
|
||||
xtrain.to_pickle(train_feature_path)
|
||||
xvalid.to_pickle(valid_feature_path)
|
||||
xtest.to_pickle(test_feature_path)
|
||||
del xtrain, xvalid, xtest
|
||||
|
||||
if (
|
||||
not os.path.isfile(train_label_path)
|
||||
or not os.path.isfile(valid_label_path)
|
||||
or not os.path.isfile(test_label_path)
|
||||
):
|
||||
ytrain, yvalid, ytest = self._gen_data(self.label_conf)
|
||||
ytrain.to_pickle(train_label_path)
|
||||
yvalid.to_pickle(valid_label_path)
|
||||
ytest.to_pickle(test_label_path)
|
||||
del ytrain, yvalid, ytest
|
||||
|
||||
feature = {
|
||||
"train": train_feature_path,
|
||||
"valid": valid_feature_path,
|
||||
"test": test_feature_path,
|
||||
}
|
||||
|
||||
label = {
|
||||
"train": train_label_path,
|
||||
"valid": valid_label_path,
|
||||
"test": test_label_path,
|
||||
}
|
||||
|
||||
return feature, label
|
||||
|
||||
def get_backtest(self, **kwargs) -> None:
|
||||
self._gen_data(self.backtest_conf)
|
||||
|
||||
def _init_qlib(self, qlib_conf):
|
||||
"""initialize qlib"""
|
||||
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut],
|
||||
expression_cache=None,
|
||||
**qlib_conf,
|
||||
)
|
||||
|
||||
def _prepare_calender_cache(self):
|
||||
"""preload the calendar for cache"""
|
||||
|
||||
# This code used the copy-on-write feature of Linux
|
||||
# to avoid calculating the calendar multiple times in the subprocess.
|
||||
# This code may accelerate, but may be not useful on Windows and Mac Os
|
||||
Cal.calendar(freq="1min")
|
||||
get_calendar_day(freq="1min")
|
||||
|
||||
def _gen_dataframe(self, config, datasets=["train", "valid", "test"]):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
data = pkl.load(f)
|
||||
if isinstance(data, dict):
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
trainset, validset, testset = dataset.prepare(["train", "valid", "test"])
|
||||
data = {
|
||||
"train": trainset,
|
||||
"valid": validset,
|
||||
"test": testset,
|
||||
}
|
||||
with open(path, "wb") as f:
|
||||
pkl.dump(data, f)
|
||||
with open(path[:-4] + "train.pkl", "wb") as f:
|
||||
pkl.dump(trainset, f)
|
||||
with open(path[:-4] + "valid.pkl", "wb") as f:
|
||||
pkl.dump(validset, f)
|
||||
with open(path[:-4] + "test.pkl", "wb") as f:
|
||||
pkl.dump(testset, f)
|
||||
res = [data[i] for i in datasets]
|
||||
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
return res
|
||||
|
||||
def _gen_data(self, config, datasets=["train", "valid", "test"]):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
data = pkl.load(f)
|
||||
if isinstance(data, dict):
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
res = dataset.prepare(datasets)
|
||||
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
return res
|
||||
|
||||
def _gen_dataset(self, config):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
dataset = pkl.load(f)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.prepare(["train", "valid", "test"])
|
||||
print_log(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
return dataset
|
||||
|
||||
def _gen_day_dataset(self, config, conf_type):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
with open(path + "tmp_dataset.pkl", "rb") as f:
|
||||
new_dataset = pkl.load(f)
|
||||
|
||||
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="1min")[::240]
|
||||
|
||||
def generate_dataset(times):
|
||||
if os.path.isfile(path + times.strftime("%Y-%m-%d") + ".pkl"):
|
||||
print("exist " + times.strftime("%Y-%m-%d"))
|
||||
return
|
||||
self._init_qlib(self.qlib_conf)
|
||||
end_times = times + datetime.timedelta(days=1)
|
||||
new_dataset.handler.config(**{"start_time": times, "end_time": end_times})
|
||||
if conf_type == "backtest":
|
||||
new_dataset.handler.setup_data()
|
||||
else:
|
||||
new_dataset.handler.setup_data(init_type=DataHandlerLP.IT_LS)
|
||||
new_dataset.config(dump_all=True, recursive=True)
|
||||
new_dataset.to_pickle(path + times.strftime("%Y-%m-%d") + ".pkl")
|
||||
|
||||
Parallel(n_jobs=8)(delayed(generate_dataset)(times) for times in time_list)
|
||||
|
||||
def _gen_stock_dataset(self, config, conf_type):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
with open(path + "tmp_dataset.pkl", "rb") as f:
|
||||
new_dataset = pkl.load(f)
|
||||
|
||||
instruments = D.instruments(market="all")
|
||||
stock_list = D.list_instruments(
|
||||
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq="1min", as_list=True
|
||||
)
|
||||
|
||||
def generate_dataset(stock):
|
||||
if os.path.isfile(path + stock + ".pkl"):
|
||||
print("exist " + stock)
|
||||
return
|
||||
self._init_qlib(self.qlib_conf)
|
||||
new_dataset.handler.config(**{"instruments": [stock]})
|
||||
if conf_type == "backtest":
|
||||
new_dataset.handler.setup_data()
|
||||
else:
|
||||
new_dataset.handler.setup_data(init_type=DataHandlerLP.IT_LS)
|
||||
new_dataset.config(dump_all=True, recursive=True)
|
||||
new_dataset.to_pickle(path + stock + ".pkl")
|
||||
|
||||
Parallel(n_jobs=32)(delayed(generate_dataset)(stock) for stock in stock_list)
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import pandas as pd
|
||||
from typing import Dict, Iterable
|
||||
from typing import Dict, Iterable, Union
|
||||
|
||||
|
||||
def align_index(df_dict, join):
|
||||
@@ -24,6 +24,10 @@ class SepDataFrame:
|
||||
SepDataFrame tries to act like a DataFrame whose column with multiindex
|
||||
"""
|
||||
|
||||
# TODO:
|
||||
# SepDataFrame try to behave like pandas dataframe, but it is still not them same
|
||||
# Contributions are welcome to make it more complete.
|
||||
|
||||
def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False):
|
||||
"""
|
||||
initialize the data based on the dataframe dictionary
|
||||
@@ -77,14 +81,37 @@ class SepDataFrame:
|
||||
|
||||
def _update_join(self):
|
||||
if self.join not in self:
|
||||
self.join = next(iter(self._df_dict.keys()))
|
||||
if len(self._df_dict) > 0:
|
||||
self.join = next(iter(self._df_dict.keys()))
|
||||
else:
|
||||
# NOTE: this will change the behavior of previous reindex when all the keys are empty
|
||||
self.join = None
|
||||
|
||||
def __getitem__(self, item):
|
||||
# TODO: behave more like pandas when multiindex
|
||||
return self._df_dict[item]
|
||||
|
||||
def __setitem__(self, item: str, df: pd.DataFrame):
|
||||
def __setitem__(self, item: str, df: Union[pd.DataFrame, pd.Series]):
|
||||
# TODO: consider the join behavior
|
||||
self._df_dict[item] = df
|
||||
if not isinstance(item, tuple):
|
||||
self._df_dict[item] = df
|
||||
else:
|
||||
# NOTE: corner case of MultiIndex
|
||||
_df_dict_key, *col_name = item
|
||||
col_name = tuple(col_name)
|
||||
if _df_dict_key in self._df_dict:
|
||||
if len(col_name) == 1:
|
||||
col_name = col_name[0]
|
||||
self._df_dict[_df_dict_key][col_name] = df
|
||||
else:
|
||||
if isinstance(df, pd.Series):
|
||||
if len(col_name) == 1:
|
||||
col_name = col_name[0]
|
||||
self._df_dict[_df_dict_key] = df.to_frame(col_name)
|
||||
else:
|
||||
df_copy = df.copy() # avoid changing df
|
||||
df_copy.columns = pd.MultiIndex.from_tuples([(*col_name, *idx) for idx in df.columns.to_list()])
|
||||
self._df_dict[_df_dict_key] = df_copy
|
||||
|
||||
def __delitem__(self, item: str):
|
||||
del self._df_dict[item]
|
||||
@@ -164,14 +191,14 @@ import builtins
|
||||
|
||||
|
||||
def _isinstance(instance, cls):
|
||||
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602
|
||||
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602 # noqa: F821
|
||||
if isinstance(cls, Iterable):
|
||||
for c in cls:
|
||||
if c is pd.DataFrame:
|
||||
return True
|
||||
elif cls is pd.DataFrame:
|
||||
return True
|
||||
return isinstance_orig(instance, cls) # pylint: disable=E0602
|
||||
return isinstance_orig(instance, cls) # pylint: disable=E0602 # noqa: F821
|
||||
|
||||
|
||||
builtins.isinstance_orig = builtins.isinstance
|
||||
|
||||
@@ -123,7 +123,7 @@ def pred_autocorr(pred: pd.Series, lag=1, inst_col="instrument", date_col="datet
|
||||
"""
|
||||
if isinstance(pred, pd.DataFrame):
|
||||
pred = pred.iloc[:, 0]
|
||||
get_module_logger("pred_autocorr").warning("Only the first column in {pred.columns} of `pred` is kept")
|
||||
get_module_logger("pred_autocorr").warning(f"Only the first column in {pred.columns} of `pred` is kept")
|
||||
pred_ustk = pred.sort_index().unstack(inst_col)
|
||||
corr_s = {}
|
||||
for (idx, cur), (_, prev) in zip(pred_ustk.iterrows(), pred_ustk.shift(lag).iterrows()):
|
||||
|
||||
@@ -2,3 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .data_selection import MetaTaskDS, MetaDatasetDS, MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaTaskDS", "MetaDatasetDS", "MetaModelDS"]
|
||||
|
||||
@@ -3,3 +3,6 @@
|
||||
|
||||
from .dataset import MetaDatasetDS, MetaTaskDS
|
||||
from .model import MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaDatasetDS", "MetaTaskDS", "MetaModelDS"]
|
||||
|
||||
@@ -10,7 +10,6 @@ from tqdm.auto import tqdm
|
||||
import copy
|
||||
from typing import Union, List
|
||||
|
||||
from ....data.dataset.weight import Reweighter
|
||||
from ....model.meta.dataset import MetaTaskDataset
|
||||
from ....model.meta.model import MetaTaskModel
|
||||
from ....workflow import R
|
||||
@@ -18,8 +17,8 @@ from .utils import ICLoss
|
||||
from .dataset import MetaDatasetDS
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
from qlib.contrib.meta.data_selection.net import PredNet
|
||||
|
||||
logger = get_module_logger("data selection")
|
||||
@@ -98,7 +97,6 @@ class MetaModelDS(MetaTaskModel):
|
||||
|
||||
if phase == "train":
|
||||
opt.zero_grad()
|
||||
norm_loss = nn.MSELoss()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
elif phase == "test":
|
||||
|
||||
@@ -10,17 +10,19 @@ try:
|
||||
from .gbdt import LGBModel
|
||||
except ModuleNotFoundError:
|
||||
DEnsembleModel, LGBModel = None, None
|
||||
print("Please install necessary libs for DEnsembleModel and LGBModel, such as lightgbm.")
|
||||
print(
|
||||
"ModuleNotFoundError. DEnsembleModel and LGBModel are skipped. (optional: maybe installing lightgbm can fix it.)"
|
||||
)
|
||||
try:
|
||||
from .xgboost import XGBModel
|
||||
except ModuleNotFoundError:
|
||||
XGBModel = None
|
||||
print("Please install necessary libs for XGBModel, such as xgboost.")
|
||||
print("ModuleNotFoundError. XGBModel is skipped(optional: maybe installing xgboost can fix it).")
|
||||
try:
|
||||
from .linear import LinearModel
|
||||
except ModuleNotFoundError:
|
||||
LinearModel = None
|
||||
print("Please install necessary libs for LinearModel, such as scipy and sklearn.")
|
||||
print("ModuleNotFoundError. LinearModel is skipped(optional: maybe installing scipy and sklearn can fix it).")
|
||||
# import pytorch models
|
||||
try:
|
||||
from .pytorch_alstm import ALSTM
|
||||
@@ -36,6 +38,6 @@ try:
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN, ADD)
|
||||
except ModuleNotFoundError:
|
||||
pytorch_classes = ()
|
||||
print("Please install necessary libs for PyTorch models.")
|
||||
print("ModuleNotFoundError. PyTorch models are skipped (optional: maybe installing pytorch can fix it).")
|
||||
|
||||
all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes
|
||||
|
||||
@@ -249,7 +249,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
return pred
|
||||
|
||||
def predict_sub(self, submodel, df_data, features):
|
||||
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
|
||||
x_data = df_data["feature"].loc[:, features]
|
||||
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
|
||||
return pred_sub
|
||||
|
||||
|
||||
@@ -68,17 +68,19 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
evals_result = {} # in case of unsafety of Python default values
|
||||
ds_l = self._prepare_data(dataset, reweighter)
|
||||
ds, names = list(zip(*ds_l))
|
||||
early_stopping_callback = lgb.early_stopping(
|
||||
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
|
||||
)
|
||||
# NOTE: if you encounter error here. Please upgrade your lightgbm
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
evals_result_callback = lgb.record_evaluation(evals_result)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
ds[0], # training dataset
|
||||
num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round,
|
||||
valid_sets=ds,
|
||||
valid_names=names,
|
||||
early_stopping_rounds=(
|
||||
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
|
||||
),
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
|
||||
**kwargs,
|
||||
)
|
||||
for k in names:
|
||||
@@ -110,6 +112,7 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
|
||||
if dtrain.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
@@ -117,5 +120,5 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
callbacks=[verbose_eval_callback],
|
||||
)
|
||||
|
||||
@@ -110,18 +110,21 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
evals_result=None,
|
||||
):
|
||||
if evals_result is None:
|
||||
evals_result = dict()
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
early_stopping_callback = lgb.early_stopping(early_stopping_rounds)
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
evals_result_callback = lgb.record_evaluation(evals_result)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
@@ -147,6 +150,7 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
@@ -154,5 +158,5 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
callbacks=[verbose_eval_callback],
|
||||
)
|
||||
|
||||
@@ -144,7 +144,7 @@ class ADARNN(Model):
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.cuda()
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
@@ -153,7 +153,7 @@ class ADARNN(Model):
|
||||
def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None):
|
||||
self.model.train()
|
||||
criterion = nn.MSELoss()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||
len_loader = np.inf
|
||||
for loader in train_loader_list:
|
||||
if len(loader) < len_loader:
|
||||
@@ -165,7 +165,7 @@ class ADARNN(Model):
|
||||
list_label = []
|
||||
for data in data_all:
|
||||
# feature :[36, 24, 6]
|
||||
feature, label_reg = data[0].cuda().float(), data[1].cuda().float()
|
||||
feature, label_reg = data[0].to(self.device).float(), data[1].to(self.device).float()
|
||||
list_feat.append(feature)
|
||||
list_label.append(label_reg)
|
||||
flag = False
|
||||
@@ -179,7 +179,7 @@ class ADARNN(Model):
|
||||
if flag:
|
||||
continue
|
||||
|
||||
total_loss = torch.zeros(1).cuda()
|
||||
total_loss = torch.zeros(1).to(self.device)
|
||||
for i, n in enumerate(index):
|
||||
feature_s = list_feat[n[0]]
|
||||
feature_t = list_feat[n[1]]
|
||||
@@ -325,7 +325,7 @@ class ADARNN(Model):
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().cuda()
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model.predict(x_batch).detach().cpu().numpy()
|
||||
@@ -335,7 +335,7 @@ class ADARNN(Model):
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
def transform_type(self, init_weight):
|
||||
weight = torch.ones(self.num_layers, self.len_seq).cuda()
|
||||
weight = torch.ones(self.num_layers, self.len_seq).to(self.device)
|
||||
for i in range(self.num_layers):
|
||||
for j in range(self.len_seq):
|
||||
weight[i, j] = init_weight[i][j].item()
|
||||
@@ -389,6 +389,7 @@ class AdaRNN(nn.Module):
|
||||
len_seq=9,
|
||||
model_type="AdaRNN",
|
||||
trans_loss="mmd",
|
||||
GPU=0,
|
||||
):
|
||||
super(AdaRNN, self).__init__()
|
||||
self.use_bottleneck = use_bottleneck
|
||||
@@ -399,6 +400,7 @@ class AdaRNN(nn.Module):
|
||||
self.model_type = model_type
|
||||
self.trans_loss = trans_loss
|
||||
self.len_seq = len_seq
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
in_size = self.n_input
|
||||
|
||||
features = nn.ModuleList()
|
||||
@@ -455,7 +457,7 @@ class AdaRNN(nn.Module):
|
||||
|
||||
out_list_all, out_weight_list = out[1], out[2]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
loss_transfer = torch.zeros((1,)).to(self.device)
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
h_start = 0
|
||||
@@ -516,12 +518,12 @@ class AdaRNN(nn.Module):
|
||||
|
||||
out_list_all = out[1]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
loss_transfer = torch.zeros((1,)).to(self.device)
|
||||
if weight_mat is None:
|
||||
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).cuda()
|
||||
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).to(self.device)
|
||||
else:
|
||||
weight = weight_mat
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
for j in range(self.len_seq):
|
||||
@@ -553,12 +555,13 @@ class AdaRNN(nn.Module):
|
||||
|
||||
|
||||
class TransferLoss:
|
||||
def __init__(self, loss_type="cosine", input_dim=512):
|
||||
def __init__(self, loss_type="cosine", input_dim=512, GPU=0):
|
||||
"""
|
||||
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
|
||||
"""
|
||||
self.loss_type = loss_type
|
||||
self.input_dim = input_dim
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
|
||||
def compute(self, X, Y):
|
||||
"""Compute adaptation loss
|
||||
@@ -574,7 +577,7 @@ class TransferLoss:
|
||||
mmdloss = MMD_loss(kernel_type="linear")
|
||||
loss = mmdloss(X, Y)
|
||||
elif self.loss_type == "coral":
|
||||
loss = CORAL(X, Y)
|
||||
loss = CORAL(X, Y, self.device)
|
||||
elif self.loss_type in ("cosine", "cos"):
|
||||
loss = 1 - cosine(X, Y)
|
||||
elif self.loss_type == "kl":
|
||||
@@ -582,10 +585,10 @@ class TransferLoss:
|
||||
elif self.loss_type == "js":
|
||||
loss = js(X, Y)
|
||||
elif self.loss_type == "mine":
|
||||
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).cuda()
|
||||
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).to(self.device)
|
||||
loss = mine_model(X, Y)
|
||||
elif self.loss_type == "adv":
|
||||
loss = adv(X, Y, input_dim=self.input_dim, hidden_dim=32)
|
||||
loss = adv(X, Y, self.device, input_dim=self.input_dim, hidden_dim=32)
|
||||
elif self.loss_type == "mmd_rbf":
|
||||
mmdloss = MMD_loss(kernel_type="rbf")
|
||||
loss = mmdloss(X, Y)
|
||||
@@ -630,12 +633,12 @@ class Discriminator(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def adv(source, target, input_dim=256, hidden_dim=512):
|
||||
def adv(source, target, device, input_dim=256, hidden_dim=512):
|
||||
domain_loss = nn.BCELoss()
|
||||
# !!! Pay attention to .cuda !!!
|
||||
adv_net = Discriminator(input_dim, hidden_dim).cuda()
|
||||
domain_src = torch.ones(len(source)).cuda()
|
||||
domain_tar = torch.zeros(len(target)).cuda()
|
||||
adv_net = Discriminator(input_dim, hidden_dim).to(device)
|
||||
domain_src = torch.ones(len(source)).to(device)
|
||||
domain_tar = torch.zeros(len(target)).to(device)
|
||||
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
|
||||
reverse_src = ReverseLayerF.apply(source, 1)
|
||||
reverse_tar = ReverseLayerF.apply(target, 1)
|
||||
@@ -646,16 +649,16 @@ def adv(source, target, input_dim=256, hidden_dim=512):
|
||||
return loss
|
||||
|
||||
|
||||
def CORAL(source, target):
|
||||
def CORAL(source, target, device):
|
||||
d = source.size(1)
|
||||
ns, nt = source.size(0), target.size(0)
|
||||
|
||||
# source covariance
|
||||
tmp_s = torch.ones((1, ns)).cuda() @ source
|
||||
tmp_s = torch.ones((1, ns)).to(device) @ source
|
||||
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
|
||||
|
||||
# target covariance
|
||||
tmp_t = torch.ones((1, nt)).cuda() @ target
|
||||
tmp_t = torch.ones((1, nt)).to(device) @ target
|
||||
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
|
||||
|
||||
# frobenius norm
|
||||
|
||||
501
qlib/contrib/model/pytorch_hist.py
Normal file
501
qlib/contrib/model/pytorch_hist.py
Normal file
@@ -0,0 +1,501 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import urllib.request
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
|
||||
class HIST(Model):
|
||||
"""HIST Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lr : float
|
||||
learning rate
|
||||
d_feat : int
|
||||
input dimensions for each time step
|
||||
metric : str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
model_path=None,
|
||||
stock2concept=None,
|
||||
stock_index=None,
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("HIST")
|
||||
self.logger.info("HIST pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.model_path = model_path
|
||||
self.stock2concept = stock2concept
|
||||
self.stock_index = stock_index
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"HIST parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nstock2concept : {}"
|
||||
"\nstock_index : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
model_path,
|
||||
stock2concept,
|
||||
stock_index,
|
||||
GPU,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.HIST_model = HISTModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.HIST_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.HIST_model)))
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.HIST_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.HIST_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.HIST_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "ic":
|
||||
x = pred[mask]
|
||||
y = label[mask]
|
||||
|
||||
vx = x - torch.mean(x)
|
||||
vy = y - torch.mean(y)
|
||||
return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)))
|
||||
|
||||
if self.metric == ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def train_epoch(self, x_train, y_train, stock_index):
|
||||
|
||||
stock2concept_matrix = np.load(self.stock2concept)
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
stock_index = stock_index.values
|
||||
stock_index[np.isnan(stock_index)] = 733
|
||||
self.HIST_model.train()
|
||||
|
||||
# organize the train data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)
|
||||
concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index[batch]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[batch]).float().to(self.device)
|
||||
pred = self.HIST_model(feature, concept_matrix)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.HIST_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y, stock_index):
|
||||
|
||||
# prepare training data
|
||||
stock2concept_matrix = np.load(self.stock2concept)
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
stock_index = stock_index.values
|
||||
stock_index[np.isnan(stock_index)] = 733
|
||||
self.HIST_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
# organize the test data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index[batch]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[batch]).float().to(self.device)
|
||||
with torch.no_grad():
|
||||
pred = self.HIST_model(feature, concept_matrix)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
if not os.path.exists(self.stock2concept):
|
||||
url = "http://fintech.msra.cn/stock_data/downloads/qlib_csi300_stock2concept.npy"
|
||||
urllib.request.urlretrieve(url, self.stock2concept)
|
||||
|
||||
stock_index = np.load(self.stock_index, allow_pickle=True).item()
|
||||
df_train["stock_index"] = 733
|
||||
df_train["stock_index"] = df_train.index.get_level_values("instrument").map(stock_index)
|
||||
df_valid["stock_index"] = 733
|
||||
df_valid["stock_index"] = df_valid.index.get_level_values("instrument").map(stock_index)
|
||||
|
||||
x_train, y_train, stock_index_train = df_train["feature"], df_train["label"], df_train["stock_index"]
|
||||
x_valid, y_valid, stock_index_valid = df_valid["feature"], df_valid["label"], df_valid["stock_index"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
|
||||
model_dict = self.HIST_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.HIST_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train, stock_index_train)
|
||||
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train, stock_index_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid, stock_index_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.HIST_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.HIST_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
stock2concept_matrix = np.load(self.stock2concept)
|
||||
stock_index = np.load(self.stock_index, allow_pickle=True).item()
|
||||
df_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
df_test["stock_index"] = 733
|
||||
df_test["stock_index"] = df_test.index.get_level_values("instrument").map(stock_index)
|
||||
stock_index_test = df_test["stock_index"].values
|
||||
stock_index_test[np.isnan(stock_index_test)] = 733
|
||||
stock_index_test = stock_index_test.astype("int")
|
||||
df_test = df_test.drop(["stock_index"], axis=1)
|
||||
index = df_test.index
|
||||
|
||||
self.HIST_model.eval()
|
||||
x_values = df_test.values
|
||||
preds = []
|
||||
|
||||
# organize the data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(df_test, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index_test[batch]]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.HIST_model(x_batch, concept_matrix).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class HISTModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"):
|
||||
super().__init__()
|
||||
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
if base_model == "GRU":
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
elif base_model == "LSTM":
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
|
||||
self.fc_es = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es.weight)
|
||||
self.fc_is = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is.weight)
|
||||
|
||||
self.fc_es_middle = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es_middle.weight)
|
||||
self.fc_is_middle = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is_middle.weight)
|
||||
|
||||
self.fc_es_fore = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es_fore.weight)
|
||||
self.fc_is_fore = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is_fore.weight)
|
||||
self.fc_indi_fore = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_indi_fore.weight)
|
||||
|
||||
self.fc_es_back = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es_back.weight)
|
||||
self.fc_is_back = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is_back.weight)
|
||||
self.fc_indi = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_indi.weight)
|
||||
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
self.softmax_s2t = torch.nn.Softmax(dim=0)
|
||||
self.softmax_t2s = torch.nn.Softmax(dim=1)
|
||||
|
||||
self.fc_out_es = nn.Linear(hidden_size, 1)
|
||||
self.fc_out_is = nn.Linear(hidden_size, 1)
|
||||
self.fc_out_indi = nn.Linear(hidden_size, 1)
|
||||
self.fc_out = nn.Linear(hidden_size, 1)
|
||||
|
||||
def cal_cos_similarity(self, x, y): # the 2nd dimension of x and y are the same
|
||||
xy = x.mm(torch.t(y))
|
||||
x_norm = torch.sqrt(torch.sum(x * x, dim=1)).reshape(-1, 1)
|
||||
y_norm = torch.sqrt(torch.sum(y * y, dim=1)).reshape(-1, 1)
|
||||
cos_similarity = xy / (x_norm.mm(torch.t(y_norm)) + 1e-6)
|
||||
return cos_similarity
|
||||
|
||||
def forward(self, x, concept_matrix):
|
||||
device = torch.device(torch.get_device(x))
|
||||
|
||||
x_hidden = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x_hidden = x_hidden.permute(0, 2, 1) # [N, T, F]
|
||||
x_hidden, _ = self.rnn(x_hidden)
|
||||
x_hidden = x_hidden[:, -1, :]
|
||||
|
||||
# Predefined Concept Module
|
||||
|
||||
stock_to_concept = concept_matrix
|
||||
|
||||
stock_to_concept_sum = torch.sum(stock_to_concept, 0).reshape(1, -1).repeat(stock_to_concept.shape[0], 1)
|
||||
stock_to_concept_sum = stock_to_concept_sum.mul(concept_matrix)
|
||||
|
||||
stock_to_concept_sum = stock_to_concept_sum + (
|
||||
torch.ones(stock_to_concept.shape[0], stock_to_concept.shape[1]).to(device)
|
||||
)
|
||||
stock_to_concept = stock_to_concept / stock_to_concept_sum
|
||||
hidden = torch.t(stock_to_concept).mm(x_hidden)
|
||||
|
||||
hidden = hidden[hidden.sum(1) != 0]
|
||||
|
||||
concept_to_stock = self.cal_cos_similarity(x_hidden, hidden)
|
||||
concept_to_stock = self.softmax_t2s(concept_to_stock)
|
||||
|
||||
e_shared_info = concept_to_stock.mm(hidden)
|
||||
e_shared_info = self.fc_es(e_shared_info)
|
||||
|
||||
e_shared_back = self.fc_es_back(e_shared_info)
|
||||
output_es = self.fc_es_fore(e_shared_info)
|
||||
output_es = self.leaky_relu(output_es)
|
||||
|
||||
# Hidden Concept Module
|
||||
i_shared_info = x_hidden - e_shared_back
|
||||
hidden = i_shared_info
|
||||
i_stock_to_concept = self.cal_cos_similarity(i_shared_info, hidden)
|
||||
dim = i_stock_to_concept.shape[0]
|
||||
diag = i_stock_to_concept.diagonal(0)
|
||||
i_stock_to_concept = i_stock_to_concept * (torch.ones(dim, dim) - torch.eye(dim)).to(device)
|
||||
row = torch.linspace(0, dim - 1, dim).to(device).long()
|
||||
column = i_stock_to_concept.max(1)[1].long()
|
||||
value = i_stock_to_concept.max(1)[0]
|
||||
i_stock_to_concept[row, column] = 10
|
||||
i_stock_to_concept[i_stock_to_concept != 10] = 0
|
||||
i_stock_to_concept[row, column] = value
|
||||
i_stock_to_concept = i_stock_to_concept + torch.diag_embed((i_stock_to_concept.sum(0) != 0).float() * diag)
|
||||
hidden = torch.t(i_shared_info).mm(i_stock_to_concept).t()
|
||||
hidden = hidden[hidden.sum(1) != 0]
|
||||
|
||||
i_concept_to_stock = self.cal_cos_similarity(i_shared_info, hidden)
|
||||
i_concept_to_stock = self.softmax_t2s(i_concept_to_stock)
|
||||
i_shared_info = i_concept_to_stock.mm(hidden)
|
||||
i_shared_info = self.fc_is(i_shared_info)
|
||||
|
||||
i_shared_back = self.fc_is_back(i_shared_info)
|
||||
output_is = self.fc_is_fore(i_shared_info)
|
||||
output_is = self.leaky_relu(output_is)
|
||||
|
||||
# Individual Information Module
|
||||
individual_info = x_hidden - e_shared_back - i_shared_back
|
||||
output_indi = individual_info
|
||||
output_indi = self.fc_indi(output_indi)
|
||||
output_indi = self.leaky_relu(output_indi)
|
||||
|
||||
# Stock Trend Prediction
|
||||
all_info = output_es + output_is + output_indi
|
||||
pred_all = self.fc_out(all_info).squeeze()
|
||||
|
||||
return pred_all
|
||||
446
qlib/contrib/model/pytorch_igmtf.py
Normal file
446
qlib/contrib/model/pytorch_igmtf.py
Normal file
@@ -0,0 +1,446 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
|
||||
class IGMTF(Model):
|
||||
"""IGMTF Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
model_path=None,
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("IGMTF")
|
||||
self.logger.info("IMGTF pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"IGMTF parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
model_path,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.igmtf_model = IGMTFModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.igmtf_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.igmtf_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.igmtf_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.igmtf_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.igmtf_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "ic":
|
||||
x = pred[mask]
|
||||
y = label[mask]
|
||||
|
||||
vx = x - torch.mean(x)
|
||||
vy = y - torch.mean(y)
|
||||
return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)))
|
||||
|
||||
if self.metric == ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def get_train_hidden(self, x_train):
|
||||
x_train_values = x_train.values
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
self.igmtf_model.eval()
|
||||
train_hidden = []
|
||||
train_hidden_day = []
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)
|
||||
out = self.igmtf_model(feature, get_hidden=True)
|
||||
train_hidden.append(out.detach().cpu())
|
||||
train_hidden_day.append(out.detach().cpu().mean(dim=0).unsqueeze(dim=0))
|
||||
|
||||
train_hidden = np.asarray(train_hidden, dtype=object)
|
||||
train_hidden_day = torch.cat(train_hidden_day)
|
||||
|
||||
return train_hidden, train_hidden_day
|
||||
|
||||
def train_epoch(self, x_train, y_train, train_hidden, train_hidden_day):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.igmtf_model.train()
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[batch]).float().to(self.device)
|
||||
pred = self.igmtf_model(feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.igmtf_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y, train_hidden, train_hidden_day):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.igmtf_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[batch]).float().to(self.device)
|
||||
|
||||
pred = self.igmtf_model(feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.igmtf_model.state_dict()
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
|
||||
}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.igmtf_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
train_hidden, train_hidden_day = self.get_train_hidden(x_train)
|
||||
self.train_epoch(x_train, y_train, train_hidden, train_hidden_day)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train, train_hidden, train_hidden_day)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid, train_hidden, train_hidden_day)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.igmtf_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.igmtf_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_train = dataset.prepare("train", col_set="feature", data_key=DataHandlerLP.DK_L)
|
||||
train_hidden, train_hidden_day = self.get_train_hidden(x_train)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.igmtf_model.eval()
|
||||
x_values = x_test.values
|
||||
preds = []
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = (
|
||||
self.igmtf_model(x_batch, train_hidden=train_hidden, train_hidden_day=train_hidden_day)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class IGMTFModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"):
|
||||
super().__init__()
|
||||
|
||||
if base_model == "GRU":
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
elif base_model == "LSTM":
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
self.lins = nn.Sequential()
|
||||
for i in range(2):
|
||||
self.lins.add_module("linear" + str(i), nn.Linear(hidden_size, hidden_size))
|
||||
self.lins.add_module("leakyrelu" + str(i), nn.LeakyReLU())
|
||||
self.fc_output = nn.Linear(hidden_size * 2, hidden_size * 2)
|
||||
self.project1 = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.project2 = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.fc_out_pred = nn.Linear(hidden_size * 2, 1)
|
||||
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
self.d_feat = d_feat
|
||||
|
||||
def cal_cos_similarity(self, x, y): # the 2nd dimension of x and y are the same
|
||||
xy = x.mm(torch.t(y))
|
||||
x_norm = torch.sqrt(torch.sum(x * x, dim=1)).reshape(-1, 1)
|
||||
y_norm = torch.sqrt(torch.sum(y * y, dim=1)).reshape(-1, 1)
|
||||
cos_similarity = xy / (x_norm.mm(torch.t(y_norm)) + 1e-6)
|
||||
return cos_similarity
|
||||
|
||||
def sparse_dense_mul(self, s, d):
|
||||
i = s._indices()
|
||||
v = s._values()
|
||||
dv = d[i[0, :], i[1, :]] # get values from relevant entries of dense matrix
|
||||
return torch.sparse.FloatTensor(i, v * dv, s.size())
|
||||
|
||||
def forward(self, x, get_hidden=False, train_hidden=None, train_hidden_day=None, k_day=10, n_neighbor=10):
|
||||
# x: [N, F*T]
|
||||
device = x.device
|
||||
x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x = x.permute(0, 2, 1) # [N, T, F]
|
||||
out, _ = self.rnn(x)
|
||||
out = out[:, -1, :]
|
||||
out = self.lins(out)
|
||||
mini_batch_out = out
|
||||
if get_hidden is True:
|
||||
return mini_batch_out
|
||||
|
||||
mini_batch_out_day = torch.mean(mini_batch_out, dim=0).unsqueeze(0)
|
||||
day_similarity = self.cal_cos_similarity(mini_batch_out_day, train_hidden_day.to(device))
|
||||
day_index = torch.topk(day_similarity, k_day, dim=1)[1]
|
||||
sample_train_hidden = train_hidden[day_index.long().cpu()].squeeze()
|
||||
sample_train_hidden = torch.cat(list(sample_train_hidden)).to(device)
|
||||
sample_train_hidden = self.lins(sample_train_hidden)
|
||||
cos_similarity = self.cal_cos_similarity(self.project1(mini_batch_out), self.project2(sample_train_hidden))
|
||||
|
||||
row = (
|
||||
torch.linspace(0, x.shape[0] - 1, x.shape[0])
|
||||
.reshape([-1, 1])
|
||||
.repeat(1, n_neighbor)
|
||||
.reshape(1, -1)
|
||||
.to(device)
|
||||
)
|
||||
column = torch.topk(cos_similarity, n_neighbor, dim=1)[1].reshape(1, -1)
|
||||
mask = torch.sparse_coo_tensor(
|
||||
torch.cat([row, column]),
|
||||
torch.ones([row.shape[1]]).to(device) / n_neighbor,
|
||||
(x.shape[0], sample_train_hidden.shape[0]),
|
||||
)
|
||||
cos_similarity = self.sparse_dense_mul(mask, cos_similarity)
|
||||
|
||||
agg_out = torch.sparse.mm(cos_similarity, self.project2(sample_train_hidden))
|
||||
# out = self.fc_out(out).squeeze()
|
||||
out = self.fc_out_pred(torch.cat([mini_batch_out, agg_out], axis=1)).squeeze()
|
||||
return out
|
||||
@@ -84,7 +84,7 @@ class SFM_Model(nn.Module):
|
||||
if len(self.states) == 0: # hasn't initialized yet
|
||||
self.init_states(x)
|
||||
self.get_constants(x)
|
||||
p_tm1 = self.states[0]
|
||||
p_tm1 = self.states[0] # noqa: F841
|
||||
h_tm1 = self.states[1]
|
||||
S_re_tm1 = self.states[2]
|
||||
S_im_tm1 = self.states[3]
|
||||
|
||||
@@ -477,10 +477,10 @@ class TabNet(nn.Module):
|
||||
sparse_loss = []
|
||||
out = torch.zeros(x.size(0), self.n_d).to(x.device)
|
||||
for step in self.steps:
|
||||
x_te, l = step(x, x_a, priors)
|
||||
x_te, loss = step(x, x_a, priors)
|
||||
out += F.relu(x_te[:, : self.n_d]) # split the feature from feat_transformer
|
||||
x_a = x_te[:, self.n_d :]
|
||||
sparse_loss.append(l)
|
||||
sparse_loss.append(loss)
|
||||
return self.fc(out), sum(sparse_loss)
|
||||
|
||||
|
||||
|
||||
@@ -145,7 +145,7 @@ class TCTS(Model):
|
||||
|
||||
init_fore_model = copy.deepcopy(self.fore_model)
|
||||
for p in init_fore_model.parameters():
|
||||
p.init_fore_model = False
|
||||
p.requires_grad = False
|
||||
|
||||
self.fore_model.train()
|
||||
self.weight_model.train()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
'''
|
||||
TODO:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import yaml
|
||||
import pathlib
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import random
|
||||
import pandas as pd
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import pathlib
|
||||
import pickle
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
|
||||
from qlib.data.cache import H
|
||||
from qlib.data.data import Cal
|
||||
from qlib.data.ops import ElemOperator
|
||||
from qlib.data.ops import ElemOperator, PairOperator
|
||||
from qlib.utils.time import time_to_day_index
|
||||
|
||||
|
||||
@@ -35,6 +36,17 @@ def get_calendar_day(freq="1min", future=False):
|
||||
return _calendar
|
||||
|
||||
|
||||
def get_calendar_minute(freq="day", future=False):
|
||||
"""Load High-Freq Calendar Minute Using Memcache"""
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(list(map(lambda x: x.minute // 30, Cal.load_calendar(freq, future))))
|
||||
H["c"][flag] = _calendar
|
||||
return _calendar
|
||||
|
||||
|
||||
class DayCumsum(ElemOperator):
|
||||
"""DayCumsum Operator during start time and end time.
|
||||
|
||||
@@ -83,3 +95,181 @@ class DayCumsum(ElemOperator):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform(self.period_cusum)
|
||||
|
||||
|
||||
class DayLast(ElemOperator):
|
||||
"""DayLast Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value equals the last value of its day
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("last")
|
||||
|
||||
|
||||
class FFillNan(ElemOperator):
|
||||
"""FFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a forward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="ffill")
|
||||
|
||||
|
||||
class BFillNan(ElemOperator):
|
||||
"""BFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a backfoward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="bfill")
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
"""Date Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value is the date corresponding to feature.index
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return pd.Series(_calendar[series.index], index=series.index)
|
||||
|
||||
|
||||
class Select(PairOperator):
|
||||
"""Select Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature_left : Expression
|
||||
feature instance, select condition
|
||||
feature_right : Expression
|
||||
feature instance, select value
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
value(feature_right) that meets the condition(feature_left)
|
||||
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_condition = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
series_feature = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
return series_feature.loc[series_condition]
|
||||
|
||||
|
||||
class IsNull(ElemOperator):
|
||||
"""IsNull Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series indicating whether the feature is nan
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.isnull()
|
||||
|
||||
|
||||
class IsInf(ElemOperator):
|
||||
"""IsInf Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series indicating whether the feature is inf
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return np.isinf(series)
|
||||
|
||||
|
||||
class Cut(ElemOperator):
|
||||
"""Cut Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
l : int
|
||||
l > 0, delete the first l elements of feature (default is None, which means 0)
|
||||
r : int
|
||||
r < 0, delete the last -r elements of feature (default is None, which means 0)
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series with the first l and last -r elements deleted from the feature.
|
||||
Note: It is deleted from the raw data, not the sliced data
|
||||
"""
|
||||
|
||||
def __init__(self, feature, left=None, right=None):
|
||||
self.left = left
|
||||
self.right = right
|
||||
if (self.left is not None and self.left <= 0) or (self.right is not None and self.right >= 0):
|
||||
raise ValueError("Cut operator l shoud > 0 and r should < 0")
|
||||
|
||||
super(Cut, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.iloc[self.left : self.right]
|
||||
|
||||
def get_extended_window_size(self):
|
||||
ll = 0 if self.left is None else self.left
|
||||
rr = 0 if self.right is None else abs(self.right)
|
||||
lft_etd, rght_etd = self.feature.get_extended_window_size()
|
||||
lft_etd = lft_etd + ll
|
||||
rght_etd = rght_etd + rr
|
||||
return lft_etd, rght_etd
|
||||
|
||||
@@ -2,3 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .analysis_model_performance import model_performance_graph
|
||||
|
||||
|
||||
__all__ = ["model_performance_graph"]
|
||||
|
||||
@@ -6,3 +6,6 @@ from .score_ic import score_ic_graph
|
||||
from .report import report_graph
|
||||
from .rank_label import rank_label_graph
|
||||
from .risk_analysis import risk_analysis_graph
|
||||
|
||||
|
||||
__all__ = ["cumulative_return_graph", "score_ic_graph", "report_graph", "rank_label_graph", "risk_analysis_graph"]
|
||||
|
||||
@@ -68,9 +68,9 @@ def parse_position(position: dict = None) -> pd.DataFrame:
|
||||
if not _trading_day_sell_df.empty:
|
||||
_trading_day_sell_df["status"] = -1
|
||||
_trading_day_sell_df["date"] = _trading_date
|
||||
_trading_day_df = _trading_day_df.append(_trading_day_sell_df, sort=False)
|
||||
_trading_day_df = pd.concat([_trading_day_df, _trading_day_sell_df], sort=False)
|
||||
|
||||
result_df = result_df.append(_trading_day_df, sort=True)
|
||||
result_df = pd.concat([result_df, _trading_day_df], sort=True)
|
||||
|
||||
previous_data = dict(
|
||||
date=_trading_date,
|
||||
|
||||
@@ -85,7 +85,7 @@ def _get_monthly_risk_analysis_with_report(report_normal_df: pd.DataFrame) -> pd
|
||||
# _m_report_long_short,
|
||||
pd.Timestamp(year=gp_m[0], month=gp_m[1], day=month_days),
|
||||
)
|
||||
_monthly_df = _monthly_df.append(_temp_df, sort=False)
|
||||
_monthly_df = pd.concat([_monthly_df, _temp_df], sort=False)
|
||||
|
||||
return _monthly_df
|
||||
|
||||
|
||||
@@ -15,3 +15,14 @@ from .rule_strategy import (
|
||||
)
|
||||
|
||||
from .cost_control import SoftTopkStrategy
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TopkDropoutStrategy",
|
||||
"WeightStrategyBase",
|
||||
"EnhancedIndexingStrategy",
|
||||
"TWAPStrategy",
|
||||
"SBBStrategyBase",
|
||||
"SBBStrategyEMA",
|
||||
"SoftTopkStrategy",
|
||||
]
|
||||
|
||||
@@ -4,3 +4,6 @@
|
||||
from .base import BaseOptimizer
|
||||
from .optimizer import PortfolioOptimizer
|
||||
from .enhanced_indexing import EnhancedIndexingOptimizer
|
||||
|
||||
|
||||
__all__ = ["BaseOptimizer", "PortfolioOptimizer", "EnhancedIndexingOptimizer"]
|
||||
|
||||
@@ -131,10 +131,10 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
if self.only_tradable:
|
||||
# If The strategy only consider tradable stock when make decision
|
||||
# It needs following actions to filter stocks
|
||||
def get_first_n(l, n, reverse=False):
|
||||
def get_first_n(li, n, reverse=False):
|
||||
cur_n = 0
|
||||
res = []
|
||||
for si in reversed(l) if reverse else l:
|
||||
for si in reversed(li) if reverse else li:
|
||||
if self.trade_exchange.is_stock_tradable(
|
||||
stock_id=si, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
@@ -144,13 +144,13 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
break
|
||||
return res[::-1] if reverse else res
|
||||
|
||||
def get_last_n(l, n):
|
||||
return get_first_n(l, n, reverse=True)
|
||||
def get_last_n(li, n):
|
||||
return get_first_n(li, n, reverse=True)
|
||||
|
||||
def filter_stock(l):
|
||||
def filter_stock(li):
|
||||
return [
|
||||
si
|
||||
for si in l
|
||||
for si in li
|
||||
if self.trade_exchange.is_stock_tradable(
|
||||
stock_id=si, start_time=trade_start_time, end_time=trade_end_time
|
||||
)
|
||||
@@ -158,14 +158,14 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
|
||||
else:
|
||||
# Otherwise, the stock will make decision with out the stock tradable info
|
||||
def get_first_n(l, n):
|
||||
return list(l)[:n]
|
||||
def get_first_n(li, n):
|
||||
return list(li)[:n]
|
||||
|
||||
def get_last_n(l, n):
|
||||
return list(l)[-n:]
|
||||
def get_last_n(li, n):
|
||||
return list(li)[-n:]
|
||||
|
||||
def filter_stock(l):
|
||||
return l
|
||||
def filter_stock(li):
|
||||
return li
|
||||
|
||||
current_temp = copy.deepcopy(self.trade_position)
|
||||
# generate order list for this adjust date
|
||||
@@ -203,7 +203,7 @@ class TopkDropoutStrategy(BaseSignalStrategy):
|
||||
candi = filter_stock(last)
|
||||
try:
|
||||
sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else [])
|
||||
except ValueError: # No enough candidates
|
||||
except ValueError: # No enough candidates
|
||||
sell = candi
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import yaml
|
||||
import copy
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
# coding=utf-8
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
from hyperopt import hp
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import os
|
||||
import yaml
|
||||
|
||||
@@ -2,3 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
from .record_temp import MultiSegRecord
|
||||
from .record_temp import SignalMseRecord
|
||||
|
||||
|
||||
__all__ = ["MultiSegRecord", "SignalMseRecord"]
|
||||
|
||||
@@ -15,6 +15,7 @@ from .data import (
|
||||
LocalCalendarProvider,
|
||||
LocalInstrumentProvider,
|
||||
LocalFeatureProvider,
|
||||
LocalPITProvider,
|
||||
LocalExpressionProvider,
|
||||
LocalDatasetProvider,
|
||||
ClientCalendarProvider,
|
||||
@@ -34,3 +35,32 @@ from .cache import (
|
||||
DatasetURICache,
|
||||
MemoryCalendarCache,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"D",
|
||||
"CalendarProvider",
|
||||
"InstrumentProvider",
|
||||
"FeatureProvider",
|
||||
"ExpressionProvider",
|
||||
"DatasetProvider",
|
||||
"LocalCalendarProvider",
|
||||
"LocalInstrumentProvider",
|
||||
"LocalFeatureProvider",
|
||||
"LocalPITProvider",
|
||||
"LocalExpressionProvider",
|
||||
"LocalDatasetProvider",
|
||||
"ClientCalendarProvider",
|
||||
"ClientInstrumentProvider",
|
||||
"ClientDatasetProvider",
|
||||
"BaseProvider",
|
||||
"LocalProvider",
|
||||
"ClientProvider",
|
||||
"ExpressionCache",
|
||||
"DatasetCache",
|
||||
"DiskExpressionCache",
|
||||
"DiskDatasetCache",
|
||||
"SimpleDatasetCache",
|
||||
"DatasetURICache",
|
||||
"MemoryCalendarCache",
|
||||
]
|
||||
|
||||
@@ -6,12 +6,20 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
|
||||
import pandas as pd
|
||||
from ..log import get_module_logger
|
||||
|
||||
|
||||
class Expression(abc.ABC):
|
||||
"""Expression base class"""
|
||||
"""
|
||||
Expression base class
|
||||
|
||||
Expression is designed to handle the calculation of data with the format below
|
||||
data with two dimension for each instrument,
|
||||
- feature
|
||||
- time: it could be observation time or period time.
|
||||
- period time is designed for Point-in-time database. For example, the period time maybe 2014Q4, its value can observed for multiple times(different value may be observed at different time due to amendment).
|
||||
"""
|
||||
|
||||
def __str__(self):
|
||||
return type(self).__name__
|
||||
@@ -104,6 +112,11 @@ class Expression(abc.ABC):
|
||||
|
||||
return Power(self, other)
|
||||
|
||||
def __rpow__(self, other):
|
||||
from .ops import Power # pylint: disable=C0415
|
||||
|
||||
return Power(other, self)
|
||||
|
||||
def __and__(self, other):
|
||||
from .ops import And # pylint: disable=C0415
|
||||
|
||||
@@ -124,8 +137,18 @@ class Expression(abc.ABC):
|
||||
|
||||
return Or(other, self)
|
||||
|
||||
def load(self, instrument, start_index, end_index, freq):
|
||||
def load(self, instrument, start_index, end_index, *args):
|
||||
"""load feature
|
||||
This function is responsible for loading feature/expression based on the expression engine.
|
||||
|
||||
The concrete implementation will be separated into two parts:
|
||||
1) caching data, handle errors.
|
||||
- This part is shared by all the expressions and implemented in Expression
|
||||
2) processing and calculating data based on the specific expression.
|
||||
- This part is different in each expression and implemented in each expression
|
||||
|
||||
Expression Engine is shared by different data.
|
||||
Different data will have different extra information for `args`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -135,8 +158,18 @@ class Expression(abc.ABC):
|
||||
feature start index [in calendar].
|
||||
end_index : str
|
||||
feature end index [in calendar].
|
||||
freq : str
|
||||
feature frequency.
|
||||
|
||||
*args may contain following information:
|
||||
1) if it is used in basic expression engine data, it contains following arguments
|
||||
freq: str
|
||||
feature frequency.
|
||||
|
||||
2) if is used in PIT data, it contains following arguments
|
||||
cur_pit:
|
||||
it is designed for the point-in-time data.
|
||||
period: int
|
||||
This is used for query specific period.
|
||||
The period is represented with int in Qlib. (e.g. 202001 may represent the first quarter in 2020)
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -146,26 +179,26 @@ class Expression(abc.ABC):
|
||||
from .cache import H # pylint: disable=C0415
|
||||
|
||||
# cache
|
||||
args = str(self), instrument, start_index, end_index, freq
|
||||
if args in H["f"]:
|
||||
return H["f"][args]
|
||||
cache_key = str(self), instrument, start_index, end_index, *args
|
||||
if cache_key in H["f"]:
|
||||
return H["f"][cache_key]
|
||||
if start_index is not None and end_index is not None and start_index > end_index:
|
||||
raise ValueError("Invalid index range: {} {}".format(start_index, end_index))
|
||||
try:
|
||||
series = self._load_internal(instrument, start_index, end_index, freq)
|
||||
series = self._load_internal(instrument, start_index, end_index, *args)
|
||||
except Exception as e:
|
||||
get_module_logger("data").debug(
|
||||
f"Loading data error: instrument={instrument}, expression={str(self)}, "
|
||||
f"start_index={start_index}, end_index={end_index}, freq={freq}. "
|
||||
f"start_index={start_index}, end_index={end_index}, args={args}. "
|
||||
f"error info: {str(e)}"
|
||||
)
|
||||
raise
|
||||
series.name = str(self)
|
||||
H["f"][args] = series
|
||||
H["f"][cache_key] = series
|
||||
return series
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
def _load_internal(self, instrument, start_index, end_index, *args) -> pd.Series:
|
||||
raise NotImplementedError("This function must be implemented in your newly defined feature")
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -225,6 +258,16 @@ class Feature(Expression):
|
||||
return 0, 0
|
||||
|
||||
|
||||
class PFeature(Feature):
|
||||
def __str__(self):
|
||||
return "$$" + self._name
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, cur_time, period=None):
|
||||
from .data import PITD # pylint: disable=C0415
|
||||
|
||||
return PITD.period_feature(instrument, str(self), start_index, end_index, cur_time, period)
|
||||
|
||||
|
||||
class ExpressionOps(Expression):
|
||||
"""Operator Expression
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from ..utils import (
|
||||
|
||||
from ..log import get_module_logger
|
||||
from .base import Feature
|
||||
from .ops import Operators # pylint: disable=W0611
|
||||
from .ops import Operators # pylint: disable=W0611 # noqa: F401
|
||||
|
||||
|
||||
class QlibCacheException(RuntimeError):
|
||||
@@ -528,7 +528,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
CacheUtils.visit(cache_path)
|
||||
series = read_bin(cache_path, start_index, end_index)
|
||||
return series
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
series = None
|
||||
self.logger.error("reading %s file error : %s" % (cache_path, traceback.format_exc()))
|
||||
return series
|
||||
@@ -1068,7 +1068,7 @@ class SimpleDatasetCache(DatasetCache):
|
||||
super(SimpleDatasetCache, self).__init__(provider)
|
||||
try:
|
||||
self.local_cache_path: Path = Path(C["local_cache_path"]).expanduser().resolve()
|
||||
except (KeyError, TypeError) as e:
|
||||
except (KeyError, TypeError):
|
||||
self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism")
|
||||
raise
|
||||
self.logger.info(
|
||||
|
||||
@@ -12,7 +12,7 @@ import queue
|
||||
import bisect
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
# For supporting multiprocessing in outer code, joblib is used
|
||||
from joblib import delayed
|
||||
@@ -34,9 +34,11 @@ from ..utils import (
|
||||
code_to_fname,
|
||||
set_log_with_config,
|
||||
time_to_slc_point,
|
||||
read_period_data,
|
||||
get_period_list,
|
||||
)
|
||||
from ..utils.paral import ParallelExt
|
||||
from .ops import Operators # pylint: disable=W0611
|
||||
from .ops import Operators # pylint: disable=W0611 # noqa: F401
|
||||
|
||||
|
||||
class ProviderBackendMixin:
|
||||
@@ -331,6 +333,51 @@ class FeatureProvider(abc.ABC):
|
||||
raise NotImplementedError("Subclass of FeatureProvider must implement `feature` method")
|
||||
|
||||
|
||||
class PITProvider(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def period_feature(
|
||||
self,
|
||||
instrument,
|
||||
field,
|
||||
start_index: int,
|
||||
end_index: int,
|
||||
cur_time: pd.Timestamp,
|
||||
period: Optional[int] = None,
|
||||
) -> pd.Series:
|
||||
"""
|
||||
get the historical periods data series between `start_index` and `end_index`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_index: int
|
||||
start_index is a relative index to the latest period to cur_time
|
||||
|
||||
end_index: int
|
||||
end_index is a relative index to the latest period to cur_time
|
||||
in most cases, the start_index and end_index will be a non-positive values
|
||||
For example, start_index == -3 end_index == 0 and current period index is cur_idx,
|
||||
then the data between [start_index + cur_idx, end_index + cur_idx] will be retrieved.
|
||||
|
||||
period: int
|
||||
This is used for query specific period.
|
||||
The period is represented with int in Qlib. (e.g. 202001 may represent the first quarter in 2020)
|
||||
NOTE: `period` will override `start_index` and `end_index`
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.Series
|
||||
The index will be integers to indicate the periods of the data
|
||||
An typical examples will be
|
||||
TODO
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
This exception will be raised if the queried data do not exist.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `period_feature` method")
|
||||
|
||||
|
||||
class ExpressionProvider(abc.ABC):
|
||||
"""Expression provider class
|
||||
|
||||
@@ -583,7 +630,7 @@ class DatasetProvider(abc.ABC):
|
||||
for _processor in inst_processors:
|
||||
if _processor:
|
||||
_processor_obj = init_instance_by_config(_processor, accept_types=InstProcessor)
|
||||
data = _processor_obj(data)
|
||||
data = _processor_obj(data, instrument=inst)
|
||||
return data
|
||||
|
||||
|
||||
@@ -694,6 +741,95 @@ class LocalFeatureProvider(FeatureProvider, ProviderBackendMixin):
|
||||
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
|
||||
|
||||
|
||||
class LocalPITProvider(PITProvider):
|
||||
# TODO: Add PIT backend file storage
|
||||
# NOTE: This class is not multi-threading-safe!!!!
|
||||
|
||||
def period_feature(self, instrument, field, start_index, end_index, cur_time, period=None):
|
||||
if not isinstance(cur_time, pd.Timestamp):
|
||||
raise ValueError(
|
||||
f"Expected pd.Timestamp for `cur_time`, got '{cur_time}'. Advices: you can't query PIT data directly(e.g. '$$roewa_q'), you must use `P` operator to convert data to each day (e.g. 'P($$roewa_q)')"
|
||||
)
|
||||
|
||||
assert end_index <= 0 # PIT don't support querying future data
|
||||
|
||||
DATA_RECORDS = [
|
||||
("date", C.pit_record_type["date"]),
|
||||
("period", C.pit_record_type["period"]),
|
||||
("value", C.pit_record_type["value"]),
|
||||
("_next", C.pit_record_type["index"]),
|
||||
]
|
||||
VALUE_DTYPE = C.pit_record_type["value"]
|
||||
|
||||
field = str(field).lower()[2:]
|
||||
instrument = code_to_fname(instrument)
|
||||
|
||||
# {For acceleration
|
||||
# start_index, end_index, cur_index = kwargs["info"]
|
||||
# if cur_index == start_index:
|
||||
# if not hasattr(self, "all_fields"):
|
||||
# self.all_fields = []
|
||||
# self.all_fields.append(field)
|
||||
# if not hasattr(self, "period_index"):
|
||||
# self.period_index = {}
|
||||
# if field not in self.period_index:
|
||||
# self.period_index[field] = {}
|
||||
# For acceleration}
|
||||
|
||||
if not field.endswith("_q") and not field.endswith("_a"):
|
||||
raise ValueError("period field must ends with '_q' or '_a'")
|
||||
quarterly = field.endswith("_q")
|
||||
index_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.index"
|
||||
data_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.data"
|
||||
if not (index_path.exists() and data_path.exists()):
|
||||
raise FileNotFoundError("No file is found. Raise exception and ")
|
||||
# NOTE: The most significant performance loss is here.
|
||||
# Does the acceleration that makes the program complicated really matters?
|
||||
# - It makes parameters of the interface complicate
|
||||
# - It does not performance in the optimal way (places all the pieces together, we may achieve higher performance)
|
||||
# - If we design it carefully, we can go through for only once to get the historical evolution of the data.
|
||||
# So I decide to deprecated previous implementation and keep the logic of the program simple
|
||||
# Instead, I'll add a cache for the index file.
|
||||
data = np.fromfile(data_path, dtype=DATA_RECORDS)
|
||||
|
||||
# find all revision periods before `cur_time`
|
||||
cur_time_int = int(cur_time.year) * 10000 + int(cur_time.month) * 100 + int(cur_time.day)
|
||||
loc = np.searchsorted(data["date"], cur_time_int, side="right")
|
||||
if loc <= 0:
|
||||
return pd.Series()
|
||||
last_period = data["period"][:loc].max() # return the latest quarter
|
||||
first_period = data["period"][:loc].min()
|
||||
period_list = get_period_list(first_period, last_period, quarterly)
|
||||
if period is not None:
|
||||
# NOTE: `period` has higher priority than `start_index` & `end_index`
|
||||
if period not in period_list:
|
||||
return pd.Series()
|
||||
else:
|
||||
period_list = [period]
|
||||
else:
|
||||
period_list = period_list[max(0, len(period_list) + start_index - 1) : len(period_list) + end_index]
|
||||
value = np.full((len(period_list),), np.nan, dtype=VALUE_DTYPE)
|
||||
for i, p in enumerate(period_list):
|
||||
# last_period_index = self.period_index[field].get(period) # For acceleration
|
||||
value[i], now_period_index = read_period_data(
|
||||
index_path, data_path, p, cur_time_int, quarterly # , last_period_index # For acceleration
|
||||
)
|
||||
# self.period_index[field].update({period: now_period_index}) # For acceleration
|
||||
# NOTE: the index is period_list; So it may result in unexpected values(e.g. nan)
|
||||
# when calculation between different features and only part of its financial indicator is published
|
||||
series = pd.Series(value, index=period_list, dtype=VALUE_DTYPE)
|
||||
|
||||
# {For acceleration
|
||||
# if cur_index == end_index:
|
||||
# self.all_fields.remove(field)
|
||||
# if not len(self.all_fields):
|
||||
# del self.all_fields
|
||||
# del self.period_index
|
||||
# For acceleration}
|
||||
|
||||
return series
|
||||
|
||||
|
||||
class LocalExpressionProvider(ExpressionProvider):
|
||||
"""Local expression data provider class
|
||||
|
||||
@@ -1003,6 +1139,8 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
|
||||
class BaseProvider:
|
||||
"""Local provider class
|
||||
It is a set of interface that allow users to access data.
|
||||
Because PITD is not exposed publicly to users, so it is not included in the interface.
|
||||
|
||||
To keep compatible with old qlib provider.
|
||||
"""
|
||||
@@ -1126,6 +1264,7 @@ if sys.version_info >= (3, 9):
|
||||
CalendarProviderWrapper = Annotated[CalendarProvider, Wrapper]
|
||||
InstrumentProviderWrapper = Annotated[InstrumentProvider, Wrapper]
|
||||
FeatureProviderWrapper = Annotated[FeatureProvider, Wrapper]
|
||||
PITProviderWrapper = Annotated[PITProvider, Wrapper]
|
||||
ExpressionProviderWrapper = Annotated[ExpressionProvider, Wrapper]
|
||||
DatasetProviderWrapper = Annotated[DatasetProvider, Wrapper]
|
||||
BaseProviderWrapper = Annotated[BaseProvider, Wrapper]
|
||||
@@ -1133,6 +1272,7 @@ else:
|
||||
CalendarProviderWrapper = CalendarProvider
|
||||
InstrumentProviderWrapper = InstrumentProvider
|
||||
FeatureProviderWrapper = FeatureProvider
|
||||
PITProviderWrapper = PITProvider
|
||||
ExpressionProviderWrapper = ExpressionProvider
|
||||
DatasetProviderWrapper = DatasetProvider
|
||||
BaseProviderWrapper = BaseProvider
|
||||
@@ -1140,6 +1280,7 @@ else:
|
||||
Cal: CalendarProviderWrapper = Wrapper()
|
||||
Inst: InstrumentProviderWrapper = Wrapper()
|
||||
FeatureD: FeatureProviderWrapper = Wrapper()
|
||||
PITD: PITProviderWrapper = Wrapper()
|
||||
ExpressionD: ExpressionProviderWrapper = Wrapper()
|
||||
DatasetD: DatasetProviderWrapper = Wrapper()
|
||||
D: BaseProviderWrapper = Wrapper()
|
||||
@@ -1165,6 +1306,11 @@ def register_all_wrappers(C):
|
||||
register_wrapper(FeatureD, feature_provider, "qlib.data")
|
||||
logger.debug(f"registering FeatureD {C.feature_provider}")
|
||||
|
||||
if getattr(C, "pit_provider", None) is not None:
|
||||
pit_provider = init_instance_by_config(C.pit_provider, module)
|
||||
register_wrapper(PITD, pit_provider, "qlib.data")
|
||||
logger.debug(f"registering PITD {C.pit_provider}")
|
||||
|
||||
if getattr(C, "expression_provider", None) is not None:
|
||||
# This provider is unnecessary in client provider
|
||||
_eprovider = init_instance_by_config(C.expression_provider, module)
|
||||
|
||||
@@ -171,6 +171,7 @@ class DatasetH(Dataset):
|
||||
Parameters
|
||||
----------
|
||||
slc : please refer to the docs of `prepare`
|
||||
NOTE: it may not be an instance of slice. It may be a segment of `segments` from `def prepare`
|
||||
"""
|
||||
if hasattr(self, "fetch_kwargs"):
|
||||
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
|
||||
@@ -199,6 +200,9 @@ class DatasetH(Dataset):
|
||||
|
||||
col_set : str
|
||||
The col_set will be passed to self.handler when fetching data.
|
||||
TODO: make it automatic:
|
||||
- select DK_I for test data
|
||||
- select DK_L for training data.
|
||||
data_key : str
|
||||
The data to fetch: DK_*
|
||||
Default is DK_I, which indicate fetching data for **inference**.
|
||||
@@ -346,7 +350,7 @@ class TSDataSampler:
|
||||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data is True)[0]]
|
||||
self.data_index = self.data_index[np.where(self.flt_data)[0]]
|
||||
self.idx_map = self.idx_map2arr(self.idx_map)
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(
|
||||
@@ -609,3 +613,6 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
|
||||
return tsds
|
||||
|
||||
|
||||
__all__ = ["Optional"]
|
||||
|
||||
@@ -515,7 +515,7 @@ class DataHandlerLP(DataHandler):
|
||||
# data for learning
|
||||
# 1) assign
|
||||
if self.process_type == DataHandlerLP.PTYPE_I:
|
||||
_learn_df = self._data
|
||||
_learn_df = _shared_df
|
||||
elif self.process_type == DataHandlerLP.PTYPE_A:
|
||||
# based on `infer_df` and append the processor
|
||||
_learn_df = _infer_df
|
||||
|
||||
@@ -187,7 +187,13 @@ class Fillna(Processor):
|
||||
df.fillna(self.fill_value, inplace=True)
|
||||
else:
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
df.fillna({col: self.fill_value for col in cols}, inplace=True)
|
||||
# this implementation is extremely slow
|
||||
# df.fillna({col: self.fill_value for col in cols}, inplace=True)
|
||||
|
||||
# So we use numpy to accelerate filling values
|
||||
nan_select = np.isnan(df.values)
|
||||
nan_select[:, ~df.columns.isin(cols)] = False
|
||||
df.values[nan_select] = self.fill_value
|
||||
return df
|
||||
|
||||
|
||||
@@ -318,6 +324,20 @@ class CSRankNorm(Processor):
|
||||
The operations across different stocks are often called Cross Sectional Operation.
|
||||
|
||||
For example, CSRankNorm is an operation that grouping the data by each day and rank `across` all the stocks in each day.
|
||||
|
||||
Explanation about 3.46 & 0.5
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
x = np.random.random(10000) # for any variable
|
||||
x_rank = pd.Series(x).rank(pct=True) # if it is converted to rank, it will be a uniform distributed
|
||||
x_rank_norm = (x_rank - x_rank.mean()) / x_rank.std() # Normally, we will normalize it to make it like normal distribution
|
||||
|
||||
x_rank.mean() # accounts for 0.5
|
||||
1 / x_rank.std() # accounts for 3.46
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, fields_group=None):
|
||||
|
||||
@@ -5,7 +5,7 @@ import pandas as pd
|
||||
|
||||
class InstProcessor:
|
||||
@abc.abstractmethod
|
||||
def __call__(self, df: pd.DataFrame, *args, **kwargs):
|
||||
def __call__(self, df: pd.DataFrame, instrument, *args, **kwargs):
|
||||
"""
|
||||
process the data
|
||||
|
||||
|
||||
185
qlib/data/ops.py
185
qlib/data/ops.py
@@ -10,9 +10,7 @@ import pandas as pd
|
||||
|
||||
from typing import Union, List, Type
|
||||
from scipy.stats import percentileofscore
|
||||
|
||||
from .base import Expression, ExpressionOps, Feature
|
||||
|
||||
from .base import Expression, ExpressionOps, Feature, PFeature
|
||||
from ..log import get_module_logger
|
||||
from ..utils import get_callable_kwargs
|
||||
|
||||
@@ -24,7 +22,7 @@ except ImportError:
|
||||
"#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####"
|
||||
)
|
||||
raise
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
print("!!!!!!!! A error occurs when importing operators implemented based on Cython.!!!!!!!!")
|
||||
print("!!!!!!!! They will be disabled. Please Upgrade your numpy to enable them !!!!!!!!")
|
||||
# We catch this error because some platform can't upgrade there package (e.g. Kaggle)
|
||||
@@ -84,8 +82,8 @@ class NpElemOperator(ElemOperator):
|
||||
self.func = func
|
||||
super(NpElemOperator, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
return getattr(np, self.func)(series)
|
||||
|
||||
|
||||
@@ -124,11 +122,11 @@ class Sign(NpElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(Sign, self).__init__(feature, "sign")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
"""
|
||||
To avoid error raised by bool type input, we transform the data into float32.
|
||||
"""
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
# TODO: More precision types should be configurable
|
||||
series = series.astype(np.float32)
|
||||
return getattr(np, self.func)(series)
|
||||
@@ -152,32 +150,6 @@ class Log(NpElemOperator):
|
||||
super(Log, self).__init__(feature, "log")
|
||||
|
||||
|
||||
class Power(NpElemOperator):
|
||||
"""Feature Power
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
Expression
|
||||
a feature instance with power
|
||||
"""
|
||||
|
||||
def __init__(self, feature, exponent):
|
||||
super(Power, self).__init__(feature, "power")
|
||||
self.exponent = exponent
|
||||
|
||||
def __str__(self):
|
||||
return "{}({},{})".format(type(self).__name__, self.feature, self.exponent)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return getattr(np, self.func)(series, self.exponent)
|
||||
|
||||
|
||||
class Mask(NpElemOperator):
|
||||
"""Feature Mask
|
||||
|
||||
@@ -201,8 +173,8 @@ class Mask(NpElemOperator):
|
||||
def __str__(self):
|
||||
return "{}({},{})".format(type(self).__name__, self.feature, self.instrument.lower())
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
return self.feature.load(self.instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
return self.feature.load(self.instrument, start_index, end_index, *args)
|
||||
|
||||
|
||||
class Not(NpElemOperator):
|
||||
@@ -252,24 +224,24 @@ class PairOperator(ExpressionOps):
|
||||
return "{}({},{})".format(type(self).__name__, self.feature_left, self.feature_right)
|
||||
|
||||
def get_longest_back_rolling(self):
|
||||
if isinstance(self.feature_left, Expression):
|
||||
if isinstance(self.feature_left, (Expression,)):
|
||||
left_br = self.feature_left.get_longest_back_rolling()
|
||||
else:
|
||||
left_br = 0
|
||||
|
||||
if isinstance(self.feature_right, Expression):
|
||||
if isinstance(self.feature_right, (Expression,)):
|
||||
right_br = self.feature_right.get_longest_back_rolling()
|
||||
else:
|
||||
right_br = 0
|
||||
return max(left_br, right_br)
|
||||
|
||||
def get_extended_window_size(self):
|
||||
if isinstance(self.feature_left, Expression):
|
||||
if isinstance(self.feature_left, (Expression,)):
|
||||
ll, lr = self.feature_left.get_extended_window_size()
|
||||
else:
|
||||
ll, lr = 0, 0
|
||||
|
||||
if isinstance(self.feature_right, Expression):
|
||||
if isinstance(self.feature_right, (Expression,)):
|
||||
rl, rr = self.feature_right.get_extended_window_size()
|
||||
else:
|
||||
rl, rr = 0, 0
|
||||
@@ -298,16 +270,16 @@ class NpPairOperator(PairOperator):
|
||||
self.func = func
|
||||
super(NpPairOperator, self).__init__(feature_left, feature_right)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
assert any(
|
||||
[isinstance(self.feature_left, Expression), self.feature_right, Expression]
|
||||
[isinstance(self.feature_left, (Expression,)), self.feature_right, Expression]
|
||||
), "at least one of two inputs is Expression instance"
|
||||
if isinstance(self.feature_left, Expression):
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
if isinstance(self.feature_left, (Expression,)):
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, *args)
|
||||
else:
|
||||
series_left = self.feature_left # numeric value
|
||||
if isinstance(self.feature_right, Expression):
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
if isinstance(self.feature_right, (Expression,)):
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, *args)
|
||||
else:
|
||||
series_right = self.feature_right
|
||||
check_length = isinstance(series_left, (np.ndarray, pd.Series)) and isinstance(
|
||||
@@ -335,6 +307,26 @@ class NpPairOperator(PairOperator):
|
||||
return res
|
||||
|
||||
|
||||
class Power(NpPairOperator):
|
||||
"""Power Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature_left : Expression
|
||||
feature instance
|
||||
feature_right : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
Feature:
|
||||
The bases in feature_left raised to the exponents in feature_right
|
||||
"""
|
||||
|
||||
def __init__(self, feature_left, feature_right):
|
||||
super(Power, self).__init__(feature_left, feature_right, "power")
|
||||
|
||||
|
||||
class Add(NpPairOperator):
|
||||
"""Add Operator
|
||||
|
||||
@@ -637,48 +629,48 @@ class If(ExpressionOps):
|
||||
def __str__(self):
|
||||
return "If({},{},{})".format(self.condition, self.feature_left, self.feature_right)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_cond = self.condition.load(instrument, start_index, end_index, freq)
|
||||
if isinstance(self.feature_left, Expression):
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series_cond = self.condition.load(instrument, start_index, end_index, *args)
|
||||
if isinstance(self.feature_left, (Expression,)):
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, *args)
|
||||
else:
|
||||
series_left = self.feature_left
|
||||
if isinstance(self.feature_right, Expression):
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
if isinstance(self.feature_right, (Expression,)):
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, *args)
|
||||
else:
|
||||
series_right = self.feature_right
|
||||
series = pd.Series(np.where(series_cond, series_left, series_right), index=series_cond.index)
|
||||
return series
|
||||
|
||||
def get_longest_back_rolling(self):
|
||||
if isinstance(self.feature_left, Expression):
|
||||
if isinstance(self.feature_left, (Expression,)):
|
||||
left_br = self.feature_left.get_longest_back_rolling()
|
||||
else:
|
||||
left_br = 0
|
||||
|
||||
if isinstance(self.feature_right, Expression):
|
||||
if isinstance(self.feature_right, (Expression,)):
|
||||
right_br = self.feature_right.get_longest_back_rolling()
|
||||
else:
|
||||
right_br = 0
|
||||
|
||||
if isinstance(self.condition, Expression):
|
||||
if isinstance(self.condition, (Expression,)):
|
||||
c_br = self.condition.get_longest_back_rolling()
|
||||
else:
|
||||
c_br = 0
|
||||
return max(left_br, right_br, c_br)
|
||||
|
||||
def get_extended_window_size(self):
|
||||
if isinstance(self.feature_left, Expression):
|
||||
if isinstance(self.feature_left, (Expression,)):
|
||||
ll, lr = self.feature_left.get_extended_window_size()
|
||||
else:
|
||||
ll, lr = 0, 0
|
||||
|
||||
if isinstance(self.feature_right, Expression):
|
||||
if isinstance(self.feature_right, (Expression,)):
|
||||
rl, rr = self.feature_right.get_extended_window_size()
|
||||
else:
|
||||
rl, rr = 0, 0
|
||||
|
||||
if isinstance(self.condition, Expression):
|
||||
if isinstance(self.condition, (Expression,)):
|
||||
cl, cr = self.condition.get_extended_window_size()
|
||||
else:
|
||||
cl, cr = 0, 0
|
||||
@@ -719,8 +711,8 @@ class Rolling(ExpressionOps):
|
||||
def __str__(self):
|
||||
return "{}({},{})".format(type(self).__name__, self.feature, self.N)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
# NOTE: remove all null check,
|
||||
# now it's user's responsibility to decide whether use features in null days
|
||||
# isnull = series.isnull() # NOTE: isnull = NaN, inf is not null
|
||||
@@ -777,8 +769,8 @@ class Ref(Rolling):
|
||||
def __init__(self, feature, N):
|
||||
super(Ref, self).__init__(feature, N, "ref")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
# N = 0, return first day
|
||||
if series.empty:
|
||||
return series # Pandas bug, see: https://github.com/pandas-dev/pandas/issues/21049
|
||||
@@ -967,8 +959,8 @@ class IdxMax(Rolling):
|
||||
def __init__(self, feature, N):
|
||||
super(IdxMax, self).__init__(feature, N, "idxmax")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
if self.N == 0:
|
||||
series = series.expanding(min_periods=1).apply(lambda x: x.argmax() + 1, raw=True)
|
||||
else:
|
||||
@@ -1015,8 +1007,8 @@ class IdxMin(Rolling):
|
||||
def __init__(self, feature, N):
|
||||
super(IdxMin, self).__init__(feature, N, "idxmin")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
if self.N == 0:
|
||||
series = series.expanding(min_periods=1).apply(lambda x: x.argmin() + 1, raw=True)
|
||||
else:
|
||||
@@ -1047,8 +1039,8 @@ class Quantile(Rolling):
|
||||
def __str__(self):
|
||||
return "{}({},{},{})".format(type(self).__name__, self.feature, self.N, self.qscore)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
if self.N == 0:
|
||||
series = series.expanding(min_periods=1).quantile(self.qscore)
|
||||
else:
|
||||
@@ -1095,8 +1087,8 @@ class Mad(Rolling):
|
||||
def __init__(self, feature, N):
|
||||
super(Mad, self).__init__(feature, N, "mad")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
# TODO: implement in Cython
|
||||
|
||||
def mad(x):
|
||||
@@ -1129,8 +1121,8 @@ class Rank(Rolling):
|
||||
def __init__(self, feature, N):
|
||||
super(Rank, self).__init__(feature, N, "rank")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
# TODO: implement in Cython
|
||||
|
||||
def rank(x):
|
||||
@@ -1187,8 +1179,8 @@ class Delta(Rolling):
|
||||
def __init__(self, feature, N):
|
||||
super(Delta, self).__init__(feature, N, "delta")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
if self.N == 0:
|
||||
series = series - series.iloc[0]
|
||||
else:
|
||||
@@ -1225,8 +1217,8 @@ class Slope(Rolling):
|
||||
def __init__(self, feature, N):
|
||||
super(Slope, self).__init__(feature, N, "slope")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
if self.N == 0:
|
||||
series = pd.Series(expanding_slope(series.values), index=series.index)
|
||||
else:
|
||||
@@ -1253,8 +1245,8 @@ class Rsquare(Rolling):
|
||||
def __init__(self, feature, N):
|
||||
super(Rsquare, self).__init__(feature, N, "rsquare")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
_series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
if self.N == 0:
|
||||
series = pd.Series(expanding_rsquare(_series.values), index=_series.index)
|
||||
else:
|
||||
@@ -1282,8 +1274,8 @@ class Resi(Rolling):
|
||||
def __init__(self, feature, N):
|
||||
super(Resi, self).__init__(feature, N, "resi")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
if self.N == 0:
|
||||
series = pd.Series(expanding_resi(series.values), index=series.index)
|
||||
else:
|
||||
@@ -1310,8 +1302,8 @@ class WMA(Rolling):
|
||||
def __init__(self, feature, N):
|
||||
super(WMA, self).__init__(feature, N, "wma")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
# TODO: implement in Cython
|
||||
|
||||
def weighted_mean(x):
|
||||
@@ -1345,8 +1337,8 @@ class EMA(Rolling):
|
||||
def __init__(self, feature, N):
|
||||
super(EMA, self).__init__(feature, N, "ema")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
|
||||
def exp_weighted_mean(x):
|
||||
a = 1 - 2 / (1 + len(x))
|
||||
@@ -1392,17 +1384,17 @@ class PairRolling(ExpressionOps):
|
||||
def __str__(self):
|
||||
return "{}({},{},{})".format(type(self).__name__, self.feature_left, self.feature_right, self.N)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
assert any(
|
||||
[isinstance(self.feature_left, Expression), self.feature_right, Expression]
|
||||
), "at least one of two inputs is Expression instance"
|
||||
|
||||
if isinstance(self.feature_left, Expression):
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, *args)
|
||||
else:
|
||||
series_left = self.feature_left # numeric value
|
||||
if isinstance(self.feature_right, Expression):
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, *args)
|
||||
else:
|
||||
series_right = self.feature_right
|
||||
|
||||
@@ -1465,12 +1457,12 @@ class Corr(PairRolling):
|
||||
def __init__(self, feature_left, feature_right, N):
|
||||
super(Corr, self).__init__(feature_left, feature_right, N, "corr")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
res: pd.Series = super(Corr, self)._load_internal(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
res: pd.Series = super(Corr, self)._load_internal(instrument, start_index, end_index, *args)
|
||||
|
||||
# NOTE: Load uses MemCache, so calling load again will not cause performance degradation
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, *args)
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, *args)
|
||||
res.loc[
|
||||
np.isclose(series_left.rolling(self.N, min_periods=1).std(), 0, atol=2e-05)
|
||||
| np.isclose(series_right.rolling(self.N, min_periods=1).std(), 0, atol=2e-05)
|
||||
@@ -1529,8 +1521,8 @@ class TResample(ElemOperator):
|
||||
def __str__(self):
|
||||
return "{}({},{})".format(type(self).__name__, self.feature, self.freq)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
def _load_internal(self, instrument, start_index, end_index, *args):
|
||||
series = self.feature.load(instrument, start_index, end_index, *args)
|
||||
|
||||
if series.empty:
|
||||
return series
|
||||
@@ -1590,6 +1582,7 @@ OpsList = [
|
||||
IdxMin,
|
||||
If,
|
||||
Feature,
|
||||
PFeature,
|
||||
] + [TResample]
|
||||
|
||||
|
||||
@@ -1622,7 +1615,7 @@ class OpsWrapper:
|
||||
else:
|
||||
_ops_class = _operator
|
||||
|
||||
if not issubclass(_ops_class, Expression):
|
||||
if not issubclass(_ops_class, (Expression,)):
|
||||
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))
|
||||
|
||||
if _ops_class.__name__ in self._ops:
|
||||
@@ -1644,8 +1637,10 @@ def register_all_ops(C):
|
||||
"""register all operator"""
|
||||
logger = get_module_logger("ops")
|
||||
|
||||
from qlib.data.pit import P, PRef # pylint: disable=C0415
|
||||
|
||||
Operators.reset()
|
||||
Operators.register(OpsList)
|
||||
Operators.register(OpsList + [P, PRef])
|
||||
|
||||
if getattr(C, "custom_ops", None) is not None:
|
||||
Operators.register(C.custom_ops)
|
||||
|
||||
72
qlib/data/pit.py
Normal file
72
qlib/data/pit.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Qlib follow the logic below to supporting point-in-time database
|
||||
|
||||
For each stock, the format of its data is <observe_time, feature>. Expression Engine support calculation on such format of data
|
||||
|
||||
To calculate the feature value f_t at a specific observe time t, data with format <period_time, feature> will be used.
|
||||
For example, the average earning of last 4 quarters (period_time) on 20190719 (observe_time)
|
||||
|
||||
The calculation of both <period_time, feature> and <observe_time, feature> data rely on expression engine. It consists of 2 phases.
|
||||
1) calculation <period_time, feature> at each observation time t and it will collasped into a point (just like a normal feature)
|
||||
2) concatenate all th collasped data, we will get data with format <observe_time, feature>.
|
||||
Qlib will use the operator `P` to perform the collapse.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.data.ops import ElemOperator
|
||||
from qlib.log import get_module_logger
|
||||
from .data import Cal
|
||||
|
||||
|
||||
class P(ElemOperator):
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
resample_data = np.empty(end_index - start_index + 1, dtype="float32")
|
||||
|
||||
for cur_index in range(start_index, end_index + 1):
|
||||
cur_time = _calendar[cur_index]
|
||||
# To load expression accurately, more historical data are required
|
||||
start_ws, end_ws = self.feature.get_extended_window_size()
|
||||
if end_ws > 0:
|
||||
raise ValueError(
|
||||
"PIT database does not support referring to future period (e.g. expressions like `Ref('$$roewa_q', -1)` are not supported"
|
||||
)
|
||||
|
||||
# The calculated value will always the last element, so the end_offset is zero.
|
||||
try:
|
||||
s = self._load_feature(instrument, -start_ws, 0, cur_time)
|
||||
resample_data[cur_index - start_index] = s.iloc[-1] if len(s) > 0 else np.nan
|
||||
except FileNotFoundError:
|
||||
get_module_logger("base").warning(f"WARN: period data not found for {str(self)}")
|
||||
return pd.Series(dtype="float32", name=str(self))
|
||||
|
||||
resample_series = pd.Series(
|
||||
resample_data, index=pd.RangeIndex(start_index, end_index + 1), dtype="float32", name=str(self)
|
||||
)
|
||||
return resample_series
|
||||
|
||||
def _load_feature(self, instrument, start_index, end_index, cur_time):
|
||||
return self.feature.load(instrument, start_index, end_index, cur_time)
|
||||
|
||||
def get_longest_back_rolling(self):
|
||||
# The period data will collapse as a normal feature. So no extending and looking back
|
||||
return 0
|
||||
|
||||
def get_extended_window_size(self):
|
||||
# The period data will collapse as a normal feature. So no extending and looking back
|
||||
return 0, 0
|
||||
|
||||
|
||||
class PRef(P):
|
||||
def __init__(self, feature, period):
|
||||
super().__init__(feature)
|
||||
self.period = period
|
||||
|
||||
def __str__(self):
|
||||
return f"{super().__str__()}[{self.period}]"
|
||||
|
||||
def _load_feature(self, instrument, start_index, end_index, cur_time):
|
||||
return self.feature.load(instrument, start_index, end_index, cur_time, self.period)
|
||||
@@ -2,3 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT
|
||||
|
||||
|
||||
__all__ = ["CalendarStorage", "InstrumentStorage", "FeatureStorage", "CalVT", "InstVT", "InstKT"]
|
||||
|
||||
@@ -79,6 +79,7 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
self.future = future
|
||||
self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)
|
||||
self.enable_read_cache = True # TODO: make it configurable
|
||||
self.region = C["region"]
|
||||
|
||||
@property
|
||||
def file_name(self) -> str:
|
||||
@@ -130,7 +131,9 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
else:
|
||||
_calendar = self._read_calendar()
|
||||
if Freq(self._freq_file) != Freq(self.freq):
|
||||
_calendar = resam_calendar(np.array(list(map(pd.Timestamp, _calendar))), self._freq_file, self.freq)
|
||||
_calendar = resam_calendar(
|
||||
np.array(list(map(pd.Timestamp, _calendar))), self._freq_file, self.freq, self.region
|
||||
)
|
||||
return _calendar
|
||||
|
||||
def _get_storage_freq(self) -> List[str]:
|
||||
|
||||
@@ -126,12 +126,10 @@ class CalendarStorage(BaseStorage):
|
||||
@overload
|
||||
def __setitem__(self, i: int, value: CalVT) -> None:
|
||||
"""x.__setitem__(i, o) <==> (x[i] = o)"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __setitem__(self, s: slice, value: Iterable[CalVT]) -> None:
|
||||
"""x.__setitem__(s, o) <==> (x[s] = o)"""
|
||||
...
|
||||
|
||||
def __setitem__(self, i, value) -> None:
|
||||
raise NotImplementedError(
|
||||
@@ -141,12 +139,10 @@ class CalendarStorage(BaseStorage):
|
||||
@overload
|
||||
def __delitem__(self, i: int) -> None:
|
||||
"""x.__delitem__(i) <==> del x[i]"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __delitem__(self, i: slice) -> None:
|
||||
"""x.__delitem__(slice(start: int, stop: int, step: int)) <==> del x[start:stop:step]"""
|
||||
...
|
||||
|
||||
def __delitem__(self, i) -> None:
|
||||
"""
|
||||
@@ -162,12 +158,10 @@ class CalendarStorage(BaseStorage):
|
||||
@overload
|
||||
def __getitem__(self, s: slice) -> Iterable[CalVT]:
|
||||
"""x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, i: int) -> CalVT:
|
||||
"""x.__getitem__(i) <==> x[i]"""
|
||||
...
|
||||
|
||||
def __getitem__(self, i) -> CalVT:
|
||||
"""
|
||||
@@ -467,12 +461,10 @@ class FeatureStorage(BaseStorage):
|
||||
-------
|
||||
pd.Series(values, index=pd.RangeIndex(start, len(values))
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, i: int) -> Tuple[int, float]:
|
||||
"""x.__getitem__(y) <==> x[y]"""
|
||||
...
|
||||
|
||||
def __getitem__(self, i) -> Union[Tuple[int, float], pd.Series]:
|
||||
"""x.__getitem__(y) <==> x[y]
|
||||
|
||||
@@ -61,7 +61,11 @@ def get_module_logger(module_name, level: Optional[int] = None) -> QlibLogger:
|
||||
if level is None:
|
||||
level = C.logging_level
|
||||
|
||||
module_name = "qlib.{}".format(module_name)
|
||||
if not module_name.startswith("qlib."):
|
||||
# Add a prefix of qlib. when the requested ``module_name`` doesn't start with ``qlib.``.
|
||||
# If the module_name is already qlib.xxx, we do not format here. Otherwise, it will become qlib.qlib.xxx.
|
||||
module_name = "qlib.{}".format(module_name)
|
||||
|
||||
# Get logger.
|
||||
module_logger = QlibLogger(module_name)
|
||||
module_logger.setLevel(level)
|
||||
|
||||
@@ -4,3 +4,6 @@
|
||||
import warnings
|
||||
|
||||
from .base import Model
|
||||
|
||||
|
||||
__all__ = ["Model", "warnings"]
|
||||
|
||||
@@ -3,3 +3,6 @@
|
||||
|
||||
from .task import MetaTask
|
||||
from .dataset import MetaTaskDataset
|
||||
|
||||
|
||||
__all__ = ["MetaTask", "MetaTaskDataset"]
|
||||
|
||||
@@ -5,3 +5,11 @@ from .base import RiskModel
|
||||
from .poet import POETCovEstimator
|
||||
from .shrink import ShrinkCovEstimator
|
||||
from .structured import StructuredCovEstimator
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RiskModel",
|
||||
"POETCovEstimator",
|
||||
"ShrinkCovEstimator",
|
||||
"StructuredCovEstimator",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
from sklearn.decomposition import PCA, FactorAnalysis
|
||||
|
||||
|
||||
43
qlib/rl/aux_info.py
Normal file
43
qlib/rl/aux_info.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, TYPE_CHECKING, TypeVar
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
from .simulator import StateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import EnvWrapper
|
||||
|
||||
|
||||
__all__ = ["AuxiliaryInfoCollector"]
|
||||
|
||||
AuxInfoType = TypeVar("AuxInfoType")
|
||||
|
||||
|
||||
class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
|
||||
"""Override this class to collect customized auxiliary information from environment."""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
|
||||
@final
|
||||
def __call__(self, simulator_state: StateType) -> AuxInfoType:
|
||||
return self.collect(simulator_state)
|
||||
|
||||
def collect(self, simulator_state: StateType) -> AuxInfoType:
|
||||
"""Override this for customized auxiliary info.
|
||||
Usually useful in Multi-agent RL.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
simulator_state
|
||||
Retrieved with ``simulator.get_state()``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Auxiliary information.
|
||||
"""
|
||||
raise NotImplementedError("collect is not implemented!")
|
||||
8
qlib/rl/data/__init__.py
Normal file
8
qlib/rl/data/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Common utilities to handle ad-hoc-styled data.
|
||||
|
||||
Most of these snippets comes from research project (paper code).
|
||||
Please take caution when using them in production.
|
||||
"""
|
||||
257
qlib/rl/data/pickle_styled.py
Normal file
257
qlib/rl/data/pickle_styled.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""This module contains utilities to read financial data from pickle-styled files.
|
||||
|
||||
This is the format used in `OPD paper <https://seqml.github.io/opd/>`__. NOT the standard data format in qlib.
|
||||
|
||||
The data here are all wrapped with ``@lru_cache``, which saves the expensive IO cost to repetitively read the data.
|
||||
We also encourage users to use ``get_xxx_yyy`` rather than ``XxxYyy`` (although they are the same thing),
|
||||
because ``get_xxx_yyy`` is cache-optimized.
|
||||
|
||||
Note that these pickle files are dumped with Python 3.8. Python lower than 3.7 might not be able to load them.
|
||||
See `PEP 574 <https://peps.python.org/pep-0574/>`__ for details.
|
||||
|
||||
This file shows resemblence to qlib.backtest.high_performance_ds. We might merge those two in future.
|
||||
"""
|
||||
|
||||
# TODO: merge with qlib/backtest/high_performance_ds.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Sequence, cast
|
||||
from pathlib import Path
|
||||
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from qlib.backtest.decision import OrderDir, Order
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
|
||||
"""Several ad-hoc deal price.
|
||||
``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.
|
||||
``bid_or_ask_fill``: Based on ``bid_or_ask``. If price is 0, use another price (``$ask0`` / ``$bid0``) instead.
|
||||
``close``: Use close price (``$close0``) as deal price.
|
||||
"""
|
||||
|
||||
|
||||
def _infer_processed_data_column_names(shape: int) -> list[str]:
|
||||
if shape == 16:
|
||||
return [
|
||||
"$open",
|
||||
"$high",
|
||||
"$low",
|
||||
"$close",
|
||||
"$vwap",
|
||||
"$bid",
|
||||
"$ask",
|
||||
"$volume",
|
||||
"$bidV",
|
||||
"$bidV1",
|
||||
"$bidV3",
|
||||
"$bidV5",
|
||||
"$askV",
|
||||
"$askV1",
|
||||
"$askV3",
|
||||
"$askV5",
|
||||
]
|
||||
if shape == 6:
|
||||
return ["$high", "$low", "$open", "$close", "$vwap", "$volume"]
|
||||
elif shape == 5:
|
||||
return ["$high", "$low", "$open", "$close", "$volume"]
|
||||
raise ValueError(f"Unrecognized data shape: {shape}")
|
||||
|
||||
|
||||
def _find_pickle(filename_without_suffix: Path) -> Path:
|
||||
suffix_list = [".pkl", ".pkl.backtest"]
|
||||
paths: List[Path] = []
|
||||
for suffix in suffix_list:
|
||||
path = filename_without_suffix.parent / (filename_without_suffix.name + suffix)
|
||||
if path.exists():
|
||||
paths.append(path)
|
||||
if not paths:
|
||||
raise FileNotFoundError(f"No file starting with '{filename_without_suffix}' found")
|
||||
if len(paths) > 1:
|
||||
raise ValueError(f"Multiple paths are found with prefix '{filename_without_suffix}': {paths}")
|
||||
return paths[0]
|
||||
|
||||
|
||||
@lru_cache(maxsize=10) # 10 * 40M = 400MB
|
||||
def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
|
||||
return pd.read_pickle(_find_pickle(filename_without_suffix))
|
||||
|
||||
|
||||
class IntradayBacktestData:
|
||||
"""Raw market data that is often used in backtesting (thus called BacktestData)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
deal_price: DealPriceType = "close",
|
||||
order_dir: int | None = None,
|
||||
):
|
||||
backtest = _read_pickle(data_dir / stock_id)
|
||||
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
|
||||
# No longer need for pandas >= 1.4
|
||||
# backtest = backtest.droplevel([0, 2])
|
||||
|
||||
self.data: pd.DataFrame = backtest
|
||||
self.deal_price_type: DealPriceType = deal_price
|
||||
self.order_dir: int | None = order_dir
|
||||
|
||||
def __repr__(self):
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.data})"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
"""Return a pandas series that can be indexed with time.
|
||||
See :attribute:`DealPriceType` for details."""
|
||||
if self.deal_price_type in ("bid_or_ask", "bid_or_ask_fill"):
|
||||
if self.order_dir is None:
|
||||
raise ValueError("Order direction cannot be none when deal_price_type is not close.")
|
||||
if self.order_dir == OrderDir.SELL:
|
||||
col = "$bid0"
|
||||
else: # BUY
|
||||
col = "$ask0"
|
||||
elif self.deal_price_type == "close":
|
||||
col = "$close0"
|
||||
else:
|
||||
raise ValueError(f"Unsupported deal_price_type: {self.deal_price_type}")
|
||||
price = self.data[col]
|
||||
|
||||
if self.deal_price_type == "bid_or_ask_fill":
|
||||
if self.order_dir == OrderDir.SELL:
|
||||
fill_col = "$ask0"
|
||||
else:
|
||||
fill_col = "$bid0"
|
||||
price = price.replace(0, np.nan).fillna(self.data[fill_col])
|
||||
|
||||
return price
|
||||
|
||||
def get_volume(self) -> pd.Series:
|
||||
"""Return a volume series that can be indexed with time."""
|
||||
return self.data["$volume0"]
|
||||
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
return cast(pd.DatetimeIndex, self.data.index)
|
||||
|
||||
|
||||
class IntradayProcessedData:
|
||||
"""Processed market data after data cleanup and feature engineering.
|
||||
|
||||
It contains both processed data for "today" and "yesterday", as some algorithms
|
||||
might use the market information of the previous day to assist decision making.
|
||||
"""
|
||||
|
||||
today: pd.DataFrame
|
||||
"""Processed data for "today".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
|
||||
yesterday: pd.DataFrame
|
||||
"""Processed data for "yesterday".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
|
||||
def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index):
|
||||
proc = _read_pickle(data_dir / stock_id)
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
cnames = _infer_processed_data_column_names(feature_dim)
|
||||
|
||||
time_length: int = len(time_index)
|
||||
|
||||
try:
|
||||
# new data format
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
assert len(proc) == time_length and len(proc.columns) == feature_dim * 2
|
||||
proc_today = proc[cnames]
|
||||
proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2])
|
||||
except (IndexError, KeyError):
|
||||
# legacy data
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, date]]
|
||||
assert time_length * feature_dim * 2 == len(proc)
|
||||
proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))
|
||||
proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))
|
||||
proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)
|
||||
proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)
|
||||
|
||||
self.today: pd.DataFrame = proc_today
|
||||
self.yesterday: pd.DataFrame = proc_yesterday
|
||||
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
|
||||
assert len(self.today) == len(self.yesterday) == time_length
|
||||
|
||||
def __repr__(self):
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@lru_cache(maxsize=100) # 100 * 50K = 5MB
|
||||
def load_intraday_backtest_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None
|
||||
) -> IntradayBacktestData:
|
||||
return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_intraday_processed_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index
|
||||
) -> IntradayProcessedData:
|
||||
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
|
||||
|
||||
def load_orders(
|
||||
order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None
|
||||
) -> Sequence[Order]:
|
||||
"""Load orders, and set start time and end time for the orders."""
|
||||
|
||||
start_time = start_time or pd.Timestamp("0:00:00")
|
||||
end_time = end_time or pd.Timestamp("23:59:59")
|
||||
|
||||
if order_path.is_file():
|
||||
order_df = pd.read_pickle(order_path)
|
||||
else:
|
||||
order_df = []
|
||||
for file in order_path.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
order_df.append(order_data)
|
||||
order_df = pd.concat(order_df)
|
||||
|
||||
order_df = order_df.reset_index()
|
||||
|
||||
# Legacy-style orders have "date" instead of "datetime"
|
||||
if "date" in order_df.columns:
|
||||
order_df = order_df.rename(columns={"date": "datetime"})
|
||||
|
||||
# Sometimes "date" are str rather than Timestamp
|
||||
order_df["datetime"] = pd.to_datetime(order_df["datetime"])
|
||||
|
||||
orders: List[Order] = []
|
||||
|
||||
for _, row in order_df.iterrows():
|
||||
# filter out orders with amount == 0
|
||||
if row["amount"] <= 0:
|
||||
continue
|
||||
orders.append(
|
||||
Order(
|
||||
row["instrument"],
|
||||
row["amount"],
|
||||
int(row["order_type"]),
|
||||
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
|
||||
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
|
||||
)
|
||||
)
|
||||
|
||||
return orders
|
||||
7
qlib/rl/entries/__init__.py
Normal file
7
qlib/rl/entries/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Train, test, inference utilities.
|
||||
|
||||
The APIs in this directory are NOT considered final and are subject to change!
|
||||
"""
|
||||
99
qlib/rl/entries/test.py
Normal file
99
qlib/rl/entries/test.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.constant import INF
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.utils import DataQueue, EnvWrapper, FiniteEnvType, LogCollector, LogWriter, vectorize_env
|
||||
|
||||
|
||||
_logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
def backtest(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
finite_env_type: FiniteEnvType = "subproc",
|
||||
concurrency: int = 2,
|
||||
) -> None:
|
||||
"""Backtest with the parallelism provided by RL framework.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
simulator_fn
|
||||
Callable receiving initial seed, returning a simulator.
|
||||
state_interpreter
|
||||
Interprets the state of simulators.
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
policy
|
||||
Policy to test against.
|
||||
logger
|
||||
Logger to record the backtest results. Logger must be present because
|
||||
without logger, all information will be lost.
|
||||
reward
|
||||
Optional reward function. For backtest, this is for testing the rewards
|
||||
and logging them only.
|
||||
finite_env_type
|
||||
Type of finite env implementation.
|
||||
concurrency
|
||||
Parallel workers.
|
||||
"""
|
||||
|
||||
# To save bandwidth
|
||||
min_loglevel = min(lg.loglevel for lg in logger) if isinstance(logger, list) else logger.loglevel
|
||||
|
||||
def env_factory():
|
||||
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
|
||||
# and could be thread unsafe.
|
||||
# I'm not sure whether it's a design flaw.
|
||||
# I'll rethink about this when designing the trainer.
|
||||
|
||||
if finite_env_type == "dummy":
|
||||
# We could only experience the "threading-unsafe" problem in dummy.
|
||||
state = copy.deepcopy(state_interpreter)
|
||||
action = copy.deepcopy(action_interpreter)
|
||||
rew = copy.deepcopy(reward)
|
||||
else:
|
||||
state, action, rew = state_interpreter, action_interpreter, reward
|
||||
|
||||
return EnvWrapper(
|
||||
simulator_fn,
|
||||
state,
|
||||
action,
|
||||
seed_iterator,
|
||||
rew,
|
||||
logger=LogCollector(min_loglevel=min_loglevel),
|
||||
)
|
||||
|
||||
with DataQueue(initial_states) as seed_iterator:
|
||||
vector_env = vectorize_env(
|
||||
env_factory,
|
||||
finite_env_type,
|
||||
concurrency,
|
||||
logger,
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
test_collector = Collector(policy, vector_env)
|
||||
_logger.info("All ready. Start backtest.")
|
||||
test_collector.collect(n_step=INF * len(vector_env))
|
||||
4
qlib/rl/entries/train.py
Normal file
4
qlib/rl/entries/train.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# TBD
|
||||
@@ -1,94 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
from ..backtest.executor import BaseExecutor
|
||||
from .interpreter import StateInterpreter, ActionInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
|
||||
|
||||
class BaseRLEnv:
|
||||
"""Base environment for reinforcement learning"""
|
||||
|
||||
def reset(self, **kwargs):
|
||||
raise NotImplementedError("reset is not implemented!")
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
step method of rl env
|
||||
Parameters
|
||||
----------
|
||||
action :
|
||||
action from rl policy
|
||||
|
||||
Returns
|
||||
-------
|
||||
env state to rl policy
|
||||
"""
|
||||
raise NotImplementedError("step is not implemented!")
|
||||
|
||||
|
||||
class QlibRLEnv:
|
||||
"""qlib-based RL env"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: BaseExecutor,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
executor : BaseExecutor
|
||||
qlib multi-level/single-level executor, which can be regarded as gamecore in RL
|
||||
"""
|
||||
self.executor = executor
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self.executor.reset(**kwargs)
|
||||
|
||||
|
||||
class QlibIntRLEnv(QlibRLEnv):
|
||||
"""(Qlib)-based RL (Env) with (Interpreter)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: BaseExecutor,
|
||||
state_interpreter: Union[dict, StateInterpreter],
|
||||
action_interpreter: Union[dict, ActionInterpreter],
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state_interpreter : Union[dict, StateInterpreter]
|
||||
interpreter that interprets the qlib execute result into rl env state.
|
||||
|
||||
action_interpreter : Union[dict, ActionInterpreter]
|
||||
interpreter that interprets the rl agent action into qlib order list
|
||||
"""
|
||||
super(QlibIntRLEnv, self).__init__(executor=executor)
|
||||
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
step method of rl env, it run as following step:
|
||||
- Use `action_interpreter.interpret` method to interpret the agent action into order list
|
||||
- Execute the order list with qlib executor, and get the executed result
|
||||
- Use `state_interpreter.interpret` method to interpret the executed result into env state
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action :
|
||||
action from rl policy
|
||||
|
||||
Returns
|
||||
-------
|
||||
env state to rl policy
|
||||
"""
|
||||
_interpret_decision = self.action_interpreter.interpret(action=action)
|
||||
_execute_result = self.executor.execute(trade_decision=_interpret_decision)
|
||||
_interpret_state = self.state_interpreter.interpret(execute_result=_execute_result)
|
||||
return _interpret_state
|
||||
@@ -1,47 +1,150 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
class BaseInterpreter:
|
||||
"""Base Interpreter"""
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic, Any
|
||||
|
||||
def interpret(self, **kwargs):
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
import numpy as np
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
from .simulator import StateType, ActType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import EnvWrapper
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
|
||||
ObsType = TypeVar("ObsType")
|
||||
PolicyActType = TypeVar("PolicyActType")
|
||||
|
||||
|
||||
class ActionInterpreter(BaseInterpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
class Interpreter:
|
||||
"""Interpreter is a media between states produced by simulators and states needed by RL policies.
|
||||
Interpreters are two-way:
|
||||
|
||||
def interpret(self, action, **kwargs):
|
||||
"""interpret method
|
||||
1. From simulator state to policy state (aka observation), see :class:`StateInterpreter`.
|
||||
2. From policy action to action accepted by simulator, see :class:`ActionInterpreter`.
|
||||
|
||||
Inherit one of the two sub-classes to define your own interpreter.
|
||||
This super-class is only used for isinstance check.
|
||||
|
||||
Interpreters are recommended to be stateless, meaning that storing temporary information with ``self.xxx``
|
||||
in interpreter is anti-pattern. In future, we might support register some interpreter-related
|
||||
states by calling ``self.env.register_state()``, but it's not planned for first iteration.
|
||||
"""
|
||||
|
||||
|
||||
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gym.Space:
|
||||
raise NotImplementedError()
|
||||
|
||||
@final # no overridden
|
||||
def __call__(self, simulator_state: StateType) -> ObsType:
|
||||
obs = self.interpret(simulator_state)
|
||||
self.validate(obs)
|
||||
return obs
|
||||
|
||||
def validate(self, obs: ObsType) -> None:
|
||||
"""Validate whether an observation belongs to the pre-defined observation space."""
|
||||
_gym_space_contains(self.observation_space, obs)
|
||||
|
||||
def interpret(self, simulator_state: StateType) -> ObsType:
|
||||
"""Interpret the state of simulator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action :
|
||||
rl agent action
|
||||
simulator_state
|
||||
Retrieved with ``simulator.get_state()``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
qlib orders
|
||||
|
||||
State needed by policy. Should conform with the state space defined in ``observation_space``.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
|
||||
class StateInterpreter(BaseInterpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
|
||||
def interpret(self, execute_result, **kwargs):
|
||||
"""interpret method
|
||||
env: "EnvWrapper" | None = None
|
||||
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
raise NotImplementedError()
|
||||
|
||||
@final # no overridden
|
||||
def __call__(self, simulator_state: StateType, action: PolicyActType) -> ActType:
|
||||
self.validate(action)
|
||||
obs = self.interpret(simulator_state, action)
|
||||
return obs
|
||||
|
||||
def validate(self, action: PolicyActType) -> None:
|
||||
"""Validate whether an action belongs to the pre-defined action space."""
|
||||
_gym_space_contains(self.action_space, action)
|
||||
|
||||
def interpret(self, simulator_state: StateType, action: PolicyActType) -> ActType:
|
||||
"""Convert the policy action to simulator action.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
execute_result :
|
||||
qlib execution result
|
||||
simulator_state
|
||||
Retrieved with ``simulator.get_state()``.
|
||||
action
|
||||
Raw action given by policy.
|
||||
|
||||
Returns
|
||||
----------
|
||||
rl env state
|
||||
-------
|
||||
The action needed by simulator,
|
||||
"""
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
|
||||
def _gym_space_contains(space: gym.Space, x: Any) -> None:
|
||||
"""Strengthened version of gym.Space.contains.
|
||||
Giving more diagnostic information on why validation fails.
|
||||
|
||||
Throw exception rather than returning true or false.
|
||||
"""
|
||||
if isinstance(space, spaces.Dict):
|
||||
if not isinstance(x, dict) or len(x) != len(space):
|
||||
raise GymSpaceValidationError("Sample must be a dict with same length as space.", space, x)
|
||||
for k, subspace in space.spaces.items():
|
||||
if k not in x:
|
||||
raise GymSpaceValidationError(f"Key {k} not found in sample.", space, x)
|
||||
try:
|
||||
_gym_space_contains(subspace, x[k])
|
||||
except GymSpaceValidationError as e:
|
||||
raise GymSpaceValidationError(f"Subspace of key {k} validation error.", space, x) from e
|
||||
|
||||
elif isinstance(space, spaces.Tuple):
|
||||
if isinstance(x, (list, np.ndarray)):
|
||||
x = tuple(x) # Promote list and ndarray to tuple for contains check
|
||||
if not isinstance(x, tuple) or len(x) != len(space):
|
||||
raise GymSpaceValidationError("Sample must be a tuple with same length as space.", space, x)
|
||||
for i, (subspace, part) in enumerate(zip(space, x)):
|
||||
try:
|
||||
_gym_space_contains(subspace, part)
|
||||
except GymSpaceValidationError as e:
|
||||
raise GymSpaceValidationError(f"Subspace of index {i} validation error.", space, x) from e
|
||||
|
||||
else:
|
||||
if not space.contains(x):
|
||||
raise GymSpaceValidationError("Validation error reported by gym.", space, x)
|
||||
|
||||
|
||||
class GymSpaceValidationError(Exception):
|
||||
def __init__(self, message: str, space: gym.Space, x: Any):
|
||||
self.message = message
|
||||
self.space = space
|
||||
self.x = x
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user