mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-22 13:41:43 +08:00
Compare commits
147 Commits
v0.8.1
...
mini_proje
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a02ac95538 | ||
|
|
cc94c32db6 | ||
|
|
9a40fd3cdc | ||
|
|
c4281121e3 | ||
|
|
2de9903200 | ||
|
|
2cf842bcfe | ||
|
|
9e381493c2 | ||
|
|
a73b60d05a | ||
|
|
64979ad769 | ||
|
|
c5cf8fb9cc | ||
|
|
5d579d1a20 | ||
|
|
3c9c76b384 | ||
|
|
9d0a8f61d1 | ||
|
|
701b18af1b | ||
|
|
84ff662a26 | ||
|
|
00e40e775b | ||
|
|
45fe5e6974 | ||
|
|
366a9c33f3 | ||
|
|
982e0da715 | ||
|
|
cd5e5d5235 | ||
|
|
caea495f40 | ||
|
|
d934c8caba | ||
|
|
a139986f4e | ||
|
|
12c3de42d0 | ||
|
|
fe0f9427f2 | ||
|
|
a973e4fb66 | ||
|
|
c60366addd | ||
|
|
41447f320b | ||
|
|
e1271a83f7 | ||
|
|
30b531086c | ||
|
|
87926513cb | ||
|
|
7bfc7e1797 | ||
|
|
85e7cdcac3 | ||
|
|
08fd1d3f42 | ||
|
|
defd6758f6 | ||
|
|
61cc1a3867 | ||
|
|
655ed982cf | ||
|
|
2952c443ca | ||
|
|
7f1293ec34 | ||
|
|
73438807f9 | ||
|
|
962751c72d | ||
|
|
56cfa480dc | ||
|
|
6edd0bf298 | ||
|
|
fe155703b0 | ||
|
|
3c4f4bfd44 | ||
|
|
5200ff520a | ||
|
|
30e457119c | ||
|
|
243e516cf1 | ||
|
|
e229b567ad | ||
|
|
f129bfef5d | ||
|
|
9dd5e07819 | ||
|
|
00ed35fc1b | ||
|
|
3f53a097b0 | ||
|
|
fb230a8097 | ||
|
|
ff4724e248 | ||
|
|
73d90f7f44 | ||
|
|
b7988e6428 | ||
|
|
8efc8b92ef | ||
|
|
f2a5ecd98a | ||
|
|
705354cc28 | ||
|
|
1b5d0d4d6d | ||
|
|
f4a481945b | ||
|
|
5f18ba7970 | ||
|
|
2681c61c60 | ||
|
|
776b0c5bb4 | ||
|
|
829ad9f5e9 | ||
|
|
921c13cc90 | ||
|
|
0f519f6053 | ||
|
|
2ed806c846 | ||
|
|
d2f0bebf60 | ||
|
|
615a381038 | ||
|
|
568a88fddb | ||
|
|
058f976727 | ||
|
|
faa99f30fa | ||
|
|
837067b9e1 | ||
|
|
3a911bc09b | ||
|
|
90be21bb40 | ||
|
|
7540b1257b | ||
|
|
57f7ed9914 | ||
|
|
9e3d0249f7 | ||
|
|
2ac964c470 | ||
|
|
07f0d4f599 | ||
|
|
ea4fb33ff2 | ||
|
|
ed0c238787 | ||
|
|
80af395b3c | ||
|
|
4dc66932d5 | ||
|
|
40dd84857c | ||
|
|
74cc21fc2c | ||
|
|
ec8969a3ae | ||
|
|
528f74af09 | ||
|
|
d482726f28 | ||
|
|
cfc3e886ed | ||
|
|
60d45ad770 | ||
|
|
0e8b94a552 | ||
|
|
4bf127eba5 | ||
|
|
c149c8616c | ||
|
|
3274e16c95 | ||
|
|
d496cf7476 | ||
|
|
357ee74b6f | ||
|
|
5da5cf5175 | ||
|
|
6a946761cf | ||
|
|
76b7b5f24b | ||
|
|
d7d19feb4e | ||
|
|
bba6972a55 | ||
|
|
18af288692 | ||
|
|
ba056850cb | ||
|
|
aed5b8ebc0 | ||
|
|
79355666a9 | ||
|
|
144e1e2459 | ||
|
|
635632e4ed | ||
|
|
c5834476e2 | ||
|
|
01afd06e18 | ||
|
|
d533219738 | ||
|
|
5b5c99fe75 | ||
|
|
da48f42f3f | ||
|
|
f979dcf5e8 | ||
|
|
97aa16a078 | ||
|
|
094be9be86 | ||
|
|
d9b9386032 | ||
|
|
b86a30aae7 | ||
|
|
2c5a4691f3 | ||
|
|
54344c4426 | ||
|
|
303cdb8ce3 | ||
|
|
1a0ac1ab6d | ||
|
|
a79e446724 | ||
|
|
bdf1fb29a6 | ||
|
|
86e1265f69 | ||
|
|
628eb7fa73 | ||
|
|
2a1b512cd2 | ||
|
|
50e7901e87 | ||
|
|
3ba54cd1ab | ||
|
|
483d01f0c1 | ||
|
|
61836cba3d | ||
|
|
aeb5e40c77 | ||
|
|
116f0fa7a7 | ||
|
|
5296cce725 | ||
|
|
292fcc9e98 | ||
|
|
d3fbf066cf | ||
|
|
52ecb79e0b | ||
|
|
59c52eac0a | ||
|
|
f455305a2a | ||
|
|
a67f67db6e | ||
|
|
5c2e99aee3 | ||
|
|
2bb8a4ce0e | ||
|
|
7f274b1e4e | ||
|
|
2aee9e0145 | ||
|
|
a62e2ec4de |
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -8,7 +8,7 @@
|
||||
<!--- Why is this change required? What problem does it solve? -->
|
||||
|
||||
## How Has This Been Tested?
|
||||
<! --- Put an `x` in all the boxes that apply: --->
|
||||
<!--- Put an `x` in all the boxes that apply: --->
|
||||
- [ ] Pass the test by running: `pytest qlib/tests/test_all_pipeline.py` under upper directory of `qlib`.
|
||||
- [ ] If you are adding a new feature, test on your own test scripts.
|
||||
|
||||
|
||||
3
.github/workflows/python-publish.yml
vendored
3
.github/workflows/python-publish.yml
vendored
@@ -12,7 +12,8 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest, macos-11]
|
||||
os: [windows-latest, macos-11]
|
||||
# FIXME: macos-latest will raise error now.
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
|
||||
81
.github/workflows/test.yml
vendored
81
.github/workflows/test.yml
vendored
@@ -33,11 +33,85 @@ jobs:
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
pip install numpy==1.19.5 ruamel.yaml
|
||||
pip install pyqlib --ignore-installed
|
||||
pip install pyqlib --ignore-installed
|
||||
|
||||
- name: Make html with sphinx
|
||||
run: |
|
||||
pip install -U sphinx
|
||||
pip install sphinx_rtd_theme readthedocs_sphinx_ext
|
||||
pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
|
||||
cd docs
|
||||
sphinx-build -b html . build
|
||||
cd ..
|
||||
|
||||
# Check Qlib with pylint
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
# C0209: consider-using-f-string
|
||||
# R0402: consider-using-from-import
|
||||
# R1705: no-else-return
|
||||
# R1710: inconsistent-return-statements
|
||||
# R1725: super-with-arguments
|
||||
# R1735: use-dict-literal
|
||||
# W0102: dangerous-default-value
|
||||
# W0212: protected-access
|
||||
# W0221: arguments-differ
|
||||
# W0223: abstract-method
|
||||
# W0231: super-init-not-called
|
||||
# W0237: arguments-renamed
|
||||
# W0612: unused-variable
|
||||
# W0621: redefined-outer-name
|
||||
# W0622: redefined-builtin
|
||||
# FIXME: specify exception type
|
||||
# W0703: broad-except
|
||||
# W1309: f-string-without-interpolation
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
- name: Check Qlib with pylint
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install pylint
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0201,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
|
||||
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
# Description: We have used black to limit the length of each line to 120.
|
||||
# F541 f-string is missing placeholders
|
||||
# Description: The same thing is done when using pylint for detection.
|
||||
# E266 too many leading '#' for block comment
|
||||
# Description: To make the code more readable, a lot of "#" is used.
|
||||
# This error code appears centrally in:
|
||||
# qlib/backtest/executor.py
|
||||
# qlib/data/ops.py
|
||||
# qlib/utils/__init__.py
|
||||
# E402 module level import not at top of file
|
||||
# Description: There are times when module level import is not available at the top of the file.
|
||||
# W503 line break before binary operator
|
||||
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
|
||||
# E731 do not assign a lambda expression, use a def
|
||||
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
|
||||
# E203 whitespace before ':'
|
||||
# Description: If there is whitespace before ":", it cannot pass the black check.
|
||||
- name: Check Qlib with flake8
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install flake8
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
# https://github.com/python/mypy/issues/10600
|
||||
- name: Check Qlib with mypy
|
||||
run: |
|
||||
pip install mypy
|
||||
mypy qlib --install-types --non-interactive || true
|
||||
mypy qlib
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
|
||||
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
@@ -48,6 +122,7 @@ jobs:
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install gym tianshou torch
|
||||
pip install -e .
|
||||
|
||||
- name: Install test dependencies
|
||||
@@ -57,10 +132,10 @@ jobs:
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
pip install -r scripts/data_collector/pit/requirements.txt
|
||||
cd tests
|
||||
python -m pytest . --durations=10
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
|
||||
23
.github/workflows/test_macos.yml
vendored
23
.github/workflows/test_macos.yml
vendored
@@ -34,10 +34,24 @@ jobs:
|
||||
python -m black qlib -l 120 --check --diff
|
||||
# Test Qlib installed with pip
|
||||
|
||||
- name: Check Qlib with flake8
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install flake8
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
python -m pip install numpy==1.19.5
|
||||
python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
- name: Make html with sphnix
|
||||
run: |
|
||||
pip install -U sphinx
|
||||
pip install sphinx_rtd_theme readthedocs_sphinx_ext
|
||||
pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
|
||||
cd docs
|
||||
sphinx-build -b html . build
|
||||
cd ..
|
||||
- name: Install Lightgbm for MacOS
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
@@ -49,7 +63,10 @@ jobs:
|
||||
brew install libomp.rb
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
|
||||
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
@@ -60,7 +77,8 @@ jobs:
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
python setup.py install
|
||||
python -m pip install gym tianshou torch
|
||||
pip install -e .
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
@@ -68,6 +86,7 @@ jobs:
|
||||
python -m pip install black pytest
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
pip install -r scripts/data_collector/pit/requirements.txt
|
||||
cd tests
|
||||
python -m pytest . --durations=0
|
||||
- name: Test workflow by config (install from source)
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -27,6 +27,10 @@ examples/estimator/estimator_example/
|
||||
|
||||
*.egg-info/
|
||||
|
||||
# test related
|
||||
test-output.xml
|
||||
.output
|
||||
.data
|
||||
|
||||
# special software
|
||||
mlruns/
|
||||
@@ -34,6 +38,7 @@ mlruns/
|
||||
tags
|
||||
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
17
.mypy.ini
Normal file
17
.mypy.ini
Normal file
@@ -0,0 +1,17 @@
|
||||
[mypy]
|
||||
exclude = (?x)(
|
||||
^qlib/backtest
|
||||
| ^qlib/contrib
|
||||
| ^qlib/data
|
||||
| ^qlib/model
|
||||
| ^qlib/strategy
|
||||
| ^qlib/tests
|
||||
| ^qlib/utils
|
||||
| ^qlib/workflow
|
||||
| ^qlib/config\.py$
|
||||
| ^qlib/log\.py$
|
||||
| ^qlib/__init__\.py$
|
||||
)
|
||||
ignore_missing_imports = true
|
||||
disallow_incomplete_defs = true
|
||||
follow_imports = skip
|
||||
12
.pre-commit-config.yaml
Normal file
12
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: ["qlib", "-l 120"]
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 4.0.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: ["--ignore=E501,F541,E266,E402,W503,E731,E203"]
|
||||
5
.pylintrc
Normal file
5
.pylintrc
Normal file
@@ -0,0 +1,5 @@
|
||||
[TYPECHECK]
|
||||
# https://stackoverflow.com/a/53572939
|
||||
# List of members which are set dynamically and missed by Pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed.
|
||||
generated-members=numpy.*, torch.*
|
||||
@@ -17,5 +17,5 @@ python:
|
||||
version: 3.7
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- method: setuptools
|
||||
path: .
|
||||
- method: pip
|
||||
path: .
|
||||
|
||||
134
README.md
134
README.md
@@ -11,23 +11,28 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Meta-Learning-based framework & DDG-DA | [Released](https://github.com/microsoft/qlib/pull/743) on Jan 10, 2022 |
|
||||
| Planning-based portfolio optimization | [Released](https://github.com/microsoft/qlib/pull/754) on Dec 28, 2021 |
|
||||
| Release Qlib v0.8.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.8.0) on Dec 8, 2021 |
|
||||
| ADD model | [Released](https://github.com/microsoft/qlib/pull/704) on Nov 22, 2021 |
|
||||
| ADARNN model | [Released](https://github.com/microsoft/qlib/pull/689) on Nov 14, 2021 |
|
||||
| TCN model | [Released](https://github.com/microsoft/qlib/pull/668) on Nov 4, 2021 |
|
||||
| Nested Decision Framework | [Released](https://github.com/microsoft/qlib/pull/438) on Oct 1, 2021. [Example](https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py) and [Doc](https://qlib.readthedocs.io/en/latest/component/highfreq.html) |
|
||||
|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
|
||||
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
|
||||
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
|
||||
| TCTS Model | [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |
|
||||
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
|
||||
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
|
||||
| High-frequency data processing example | [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
|
||||
| High-frequency trading example | [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 |
|
||||
| High-frequency data(1min) | [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 |
|
||||
| Tabnet Model | [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
|
||||
| HIST and IGMTF models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1040) on Apr 10, 2022 |
|
||||
| Qlib [notebook tutorial](https://github.com/microsoft/qlib/tree/main/examples/tutorial) | 📖 [Released](https://github.com/microsoft/qlib/pull/1037) on Apr 7, 2022 |
|
||||
| Ibovespa index data | :rice: [Released](https://github.com/microsoft/qlib/pull/990) on Apr 6, 2022 |
|
||||
| Point-in-Time database | :hammer: [Released](https://github.com/microsoft/qlib/pull/343) on Mar 10, 2022 |
|
||||
| Arctic Provider Backend & Orderbook data example | :hammer: [Released](https://github.com/microsoft/qlib/pull/744) on Jan 17, 2022 |
|
||||
| Meta-Learning-based framework & DDG-DA | :chart_with_upwards_trend: :hammer: [Released](https://github.com/microsoft/qlib/pull/743) on Jan 10, 2022 |
|
||||
| Planning-based portfolio optimization | :hammer: [Released](https://github.com/microsoft/qlib/pull/754) on Dec 28, 2021 |
|
||||
| Release Qlib v0.8.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.8.0) on Dec 8, 2021 |
|
||||
| ADD model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/704) on Nov 22, 2021 |
|
||||
| ADARNN model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/689) on Nov 14, 2021 |
|
||||
| TCN model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/668) on Nov 4, 2021 |
|
||||
| Nested Decision Framework | :hammer: [Released](https://github.com/microsoft/qlib/pull/438) on Oct 1, 2021. [Example](https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py) and [Doc](https://qlib.readthedocs.io/en/latest/component/highfreq.html) |
|
||||
| Temporal Routing Adaptor (TRA) | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
|
||||
| Transformer & Localformer | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
|
||||
| Release Qlib v0.7.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
|
||||
| TCTS Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |
|
||||
| Online serving and automatic model rolling | :hammer: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
|
||||
| DoubleEnsemble Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
|
||||
| High-frequency data processing example | :hammer: [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
|
||||
| High-frequency trading example | :chart_with_upwards_trend: [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 |
|
||||
| High-frequency data(1min) | :rice: [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 |
|
||||
| Tabnet Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
|
||||
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
@@ -44,35 +49,58 @@ With Qlib, users can easily try ideas to create better Quant investment strategi
|
||||
|
||||
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
|
||||
|
||||
- [**Plans**](#plans)
|
||||
- [Framework of Qlib](#framework-of-qlib)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Installation](#installation)
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
|
||||
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
|
||||
- [Main Challenges & Solutions in Quant Research](#main-challenges--solutions-in-quant-research)
|
||||
- [Forecasting: Finding Valuable Signals/Patterns](#forecasting-finding-valuable-signalspatterns)
|
||||
- [**Quant Model (Paper) Zoo**](#quant-model-paper-zoo)
|
||||
- [Run a Single Model](#run-a-single-model)
|
||||
- [Run Multiple Models](#run-multiple-models)
|
||||
- [Adapting to Market Dynamics](#adapting-to-market-dynamics)
|
||||
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
|
||||
- [More About Qlib](#more-about-qlib)
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
|
||||
- [Related Reports](#related-reports)
|
||||
- [Contact Us](#contact-us)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<th>Frameworks, Tutorial, Data & DevOps</th>
|
||||
<th>Main Challenges & Solutions in Quant Research</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<li><a href="#plans"><strong>Plans</strong></a></li>
|
||||
<li><a href="#framework-of-qlib">Framework of Qlib</a></li>
|
||||
<li><a href="#quick-start">Quick Start</a></li>
|
||||
<ul dir="auto">
|
||||
<li type="circle"><a href="#installation">Installation</a> </li>
|
||||
<li type="circle"><a href="#data-preparation">Data Preparation</a></li>
|
||||
<li type="circle"><a href="#auto-quant-research-workflow">Auto Quant Research Workflow</a></li>
|
||||
<li type="circle"><a href="#building-customized-quant-research-workflow-by-code">Building Customized Quant Research Workflow by Code</a></li></ul>
|
||||
<li><a href="#quant-dataset-zoo"><strong>Quant Dataset Zoo</strong></a></li>
|
||||
<li><a href="#more-about-qlib">More About Qlib</a></li>
|
||||
<li><a href="#offline-mode-and-online-mode">Offline Mode and Online Mode</a>
|
||||
<ul>
|
||||
<li type="circle"><a href="#performance-of-qlib-data-server">Performance of Qlib Data Server</a></li></ul>
|
||||
<li><a href="#related-reports">Related Reports</a></li>
|
||||
<li><a href="#contact-us">Contact Us</a></li>
|
||||
<li><a href="#contributing">Contributing</a></li>
|
||||
</td>
|
||||
<td valign="baseline">
|
||||
<li><a href="#main-challenges--solutions-in-quant-research">Main Challenges & Solutions in Quant Research</a>
|
||||
<ul>
|
||||
<li type="circle"><a href="#forecasting-finding-valuable-signalspatterns">Forecasting: Finding Valuable Signals/Patterns</a>
|
||||
<ul>
|
||||
<li type="disc"><a href="#quant-model-paper-zoo"><strong>Quant Model (Paper) Zoo</strong></a>
|
||||
<ul>
|
||||
<li type="circle"><a href="#run-a-single-model">Run a Single Model</a></li>
|
||||
<li type="circle"><a href="#run-multiple-models">Run Multiple Models</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li type="circle"><a href="#adapting-to-market-dynamics">Adapting to Market Dynamics</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
# Plans
|
||||
New features under development(order by estimated release time).
|
||||
Your feedbacks about the features are very important.
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
|
||||
| Orderbook database | Under review: https://github.com/microsoft/qlib/pull/744 |
|
||||
<!-- | Feature | Status | -->
|
||||
<!-- | -- | ------ | -->
|
||||
|
||||
# Framework of Qlib
|
||||
|
||||
@@ -80,7 +108,6 @@ Your feedbacks about the features are very important.
|
||||
<img src="docs/_static/img/framework.svg" />
|
||||
</div>
|
||||
|
||||
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules, and each component could be used stand-alone.
|
||||
|
||||
| Name | Description |
|
||||
@@ -92,6 +119,8 @@ At the module level, Qlib is a platform that consists of the above components. T
|
||||
* The modules with hand-drawn style are under development and will be released in the future.
|
||||
* The modules with dashed borders are highly user-customizable and extendible.
|
||||
|
||||
(p.s. framework image is created with https://draw.io/)
|
||||
|
||||
|
||||
# Quick Start
|
||||
|
||||
@@ -115,6 +144,7 @@ This table demonstrates the supported Python version of `Qlib`:
|
||||
1. **Conda** is suggested for managing your Python environment.
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
1. `Qlib`Requires `tables` package, `hdf5` in tables does not support python3.9.
|
||||
|
||||
### Install with pip
|
||||
Users can easily install ``Qlib`` by pip according to the following command.
|
||||
@@ -136,17 +166,11 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
```
|
||||
|
||||
* Clone the repository and install ``Qlib`` as follows.
|
||||
* If you haven't installed qlib by the command ``pip install pyqlib`` before:
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
python setup.py install
|
||||
```
|
||||
* If you have already installed the stable version by the command ``pip install pyqlib``:
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
pip install .
|
||||
```
|
||||
**Note**: **Only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommanded approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
|
||||
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test.yml) may help you find the problem.
|
||||
|
||||
@@ -316,6 +340,8 @@ Here is a list of models built on `Qlib`.
|
||||
- [TCN based on pytorch (Shaojie Bai, et al. 2018)](examples/benchmarks/TCN/)
|
||||
- [ADARNN based on pytorch (YunTao Du, et al. 2021)](examples/benchmarks/ADARNN/)
|
||||
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
|
||||
- [IGMTF based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/IGMTF/)
|
||||
- [HIST based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/HIST/)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -364,6 +390,8 @@ Dataset plays a very important role in Quant. Here is a list of the datasets bui
|
||||
Your PR to build new Quant dataset is highly welcomed.
|
||||
|
||||
# More About Qlib
|
||||
If you want to have a quick glance at the most frequently used components of qlib, you can try notebooks [here](examples/tutorial/).
|
||||
|
||||
The detailed documents are organized in [docs](docs/).
|
||||
[Sphinx](http://www.sphinx-doc.org) and the readthedocs theme is required to build the documentation in html formats.
|
||||
```bash
|
||||
@@ -446,9 +474,13 @@ If you don't know how to start to contribute, you can refer to the following exa
|
||||
| Docs | [Improve docs quality](https://github.com/microsoft/qlib/pull/797/files) ; [Fix a typo](https://github.com/microsoft/qlib/pull/774) |
|
||||
| Feature | Implement a [requested feature](https://github.com/microsoft/qlib/projects) like [this](https://github.com/microsoft/qlib/pull/754); [Refactor interfaces](https://github.com/microsoft/qlib/pull/539/files) |
|
||||
| Dataset | [Add a dataset](https://github.com/microsoft/qlib/pull/733) |
|
||||
| Models | [Implement a new model](https://github.com/microsoft/qlib/pull/689) |
|
||||
| Models | [Implement a new model](https://github.com/microsoft/qlib/pull/689), [some instructions to contribute models](https://github.com/microsoft/qlib/tree/main/examples/benchmarks#contributing) |
|
||||
|
||||
If you would like to become one of Qlib's maintainers to contribute more (e.g. help merge PR, triage issues), please contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)). We are glad to help you to set the right permission.
|
||||
[Good first issues](https://github.com/microsoft/qlib/labels/good%20first%20issue) are labelled to indicate that they are easy to start your contributions.
|
||||
|
||||
You can find some impefect implementation in Qlib by `rg 'TODO|FIXME' qlib`
|
||||
|
||||
If you would like to become one of Qlib's maintainers to contribute more (e.g. help merge PR, triage issues), please contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)). We are glad to help to upgrade your permission.
|
||||
|
||||
## Licence
|
||||
Most contributions require you to agree to a
|
||||
|
||||
136
docs/advanced/PIT.rst
Normal file
136
docs/advanced/PIT.rst
Normal file
@@ -0,0 +1,136 @@
|
||||
.. _pit:
|
||||
|
||||
===========================
|
||||
(P)oint-(I)n-(T)ime Database
|
||||
===========================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
------------
|
||||
Point-in-time data is a very important consideration when performing any sort of historical market analysis.
|
||||
|
||||
For example, let’s say we are backtesting a trading strategy and we are using the past five years of historical data as our input.
|
||||
Our model is assumed to trade once a day, at the market close, and we’ll say we are calculating the trading signal for 1 January 2020 in our backtest. At that point, we should only have data for 1 January 2020, 31 December 2019, 30 December 2019 etc.
|
||||
|
||||
In financial data (especially financial reports), the same piece of data may be amended for multiple times overtime. If we only use the latest version for historical backtesting, data leakage will happen.
|
||||
Point-in-time database is designed for solving this problem to make sure user get the right version of data at any historical timestamp. It will keep the performance of online trading and historical backtesting the same.
|
||||
|
||||
|
||||
|
||||
Data Preparation
|
||||
----------------
|
||||
|
||||
Qlib provides a crawler to help users to download financial data and then a converter to dump the data in Qlib format.
|
||||
Please follow `scripts/data_collector/pit/README.md <https://github.com/microsoft/qlib/tree/main/scripts/data_collector/pit/>`_ to download and convert data.
|
||||
Besides, you can find some additional usage examples there.
|
||||
|
||||
|
||||
File-based design for PIT data
|
||||
------------------------------
|
||||
|
||||
Qlib provides a file-based storage for PIT data.
|
||||
|
||||
For each feature, it contains 4 columns, i.e. date, period, value, _next.
|
||||
Each row corresponds to a statement.
|
||||
|
||||
The meaning of each feature with filename like `XXX_a.data`:
|
||||
|
||||
- `date`: the statement's date of publication.
|
||||
- `period`: the period of the statement. (e.g. it will be quarterly frequency in most of the markets)
|
||||
- If it is an annual period, it will be an integer corresponding to the year
|
||||
- If it is an quarterly periods, it will be an integer like `<year><index of quarter>`. The last two decimal digits represents the index of quarter. Others represent the year.
|
||||
- `value`: the described value
|
||||
- `_next`: the byte index of the next occurance of the field.
|
||||
|
||||
Besides the feature data, an index `XXX_a.index` is included to speed up the querying performance
|
||||
|
||||
The statements are soted by the `date` in ascending order from the beginning of the file.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# the data format from XXXX.data
|
||||
array([(20070428, 200701, 0.090219 , 4294967295),
|
||||
(20070817, 200702, 0.13933 , 4294967295),
|
||||
(20071023, 200703, 0.24586301, 4294967295),
|
||||
(20080301, 200704, 0.3479 , 80),
|
||||
(20080313, 200704, 0.395989 , 4294967295),
|
||||
(20080422, 200801, 0.100724 , 4294967295),
|
||||
(20080828, 200802, 0.24996801, 4294967295),
|
||||
(20081027, 200803, 0.33412001, 4294967295),
|
||||
(20090325, 200804, 0.39011699, 4294967295),
|
||||
(20090421, 200901, 0.102675 , 4294967295),
|
||||
(20090807, 200902, 0.230712 , 4294967295),
|
||||
(20091024, 200903, 0.30072999, 4294967295),
|
||||
(20100402, 200904, 0.33546099, 4294967295),
|
||||
(20100426, 201001, 0.083825 , 4294967295),
|
||||
(20100812, 201002, 0.200545 , 4294967295),
|
||||
(20101029, 201003, 0.260986 , 4294967295),
|
||||
(20110321, 201004, 0.30739301, 4294967295),
|
||||
(20110423, 201101, 0.097411 , 4294967295),
|
||||
(20110831, 201102, 0.24825101, 4294967295),
|
||||
(20111018, 201103, 0.318919 , 4294967295),
|
||||
(20120323, 201104, 0.4039 , 420),
|
||||
(20120411, 201104, 0.403925 , 4294967295),
|
||||
(20120426, 201201, 0.112148 , 4294967295),
|
||||
(20120810, 201202, 0.26484701, 4294967295),
|
||||
(20121026, 201203, 0.370487 , 4294967295),
|
||||
(20130329, 201204, 0.45004699, 4294967295),
|
||||
(20130418, 201301, 0.099958 , 4294967295),
|
||||
(20130831, 201302, 0.21044201, 4294967295),
|
||||
(20131016, 201303, 0.30454299, 4294967295),
|
||||
(20140325, 201304, 0.394328 , 4294967295),
|
||||
(20140425, 201401, 0.083217 , 4294967295),
|
||||
(20140829, 201402, 0.16450299, 4294967295),
|
||||
(20141030, 201403, 0.23408499, 4294967295),
|
||||
(20150421, 201404, 0.319612 , 4294967295),
|
||||
(20150421, 201501, 0.078494 , 4294967295),
|
||||
(20150828, 201502, 0.137504 , 4294967295),
|
||||
(20151023, 201503, 0.201709 , 4294967295),
|
||||
(20160324, 201504, 0.26420501, 4294967295),
|
||||
(20160421, 201601, 0.073664 , 4294967295),
|
||||
(20160827, 201602, 0.136576 , 4294967295),
|
||||
(20161029, 201603, 0.188062 , 4294967295),
|
||||
(20170415, 201604, 0.244385 , 4294967295),
|
||||
(20170425, 201701, 0.080614 , 4294967295),
|
||||
(20170728, 201702, 0.15151 , 4294967295),
|
||||
(20171026, 201703, 0.25416601, 4294967295),
|
||||
(20180328, 201704, 0.32954201, 4294967295),
|
||||
(20180428, 201801, 0.088887 , 4294967295),
|
||||
(20180802, 201802, 0.170563 , 4294967295),
|
||||
(20181029, 201803, 0.25522 , 4294967295),
|
||||
(20190329, 201804, 0.34464401, 4294967295),
|
||||
(20190425, 201901, 0.094737 , 4294967295),
|
||||
(20190713, 201902, 0. , 1040),
|
||||
(20190718, 201902, 0.175322 , 4294967295),
|
||||
(20191016, 201903, 0.25581899, 4294967295)],
|
||||
dtype=[('date', '<u4'), ('period', '<u4'), ('value', '<f8'), ('_next', '<u4')])
|
||||
# - each row contains 20 byte
|
||||
|
||||
|
||||
# The data format from XXXX.index. It consists of two parts
|
||||
# 1) the start index of the data. So the first part of the info will be like
|
||||
2007
|
||||
# 2) the remain index data will be like information below
|
||||
# - The data indicate the **byte index** of first data update of a period.
|
||||
# - e.g. Because the info at both byte 80 and 100 corresponds to 200704. The byte index of first occurance (i.e. 100) is recorded in the data.
|
||||
array([ 0, 20, 40, 60, 100,
|
||||
120, 140, 160, 180, 200,
|
||||
220, 240, 260, 280, 300,
|
||||
320, 340, 360, 380, 400,
|
||||
440, 460, 480, 500, 520,
|
||||
540, 560, 580, 600, 620,
|
||||
640, 660, 680, 700, 720,
|
||||
740, 760, 780, 800, 820,
|
||||
840, 860, 880, 900, 920,
|
||||
940, 960, 980, 1000, 1020,
|
||||
1060, 4294967295], dtype=uint32)
|
||||
|
||||
|
||||
|
||||
|
||||
Known limitations:
|
||||
|
||||
- Currently, the PIT database is designed for quarterly or annually factors, which can handle fundamental data of financial reports in most markets.
|
||||
- Qlib leverage the file name to identify the type of the data. File with name like `XXX_q.data` corresponds to quarterly data. File with name like `XXX_a.data` corresponds to annual data.
|
||||
- The caclulation of PIT is not performed in the optimal way. There is great potential to boost the performance of PIT data calcuation.
|
||||
@@ -52,7 +52,8 @@ Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency
|
||||
Qlib Format Dataset
|
||||
--------------------
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows.
|
||||
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
Users can leverage `$factor` to get the original trading price (e.g. `$close / $factor` to get the original close price).
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@@ -436,7 +437,7 @@ Dataset
|
||||
|
||||
The ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing.
|
||||
|
||||
The motivation of this module is that we want to maximize the flexibility of of different models to handle data that are suitable for themselves. This module gives the model the flexibility to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
|
||||
The motivation of this module is that we want to maximize the flexibility of different models to handle data that are suitable for themselves. This module gives the model the flexibility to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
|
||||
|
||||
If user's model need process its data in a different way, user could implement his own ``Dataset`` class. If the model's
|
||||
data processing is not special, ``DatasetH`` can be used directly.
|
||||
|
||||
@@ -28,4 +28,11 @@ The frequency of trading algorithm, decision content and execution environment c
|
||||
Example
|
||||
===========================
|
||||
|
||||
An example of nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.
|
||||
An example of nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.
|
||||
|
||||
|
||||
Besides, the above examples, here are some other related work about high-frequency trading in Qlib.
|
||||
|
||||
- `Prediction with high-frequency data <https://github.com/microsoft/qlib/tree/main/examples/highfreq#benchmarks-performance-predicting-the-price-trend-in-high-frequency-data>`_
|
||||
- `Examples <https://github.com/microsoft/qlib/blob/main/examples/orderbook_data/>`_ to extract features form high-frequency data without fixed frequency.
|
||||
- `A paper <https://github.com/microsoft/qlib/tree/high-freq-execution#high-frequency-execution>`_ for high-frequency trading.
|
||||
|
||||
@@ -106,6 +106,9 @@ Example
|
||||
`SignalRecord` is the `Record Template` in ``Qlib``, please refer to `Workflow <recorder.html#record-template>`_.
|
||||
|
||||
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
|
||||
Technically, the meaning of the model prediction depends on the label setting designed by user.
|
||||
By default, the meaning of the score is normally the rating of the instruments by the forecasting model. The higher the score, the more profit the instruments.
|
||||
|
||||
|
||||
Custom Model
|
||||
===================
|
||||
|
||||
@@ -23,6 +23,10 @@ The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`
|
||||
|
||||
**NOTE**: User should keep his data source updated to support online serving. For example, Qlib provides `a batch of scripts <https://github.com/microsoft/qlib/blob/main/scripts/data_collector/yahoo/README.md#automatic-update-of-daily-frequency-datafrom-yahoo-finance>`_ to help users update Yahoo daily data.
|
||||
|
||||
Known limitations currently
|
||||
- Currently, the daily updating prediction for the next trading day is supported. But generating orders for the next trading day is not supported due to the `limitations of public data <https://github.com/microsoft/qlib/issues/215#issuecomment-766293563>_`
|
||||
|
||||
|
||||
Online Manager
|
||||
=============
|
||||
|
||||
|
||||
@@ -143,3 +143,9 @@ Here is a simple exampke of what is done in ``PortAnaRecord``, which users can r
|
||||
print(analysis_df)
|
||||
|
||||
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
|
||||
|
||||
|
||||
|
||||
Known Limitations
|
||||
=================
|
||||
- The Python objects are saved based on pickle, which may results in issues when the environment dumping objects and loading objects are different.
|
||||
|
||||
@@ -20,6 +20,9 @@ Introduction
|
||||
- model_performance_graph
|
||||
|
||||
|
||||
All of the accumulated profit metrics(e.g. return, max drawdown) in Qlib are calculated by summation.
|
||||
This avoids the metrics or the plots being skewed exponentially over time.
|
||||
|
||||
Graphical Reports
|
||||
===================
|
||||
|
||||
@@ -101,7 +104,7 @@ Graphical Result
|
||||
- Axis Y:
|
||||
- `ic`
|
||||
The `Pearson correlation coefficient` series between `label` and `prediction score`.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
In the above example, the `label` is formulated as `Ref($close, -2)/Ref($close, -1)-1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
|
||||
- `rank_ic`
|
||||
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.
|
||||
|
||||
@@ -24,11 +24,10 @@ BaseStrategy
|
||||
|
||||
Qlib provides a base class ``qlib.strategy.base.BaseStrategy``. All strategy classes need to inherit the base class and implement its interface.
|
||||
|
||||
- `get_risk_degree`
|
||||
Return the proportion of your total value you will use in investment. Dynamically risk_degree will result in Market timing.
|
||||
|
||||
- `generate_order_list`
|
||||
Return the order list.
|
||||
- `generate_trade_decision`
|
||||
generate_trade_decision is a key interface that generates trade decisions in each trading bar.
|
||||
The frequency to call this method depends on the executor frequency("time_per_step"="day" by default). But the trading frequency can be decided by users' implementation.
|
||||
For example, if the user wants to trading in weekly while the `time_per_step` is "day" in executor, user can return non-empty TradeDecision weekly(otherwise return empty like `this <https://github.com/microsoft/qlib/blob/main/qlib/contrib/strategy/signal_strategy.py#L132>`_ ).
|
||||
|
||||
Users can inherit `BaseStrategy` to customize their strategy class.
|
||||
|
||||
@@ -67,18 +66,24 @@ TopkDropoutStrategy
|
||||
- Adopt the ``Topk-Drop`` algorithm to calculate the target amount of each stock
|
||||
|
||||
.. note::
|
||||
``Topk-Drop`` algorithm:
|
||||
There are two parameters for the ``Topk-Drop`` algorithm:
|
||||
|
||||
- `Topk`: The number of stocks held
|
||||
- `Drop`: The number of stocks sold on each trading day
|
||||
|
||||
Currently, the number of held stocks is `Topk`.
|
||||
On each trading day, the `Drop` number of held stocks with the worst `prediction score` will be sold, and the same number of unheld stocks with the best `prediction score` will be bought.
|
||||
|
||||
In general, the number of stocks currently held is `Topk`, with the exception of being zero at the beginning period of trading.
|
||||
For each trading day, let $d$ be the number of the instruments currently held and with a rank $\gt K$ when ranked by the prediction scores from high to low.
|
||||
Then `d` number of stocks currently held with the worst `prediction score` will be sold, and the same number of unheld stocks with the best `prediction score` will be bought.
|
||||
|
||||
In general, $d=$`Drop`, especially when the pool of the candidate instruments is large, $K$ is large, and `Drop` is small.
|
||||
|
||||
In most cases, ``TopkDrop`` algorithm sells and buys `Drop` stocks every trading day, which yields a turnover rate of 2$\times$`Drop`/$K$.
|
||||
|
||||
The following images illustrate a typical scenario.
|
||||
.. image:: ../_static/img/topk_drop.png
|
||||
:alt: Topk-Drop
|
||||
|
||||
``TopkDrop`` algorithm sells `Drop` stocks every trading day, which guarantees a fixed turnover rate.
|
||||
|
||||
|
||||
- Generate the order list from the target amount
|
||||
|
||||
@@ -124,7 +129,9 @@ A prediction sample is shown as follows.
|
||||
|
||||
Normally, the prediction score is the output of the models. But some models are learned from a label with a different scale. So the scale of the prediction score may be different from your expectation(e.g. the return of instruments).
|
||||
|
||||
Qlib didn't add a step to scale the prediction score to a unified scale. Because not every trading strategy cares about the scale(e.g. TopkDropoutStrategy only cares about the order). So the strategy is responsible for rescaling the prediction score(e.g. some portfolio-optimization-based strategies may require a meaningful scale).
|
||||
Qlib didn't add a step to scale the prediction score to a unified scale due to the following reasons.
|
||||
- Because not every trading strategy cares about the scale(e.g. TopkDropoutStrategy only cares about the order). So the strategy is responsible for rescaling the prediction score(e.g. some portfolio-optimization-based strategies may require a meaningful scale).
|
||||
- The model has the flexibility to define the target, loss, and data processing. So we don't think there is a silver bullet to rescale it back directly barely based on the model's outputs. If you want to scale it back to some meaningful values(e.g. stock returns.), an intuitive solution is to create a regression model for the model's recent outputs and your recent target values.
|
||||
|
||||
Running backtest
|
||||
-----------------
|
||||
@@ -160,12 +167,9 @@ Running backtest
|
||||
start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_obj
|
||||
)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"], freq=analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq
|
||||
)
|
||||
# default frequency will be daily (i.e. "day")
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
pprint(analysis_df)
|
||||
@@ -190,6 +194,14 @@ Running backtest
|
||||
qlib.init(provider_uri=<qlib data dir>)
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
# Benchmark is for calculating the excess return of your strategy.
|
||||
# Its data format will be like **ONE normal instrument**.
|
||||
# For example, you can query its data with the code below
|
||||
# `D.features(["SH000300"], ["$close"], start_time='2010-01-01', end_time='2017-12-31', freq='day')`
|
||||
# It is different from the argument `market`, which indicates a universe of stocks (e.g. **A SET** of stocks like csi300)
|
||||
# For example, you can query all data from a stock market with the code below.
|
||||
# ` D.features(D.instruments(market='csi300'), ["$close"], start_time='2010-01-01', end_time='2017-12-31', freq='day')`
|
||||
|
||||
FREQ = "day"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
|
||||
@@ -124,9 +124,47 @@ Configuration File
|
||||
===================
|
||||
|
||||
Let's get into details of ``qrun`` in this section.
|
||||
|
||||
Before using ``qrun``, users need to prepare a configuration file. The following content shows how to prepare each part of the configuration file.
|
||||
|
||||
The design logic of the configuration file is very simple. It predefines fixed workflows and provide this yaml interface to users to define how to initialize each component.
|
||||
It follow the design of `init_instance_by_config <https://github.com/microsoft/qlib/blob/2aee9e0145decc3e71def70909639b5e5a6f4b58/qlib/utils/__init__.py#L264>`_ . It defines the initialization of each component of Qlib, which typically include the class and the initialization arguments.
|
||||
|
||||
For example, the following yaml and code are equivalent.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
kwargs = {
|
||||
"loss": "mse" ,
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
}
|
||||
LGBModel(kwargs)
|
||||
|
||||
|
||||
Qlib Init Section
|
||||
--------------------
|
||||
|
||||
@@ -195,7 +233,7 @@ The meaning of each field is as follows:
|
||||
Dataset Section
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Model <../component/data.html#dataset>`_.
|
||||
The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Data <../component/data.html#dataset>`_.
|
||||
|
||||
The keywords arguments configuration of the ``DataHandler`` is as follows:
|
||||
|
||||
@@ -210,7 +248,7 @@ The keywords arguments configuration of the ``DataHandler`` is as follows:
|
||||
|
||||
Users can refer to the document of `DataHandler <../component/data.html#datahandler>`_ for more information about the meaning of each field in the configuration.
|
||||
|
||||
Here is the configuration for the ``Dataset`` module which will take care of data preprossing and slicing during the training and testing phase.
|
||||
Here is the configuration for the ``Dataset`` module which will take care of data preprocessing and slicing during the training and testing phase.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
|
||||
12
docs/conf.py
12
docs/conf.py
@@ -54,9 +54,9 @@ master_doc = "index"
|
||||
|
||||
|
||||
# General information about the project.
|
||||
project = u"QLib"
|
||||
copyright = u"Microsoft"
|
||||
author = u"Microsoft"
|
||||
project = "QLib"
|
||||
copyright = "Microsoft"
|
||||
author = "Microsoft"
|
||||
|
||||
# The version info for the project you're documenting, acts as replacement for
|
||||
# |version| and |release|, also used in various other places throughout the
|
||||
@@ -174,7 +174,7 @@ latex_elements = {
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, "qlib.tex", u"QLib Documentation", u"Microsoft", "manual"),
|
||||
(master_doc, "qlib.tex", "QLib Documentation", "Microsoft", "manual"),
|
||||
]
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ latex_documents = [
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [(master_doc, "qlib", u"QLib Documentation", [author], 1)]
|
||||
man_pages = [(master_doc, "qlib", "QLib Documentation", [author], 1)]
|
||||
|
||||
|
||||
# -- Options for Texinfo output -------------------------------------------
|
||||
@@ -194,7 +194,7 @@ texinfo_documents = [
|
||||
(
|
||||
master_doc,
|
||||
"QLib",
|
||||
u"QLib Documentation",
|
||||
"QLib Documentation",
|
||||
author,
|
||||
"QLib",
|
||||
"One line description of project.",
|
||||
|
||||
@@ -14,9 +14,35 @@ Continuous Integration (CI) tools help you stick to the quality standards by run
|
||||
|
||||
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
|
||||
|
||||
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
|
||||
1. Qlib will check the code format with black. The PR will raise error if your code does not align to the standard of Qlib(e.g. a common error is the mixed use of space and tab).
|
||||
You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: bash
|
||||
|
||||
pip install black
|
||||
python -m black . -l 120
|
||||
|
||||
|
||||
2. Qlib will check your code style pylint. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L66).
|
||||
Sometime pylint's restrictions are not that reasonable. You can ignore specific errors like this
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
return -ICLoss()(pred, target, index) # pylint: disable=E1130
|
||||
|
||||
|
||||
3. Qlib will check your code style flake8. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L73).
|
||||
You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
flake8 --ignore E501,F541,E402,F401,W503,E741,E266,E203,E302,E731,E262,F523,F821,F811,F841,E713,E265,W291,E712,E722,W293 qlib
|
||||
|
||||
|
||||
4. Qlib has integrated pre-commit, which will make it easier for developers to format their code.
|
||||
Just run the following two commands, and the code will be automatically formatted using black and flake8 when the git commit command is executed.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -e .[dev]
|
||||
pre-commit install
|
||||
@@ -53,6 +53,7 @@ Document Structure
|
||||
Online & Offline mode <advanced/server.rst>
|
||||
Serialization <advanced/serial.rst>
|
||||
Task Management <advanced/task_management.rst>
|
||||
Point-In-Time database <advanced/PIT.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -120,6 +120,32 @@ For more details about features, please refer `Feature API <../component/data.ht
|
||||
|
||||
.. note:: When calling `D.features()` at the client, use parameter `disk_cache=0` to skip dataset cache, use `disk_cache=1` to generate and use dataset cache. In addition, when calling at the server, users can use `disk_cache=2` to update the dataset cache.
|
||||
|
||||
|
||||
When you are building complicated expressions, implementing all the expressions in a single string may not be easy.
|
||||
For example, it looks quite long and complicated:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data import D
|
||||
>> data = D.features(["sh600519"], ["(($high / $close) + ($open / $close)) * (($high / $close) + ($open / $close)) / ($high / $close) + ($open / $close)"], start_time="20200101")
|
||||
|
||||
|
||||
But using string is not the only way to implement the expression. You can also implement expression by code.
|
||||
Here is an exmaple which does the same thing as above examples.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data.ops import *
|
||||
>> f1 = Feature("high") / Feature("close")
|
||||
>> f2 = Feature("open") / Feature("close")
|
||||
>> f3 = f1 + f2
|
||||
>> f4 = f3 * f3 / f3
|
||||
|
||||
>> data = D.features(["sh600519"], [f4], start_time="20200101")
|
||||
>> data.head()
|
||||
|
||||
|
||||
API
|
||||
====================
|
||||
To know more about how to use the Data, go to API Reference: `Data API <../reference/api.html#data>`_
|
||||
|
||||
@@ -37,7 +37,8 @@ Initialize Qlib before calling other APIs: run following code in python.
|
||||
Parameters
|
||||
-------------------
|
||||
|
||||
Besides `provider_uri` and `region`, `qlib.init` has other parameters. The following are several important parameters of `qlib.init`:
|
||||
Besides `provider_uri` and `region`, `qlib.init` has other parameters.
|
||||
The following are several important parameters of `qlib.init` (`Qlib` has a lot of config. Only part of parameters are limited here. More detailed setting can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/config.py>`_):
|
||||
|
||||
- `provider_uri`
|
||||
Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.
|
||||
@@ -48,7 +49,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
- ``qlib.constant.REG_CN``: China stock market.
|
||||
|
||||
Different modes will result in different trading limitations and costs.
|
||||
The region is just `shortcuts for defining a batch of configurations <https://github.com/microsoft/qlib/blob/main/qlib/config.py#L239>`_. Users can set the key configurations manually if the existing region setting can't meet their requirements.
|
||||
The region is just `shortcuts for defining a batch of configurations <https://github.com/microsoft/qlib/blob/528f74af099bf6156e9480bcd2bb28e453231212/qlib/config.py#L249>`_, which include minimal trading order unit (``trade_unit``), trading limitation (``limit_threshold``) , etc. It is not a necessary part and users can set the key configurations manually if the existing region setting can't meet their requirements.
|
||||
- `redis_host`
|
||||
Type: str, optional parameter(default: "127.0.0.1"), host of `redis`
|
||||
The lock and cache mechanism relies on redis.
|
||||
@@ -88,3 +89,9 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
"task_url": "mongodb://localhost:27017/", # your mongo url
|
||||
"task_db_name": "rolling_db", # the database name of Task Management
|
||||
})
|
||||
|
||||
- `logging_level`
|
||||
The logging level for the system.
|
||||
|
||||
- `kernels`
|
||||
The number of processes used when calculating features in Qlib's expression engine. It is very helpful to set it to 1 when you are debuggin an expression calculating exception
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
|
||||
@@ -6,3 +6,4 @@
|
||||
|
||||
[https://www.ijcai.org/Proceedings/2017/0366.pdf](https://www.ijcai.org/Proceedings/2017/0366.pdf)
|
||||
|
||||
- NOTE: Current version of implementation is just a simplified version of ALSTM. It is an LSTM with attention.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
catboost==0.24.3
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
lightgbm==3.1.0
|
||||
@@ -1,4 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
|
||||
3
examples/benchmarks/HIST/README.md
Normal file
3
examples/benchmarks/HIST/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# HIST
|
||||
* Code: [https://github.com/Wentao-Xu/HIST](https://github.com/Wentao-Xu/HIST)
|
||||
* Paper: [HIST: A Graph-based Framework for Stock Trend Forecasting via Mining Concept-Oriented Shared InformationAdaRNN: Adaptive Learning and Forecasting for Time Series](https://arxiv.org/abs/2110.13716).
|
||||
BIN
examples/benchmarks/HIST/qlib_csi300_stock_index.npy
Normal file
BIN
examples/benchmarks/HIST/qlib_csi300_stock_index.npy
Normal file
Binary file not shown.
4
examples/benchmarks/HIST/requirements.txt
Normal file
4
examples/benchmarks/HIST/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
92
examples/benchmarks/HIST/workflow_config_hist_Alpha360.yaml
Normal file
92
examples/benchmarks/HIST/workflow_config_hist_Alpha360.yaml
Normal file
@@ -0,0 +1,92 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: HIST
|
||||
module_path: qlib.contrib.model.pytorch_hist
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
metric: ic
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
stock2concept: "benchmarks/HIST/qlib_csi300_stock2concept.npy"
|
||||
stock_index: "benchmarks/HIST/qlib_csi300_stock_index.npy"
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
4
examples/benchmarks/IGMTF/README.md
Normal file
4
examples/benchmarks/IGMTF/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# IGMTF
|
||||
* Code: [https://github.com/Wentao-Xu/IGMTF](https://github.com/Wentao-Xu/IGMTF)
|
||||
* Paper: [IGMTF: An Instance-wise Graph-based Framework for
|
||||
Multivariate Time Series Forecasting](https://arxiv.org/abs/2109.06489).
|
||||
4
examples/benchmarks/IGMTF/requirements.txt
Normal file
4
examples/benchmarks/IGMTF/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -0,0 +1,89 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: IGMTF
|
||||
module_path: qlib.contrib.model.pytorch_igmtf
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
metric: ic
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,4 +1,4 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
lightgbm==3.1.0
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
torch==1.2.0
|
||||
@@ -1,4 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
|
||||
@@ -63,8 +63,6 @@ task:
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
input_dim: 157
|
||||
output_dim: 1
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
@@ -73,6 +71,8 @@ task:
|
||||
batch_size: 8192
|
||||
GPU: 0
|
||||
weight_decay: 0.0002
|
||||
pt_model_kwargs:
|
||||
input_dim: 157
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -51,8 +51,6 @@ task:
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
input_dim: 360
|
||||
output_dim: 1
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
@@ -60,6 +58,8 @@ task:
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
GPU: 0
|
||||
pt_model_kwargs:
|
||||
input_dim: 360
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -4,6 +4,7 @@ This page lists a batch of methods designed for alpha seeking. Each method tries
|
||||
The alpha is evaluated in two ways.
|
||||
1. The correlation between the alpha and future return.
|
||||
1. Constructing portfolio based on the alpha and evaluating the final total return.
|
||||
- The explanation of metrics can be found [here](https://qlib.readthedocs.io/en/latest/component/report.html#id4)
|
||||
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
|
||||
|
||||
@@ -16,6 +17,8 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
> NOTE:
|
||||
> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference.
|
||||
|
||||
> NOTE:
|
||||
> We have very limited resources to implement and finetune the models. We tried our best effort to fairly compare these models. But some models may have greater potential than what it looks like in the table below. Your contribution is highly welcomed to explore their potential.
|
||||
|
||||
## Alpha158 dataset
|
||||
|
||||
@@ -62,7 +65,33 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
|
||||
| TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 |
|
||||
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
|
||||
| IGMTF(Wentao Xu, et al.) | Alpha360 | 0.0480±0.00 | 0.3589±0.02 | 0.0606±0.00 | 0.4773±0.01 | 0.0946±0.02 | 1.3509±0.25 | -0.0716±0.02 |
|
||||
| HIST(Wentao Xu, et al.) | Alpha360 | 0.0522±0.00 | 0.3530±0.01 | 0.0667±0.00 | 0.4576±0.01 | 0.0987±0.02 | 1.3726±0.27 | -0.0681±0.01 |
|
||||
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
- The base model of TCTS is GRU.
|
||||
- About the datasets
|
||||
- Alpha158 is a tabular dataset. There are less spatial relationships between different features. Each feature are carefully desgined by human (a.k.a feature engineering)
|
||||
- Alpha360 contains raw price and volue data without much feature engineering. There are strong strong spatial relationships between the features in the time dimension.
|
||||
- The metrics can be categorized into two
|
||||
- Signal-based evaluation: IC, ICIR, Rank IC, Rank ICIR
|
||||
- Portfolio-based metrics: Annualized Return, Information Ratio, Max Drawdown
|
||||
|
||||
|
||||
# Contributing
|
||||
|
||||
Your contributions to new models are highly welcome!
|
||||
|
||||
If you want to contribute your new models, you can follow the steps below.
|
||||
1. Create a folder for your model
|
||||
2. The folder contains following items(you can refer to [this example](https://github.com/microsoft/qlib/tree/main/examples/benchmarks/TCTS)).
|
||||
- `requirements.txt`: required dependencies.
|
||||
- `README.md`: a brief introduction to your models
|
||||
- `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.
|
||||
3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).
|
||||
4. Please updated your results in the benchmark tables, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on 20 runs with different random seeds, if you don't have enough computational resource, you can ask for help in the PR).
|
||||
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
|
||||
|
||||
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -1,5 +1,5 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
seaborn
|
||||
|
||||
@@ -6,8 +6,7 @@ import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -95,7 +94,7 @@ class MTSDatasetH(DatasetH):
|
||||
shuffle=True,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
assert horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
@@ -150,8 +149,15 @@ class MTSDatasetH(DatasetH):
|
||||
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
fn = _get_date_parse_fn(self._index[0][1])
|
||||
start_date = fn(slc.start)
|
||||
end_date = fn(slc.stop)
|
||||
|
||||
if isinstance(slc, slice):
|
||||
start, stop = slc.start, slc.stop
|
||||
elif isinstance(slc, (list, tuple)):
|
||||
start, stop = slc
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
start_date = fn(start)
|
||||
end_date = fn(stop)
|
||||
obj = copy.copy(self) # shallow copy
|
||||
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
|
||||
obj._data = self._data
|
||||
|
||||
@@ -130,7 +130,7 @@ class TRAModel(Model):
|
||||
|
||||
if prob is not None:
|
||||
P = sinkhorn(-L, epsilon=0.01) # sample assignment matrix
|
||||
lamb = self.lamb * (self.rho ** self.global_step)
|
||||
lamb = self.lamb * (self.rho**self.global_step)
|
||||
reg = prob.log().mul(P).sum(dim=-1).mean()
|
||||
loss = loss - lamb * reg
|
||||
|
||||
@@ -547,7 +547,7 @@ def evaluate(pred):
|
||||
score = pred.score
|
||||
label = pred.label
|
||||
diff = score - label
|
||||
MSE = (diff ** 2).mean()
|
||||
MSE = (diff**2).mean()
|
||||
MAE = (diff.abs()).mean()
|
||||
IC = score.corr(label)
|
||||
return {"MSE": MSE, "MAE": MAE, "IC": IC}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -1,3 +1,3 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
torch==1.2.0
|
||||
@@ -1,3 +1,3 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
xgboost==1.2.1
|
||||
@@ -1,16 +1,19 @@
|
||||
# Introduction
|
||||
This is the implementation of `DDG-DA` based on `Meta Controller` component provided by `Qlib`.
|
||||
|
||||
## Background
|
||||
Please refer to the paper for more details: *DDG-DA: Data Distribution Generation for Predictable Concept Drift Adaptation* [[arXiv](https://arxiv.org/abs/2201.04038)]
|
||||
|
||||
|
||||
# Background
|
||||
In many real-world scenarios, we often deal with streaming data that is sequentially collected over time. Due to the non-stationary nature of the environment, the streaming data distribution may change in unpredictable ways, which is known as concept drift. To handle concept drift, previous methods first detect when/where the concept drift happens and then adapt models to fit the distribution of the latest data. However, there are still many cases that some underlying factors of environment evolution are predictable, making it possible to model the future concept drift trend of the streaming data, while such cases are not fully explored in previous work.
|
||||
|
||||
Therefore, we propose a novel method `DDG-DA`, that can effectively forecast the evolution of data distribution and improve the performance of models. Specifically, we first train a predictor to estimate the future data distribution, then leverage it to generate training samples, and finally train models on the generated data.
|
||||
|
||||
## Dataset
|
||||
# Dataset
|
||||
The data in the paper are private. So we conduct experiments on Qlib's public dataset.
|
||||
Though the dataset is different, the conclusion remains the same. By applying `DDG-DA`, users can see rising trends at the test phase both in the proxy models' ICs and the performances of the forecasting models.
|
||||
|
||||
## Run the Code
|
||||
# Run the Code
|
||||
Users can try `DDG-DA` by running the following command:
|
||||
```bash
|
||||
python workflow.py run_all
|
||||
@@ -21,7 +24,10 @@ The default forecasting models are `Linear`. Users can choose other forecasting
|
||||
python workflow.py --forecast_model="gbdt" run_all
|
||||
```
|
||||
|
||||
|
||||
## Results
|
||||
|
||||
# Results
|
||||
The results of related methods in Qlib's public dataset can be found [here](../)
|
||||
|
||||
# Requirements
|
||||
Here is the minimal hardware requirements to run the ``workflow.py`` of DDG-DA.
|
||||
* Memory: 45G
|
||||
* Disk: 4G
|
||||
|
||||
@@ -9,13 +9,10 @@ from qlib.data.dataset.handler import DataHandlerLP
|
||||
import pandas as pd
|
||||
import fire
|
||||
import sys
|
||||
from tqdm.auto import tqdm
|
||||
import yaml
|
||||
import pickle
|
||||
from qlib import auto_init
|
||||
from qlib.model.trainer import TrainerR, task_train
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
@@ -47,9 +44,10 @@ class DDGDA:
|
||||
rb = RollingBenchmark(model_type="gbdt")
|
||||
task = rb.basic_task()
|
||||
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
with R.start(experiment_name="feature_importance"):
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
fi = model.get_feature_importance()
|
||||
|
||||
@@ -147,6 +145,9 @@ class DDGDA:
|
||||
},
|
||||
# "record": ["qlib.workflow.record_temp.SignalRecord"]
|
||||
}
|
||||
# the proxy_forecast_model_task will be used to create meta tasks.
|
||||
# The test date of first task will be 2011-01-01. Each test segment will be about 20days
|
||||
# The tasks include all training tasks and test tasks.
|
||||
|
||||
# 2) preparing meta dataset
|
||||
kwargs = dict(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
lightgbm==3.1.0
|
||||
optuna==2.7.0
|
||||
optuna-dashboard==0.4.1
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
@@ -15,6 +15,10 @@ from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
|
||||
import pandas as pd
|
||||
from qlib.contrib.evaluate import backtest_daily
|
||||
from qlib.contrib.evaluate import risk_analysis
|
||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
@@ -30,6 +34,7 @@ class OnlineSimulationExample:
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=None,
|
||||
trainer="TrainerR",
|
||||
):
|
||||
"""
|
||||
Init OnlineManagerExample.
|
||||
@@ -60,7 +65,13 @@ class OnlineSimulationExample:
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
if trainer == "TrainerRM":
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool)
|
||||
elif trainer == "TrainerR":
|
||||
self.trainer = TrainerR(self.exp_name)
|
||||
else:
|
||||
# TODO: support all the trainers: TrainerR, TrainerRM, DelayTrainerR
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
@@ -70,7 +81,8 @@ class OnlineSimulationExample:
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
TaskManager(self.task_pool).remove()
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
@@ -84,7 +96,30 @@ class OnlineSimulationExample:
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
signals = self.rolling_online_manager.get_signals()
|
||||
print(signals)
|
||||
# Backtesting
|
||||
# - the code is based on this example https://qlib.readthedocs.io/en/latest/component/strategy.html
|
||||
CSI300_BENCH = "SH000903"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 30,
|
||||
"n_drop": 3,
|
||||
"signal": signals.to_frame("score"),
|
||||
}
|
||||
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = backtest_daily(
|
||||
start_time=signals.index.get_level_values("datetime").min(),
|
||||
end_time=signals.index.get_level_values("datetime").max(),
|
||||
strategy=strategy_obj,
|
||||
)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
pprint(analysis_df)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
|
||||
52
examples/orderbook_data/README.md
Normal file
52
examples/orderbook_data/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Introduction
|
||||
|
||||
This example tries to demonstrate how Qlib supports data without fixed shared frequency.
|
||||
|
||||
For example,
|
||||
- Daily prices volume data are fixed-frequency data. The data comes in a fixed frequency (i.e. daily)
|
||||
- Orders are not fixed data and they may come at any time point
|
||||
|
||||
To support such non-fixed-frequency, Qlib implements an Arctic-based backend.
|
||||
Here is an example to import and query data based on this backend.
|
||||
|
||||
# Installation
|
||||
|
||||
Please refer to [the installation docs](https://docs.mongodb.com/manual/installation/) of mongodb.
|
||||
Current version of script with default value tries to connect localhost **via default port without authentication**.
|
||||
|
||||
Run following command to install necessary libraries
|
||||
```
|
||||
pip install pytest coverage
|
||||
pip install arctic # NOTE: pip may fail to resolve the right package dependency !!! Please make sure the dependency are satisfied.
|
||||
```
|
||||
|
||||
# Importing example data
|
||||
|
||||
|
||||
1. (Optional) Please follow the first part of [this section](https://github.com/microsoft/qlib#data-preparation) to **get 1min data** of Qlib.
|
||||
2. Please follow following steps to download example data
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
wget http://fintech.msra.cn/stock_data/downloads/highfreq_orderboook_example_data.tar.bz2
|
||||
tar xf highfreq_orderboook_example_data.tar.bz2
|
||||
```
|
||||
|
||||
3. Please import the example data to your mongo db
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
python create_dataset.py initialize_library # Initialization Libraries
|
||||
python create_dataset.py import_data # Initialization Libraries
|
||||
```
|
||||
|
||||
# Query Examples
|
||||
|
||||
After importing these data, you run `example.py` to create some high-frequency features.
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
pytest -s --disable-warnings example.py # If you want run all examples
|
||||
pytest -s --disable-warnings example.py::TestClass::test_exp_10 # If you want to run specific example
|
||||
```
|
||||
|
||||
|
||||
# Known limitations
|
||||
Expression computing between different frequencies are not supported yet
|
||||
315
examples/orderbook_data/create_dataset.py
Executable file
315
examples/orderbook_data/create_dataset.py
Executable file
@@ -0,0 +1,315 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
NOTE:
|
||||
- This scripts is a demo to import example data import Qlib
|
||||
- !!!!!!!!!!!!!!!TODO!!!!!!!!!!!!!!!!!!!:
|
||||
- Its structure is not well designed and very ugly, your contribution is welcome to make importing dataset easier
|
||||
"""
|
||||
from datetime import date, datetime as dt
|
||||
import os
|
||||
from pathlib import Path
|
||||
import random
|
||||
import shutil
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from arctic import Arctic, chunkstore
|
||||
import arctic
|
||||
from arctic import Arctic, CHUNK_STORE
|
||||
from arctic.chunkstore.chunkstore import CHUNK_SIZE
|
||||
import fire
|
||||
from joblib import Parallel, delayed, parallel
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas import DataFrame
|
||||
from pandas.core.indexes.datetimes import date_range
|
||||
from pymongo.mongo_client import MongoClient
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
# CONFIG
|
||||
N_JOBS = -1 # leaving one kernel free
|
||||
LOG_FILE_PATH = DIRNAME / "log_file"
|
||||
DATA_PATH = DIRNAME / "raw_data"
|
||||
DATABASE_PATH = DIRNAME / "orig_data"
|
||||
DATA_INFO_PATH = DIRNAME / "data_info"
|
||||
DATA_FINISH_INFO_PATH = DIRNAME / "./data_finish_info"
|
||||
DOC_TYPE = ["Tick", "Order", "OrderQueue", "Transaction", "Day", "Minute"]
|
||||
MAX_SIZE = 3000 * 1024 * 1024 * 1024
|
||||
ALL_STOCK_PATH = DATABASE_PATH / "all.txt"
|
||||
ARCTIC_SRV = "127.0.0.1"
|
||||
|
||||
|
||||
def get_library_name(doc_type):
|
||||
if str.lower(doc_type) == str.lower("Tick"):
|
||||
return "ticks"
|
||||
else:
|
||||
return str.lower(doc_type)
|
||||
|
||||
|
||||
def is_stock(exchange_place, code):
|
||||
if exchange_place == "SH" and code[0] != "6":
|
||||
return False
|
||||
if exchange_place == "SZ" and code[0] != "0" and code[:2] != "30":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def add_one_stock_daily_data(filepath, type, exchange_place, arc, date):
|
||||
"""
|
||||
exchange_place: "SZ" OR "SH"
|
||||
type: "tick", "orderbook", ...
|
||||
filepath: the path of csv
|
||||
arc: arclink created by a process
|
||||
"""
|
||||
code = os.path.split(filepath)[-1].split(".csv")[0]
|
||||
if exchange_place == "SH" and code[0] != "6":
|
||||
return
|
||||
if exchange_place == "SZ" and code[0] != "0" and code[:2] != "30":
|
||||
return
|
||||
|
||||
df = pd.read_csv(filepath, encoding="gbk", dtype={"code": str})
|
||||
code = os.path.split(filepath)[-1].split(".csv")[0]
|
||||
|
||||
def format_time(day, hms):
|
||||
day = str(day)
|
||||
hms = str(hms)
|
||||
if hms[0] == "1": # >=10,
|
||||
return (
|
||||
"-".join([day[0:4], day[4:6], day[6:8]]) + " " + ":".join([hms[:2], hms[2:4], hms[4:6] + "." + hms[6:]])
|
||||
)
|
||||
else:
|
||||
return (
|
||||
"-".join([day[0:4], day[4:6], day[6:8]]) + " " + ":".join([hms[:1], hms[1:3], hms[3:5] + "." + hms[5:]])
|
||||
)
|
||||
|
||||
## Discard the entire row if wrong data timestamp encoutered.
|
||||
timestamp = list(zip(list(df["date"]), list(df["time"])))
|
||||
error_index_list = []
|
||||
for index, t in enumerate(timestamp):
|
||||
try:
|
||||
pd.Timestamp(format_time(t[0], t[1]))
|
||||
except Exception:
|
||||
error_index_list.append(index) ## The row number of the error line
|
||||
|
||||
# to-do: writting to logs
|
||||
|
||||
if len(error_index_list) > 0:
|
||||
print("error: {}, {}".format(filepath, len(error_index_list)))
|
||||
|
||||
df = df.drop(error_index_list)
|
||||
timestamp = list(zip(list(df["date"]), list(df["time"]))) ## The cleaned timestamp
|
||||
# generate timestamp
|
||||
pd_timestamp = pd.DatetimeIndex(
|
||||
[pd.Timestamp(format_time(timestamp[i][0], timestamp[i][1])) for i in range(len(df["date"]))]
|
||||
)
|
||||
df = df.drop(columns=["date", "time", "name", "code", "wind_code"])
|
||||
# df = pd.DataFrame(data=df.to_dict("list"), index=pd_timestamp)
|
||||
df["date"] = pd.to_datetime(pd_timestamp)
|
||||
df.set_index("date", inplace=True)
|
||||
|
||||
if str.lower(type) == "orderqueue":
|
||||
## extract ab1~ab50
|
||||
df["ab"] = [
|
||||
",".join([str(int(row["ab" + str(i + 1)])) for i in range(0, row["ab_items"])])
|
||||
for timestamp, row in df.iterrows()
|
||||
]
|
||||
df = df.drop(columns=["ab" + str(i) for i in range(1, 51)])
|
||||
|
||||
type = get_library_name(type)
|
||||
# arc.initialize_library(type, lib_type=CHUNK_STORE)
|
||||
lib = arc[type]
|
||||
|
||||
symbol = "".join([exchange_place, code])
|
||||
if symbol in lib.list_symbols():
|
||||
print("update {0}, date={1}".format(symbol, date))
|
||||
if df.empty == True:
|
||||
return error_index_list
|
||||
lib.update(symbol, df, chunk_size="D")
|
||||
else:
|
||||
print("write {0}, date={1}".format(symbol, date))
|
||||
lib.write(symbol, df, chunk_size="D")
|
||||
return error_index_list
|
||||
|
||||
|
||||
def add_one_stock_daily_data_wrapper(filepath, type, exchange_place, index, date):
|
||||
pid = os.getpid()
|
||||
code = os.path.split(filepath)[-1].split(".csv")[0]
|
||||
arc = Arctic(ARCTIC_SRV)
|
||||
try:
|
||||
if index % 100 == 0:
|
||||
print("index = {}, filepath = {}".format(index, filepath))
|
||||
error_index_list = add_one_stock_daily_data(filepath, type, exchange_place, arc, date)
|
||||
if error_index_list is not None and len(error_index_list) > 0:
|
||||
f = open(os.path.join(LOG_FILE_PATH, "temp_timestamp_error_{0}_{1}_{2}.txt".format(pid, date, type)), "a+")
|
||||
f.write("{}, {}, {}\n".format(filepath, error_index_list, exchange_place + "_" + code))
|
||||
f.close()
|
||||
|
||||
except Exception as e:
|
||||
info = traceback.format_exc()
|
||||
print("error:" + str(e))
|
||||
f = open(os.path.join(LOG_FILE_PATH, "temp_fail_{0}_{1}_{2}.txt".format(pid, date, type)), "a+")
|
||||
f.write("fail:" + str(filepath) + "\n" + str(e) + "\n" + str(info) + "\n")
|
||||
f.close()
|
||||
|
||||
finally:
|
||||
arc.reset()
|
||||
|
||||
|
||||
def add_data(tick_date, doc_type, stock_name_dict):
|
||||
pid = os.getpid()
|
||||
|
||||
if doc_type not in DOC_TYPE:
|
||||
print("doc_type not in {}".format(DOC_TYPE))
|
||||
return
|
||||
try:
|
||||
begin_time = time.time()
|
||||
os.system(f"cp {DATABASE_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} {DATA_PATH}/")
|
||||
|
||||
os.system(
|
||||
f"tar -xvzf {DATA_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} -C {DATA_PATH}/ {tick_date + '_' + doc_type}/SH"
|
||||
)
|
||||
os.system(
|
||||
f"tar -xvzf {DATA_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} -C {DATA_PATH}/ {tick_date + '_' + doc_type}/SZ"
|
||||
)
|
||||
os.system(f"chmod 777 {DATA_PATH}")
|
||||
os.system(f"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}")
|
||||
os.system(f"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SH")
|
||||
os.system(f"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SZ")
|
||||
os.system(f"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SH/{tick_date}")
|
||||
os.system(f"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SZ/{tick_date}")
|
||||
|
||||
print("tick_date={}".format(tick_date))
|
||||
|
||||
temp_data_path_sh = os.path.join(DATA_PATH, tick_date + "_" + doc_type, "SH", tick_date)
|
||||
temp_data_path_sz = os.path.join(DATA_PATH, tick_date + "_" + doc_type, "SZ", tick_date)
|
||||
is_files_exist = {"sh": os.path.exists(temp_data_path_sh), "sz": os.path.exists(temp_data_path_sz)}
|
||||
|
||||
sz_files = (
|
||||
(
|
||||
set([i.split(".csv")[0] for i in os.listdir(temp_data_path_sz) if i[:2] == "30" or i[0] == "0"])
|
||||
& set(stock_name_dict["SZ"])
|
||||
)
|
||||
if is_files_exist["sz"]
|
||||
else set()
|
||||
)
|
||||
sz_file_nums = len(sz_files) if is_files_exist["sz"] else 0
|
||||
sh_files = (
|
||||
(
|
||||
set([i.split(".csv")[0] for i in os.listdir(temp_data_path_sh) if i[0] == "6"])
|
||||
& set(stock_name_dict["SH"])
|
||||
)
|
||||
if is_files_exist["sh"]
|
||||
else set()
|
||||
)
|
||||
sh_file_nums = len(sh_files) if is_files_exist["sh"] else 0
|
||||
print("sz_file_nums:{}, sh_file_nums:{}".format(sz_file_nums, sh_file_nums))
|
||||
|
||||
f = (DATA_INFO_PATH / "data_info_log_{}_{}".format(doc_type, tick_date)).open("w+")
|
||||
f.write("sz:{}, sh:{}, date:{}:".format(sz_file_nums, sh_file_nums, tick_date) + "\n")
|
||||
f.close()
|
||||
|
||||
if sh_file_nums > 0:
|
||||
# write is not thread-safe, update may be thread-safe
|
||||
Parallel(n_jobs=N_JOBS)(
|
||||
delayed(add_one_stock_daily_data_wrapper)(
|
||||
os.path.join(temp_data_path_sh, name + ".csv"), doc_type, "SH", index, tick_date
|
||||
)
|
||||
for index, name in enumerate(list(sh_files))
|
||||
)
|
||||
if sz_file_nums > 0:
|
||||
# write is not thread-safe, update may be thread-safe
|
||||
Parallel(n_jobs=N_JOBS)(
|
||||
delayed(add_one_stock_daily_data_wrapper)(
|
||||
os.path.join(temp_data_path_sz, name + ".csv"), doc_type, "SZ", index, tick_date
|
||||
)
|
||||
for index, name in enumerate(list(sz_files))
|
||||
)
|
||||
|
||||
os.system(f"rm -f {DATA_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)}")
|
||||
os.system(f"rm -rf {DATA_PATH}/{tick_date + '_' + doc_type}")
|
||||
total_time = time.time() - begin_time
|
||||
f = (DATA_FINISH_INFO_PATH / "data_info_finish_log_{}_{}".format(doc_type, tick_date)).open("w+")
|
||||
f.write("finish: date:{}, consume_time:{}, end_time: {}".format(tick_date, total_time, time.time()) + "\n")
|
||||
f.close()
|
||||
|
||||
except Exception as e:
|
||||
info = traceback.format_exc()
|
||||
print("date error:" + str(e))
|
||||
f = open(os.path.join(LOG_FILE_PATH, "temp_fail_{0}_{1}_{2}.txt".format(pid, tick_date, doc_type)), "a+")
|
||||
f.write("fail:" + str(tick_date) + "\n" + str(e) + "\n" + str(info) + "\n")
|
||||
f.close()
|
||||
|
||||
|
||||
class DSCreator:
|
||||
"""Dataset creator"""
|
||||
|
||||
def clear(self):
|
||||
client = MongoClient(ARCTIC_SRV)
|
||||
client.drop_database("arctic")
|
||||
|
||||
def initialize_library(self):
|
||||
arc = Arctic(ARCTIC_SRV)
|
||||
for doc_type in DOC_TYPE:
|
||||
arc.initialize_library(get_library_name(doc_type), lib_type=CHUNK_STORE)
|
||||
|
||||
def _get_empty_folder(self, fp: Path):
|
||||
fp = Path(fp)
|
||||
if fp.exists():
|
||||
shutil.rmtree(fp)
|
||||
fp.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def import_data(self, doc_type_l=["Tick", "Transaction", "Order"]):
|
||||
# clear all the old files
|
||||
for fp in LOG_FILE_PATH, DATA_INFO_PATH, DATA_FINISH_INFO_PATH, DATA_PATH:
|
||||
self._get_empty_folder(fp)
|
||||
|
||||
arc = Arctic(ARCTIC_SRV)
|
||||
for doc_type in DOC_TYPE:
|
||||
# arc.initialize_library(get_library_name(doc_type), lib_type=CHUNK_STORE)
|
||||
arc.set_quota(get_library_name(doc_type), MAX_SIZE)
|
||||
arc.reset()
|
||||
|
||||
# doc_type = 'Day'
|
||||
for doc_type in doc_type_l:
|
||||
date_list = list(set([int(path.split("_")[0]) for path in os.listdir(DATABASE_PATH) if doc_type in path]))
|
||||
date_list.sort()
|
||||
date_list = [str(date) for date in date_list]
|
||||
|
||||
f = open(ALL_STOCK_PATH, "r")
|
||||
stock_name_list = [lines.split("\t")[0] for lines in f.readlines()]
|
||||
f.close()
|
||||
stock_name_dict = {
|
||||
"SH": [stock_name[2:] for stock_name in stock_name_list if "SH" in stock_name],
|
||||
"SZ": [stock_name[2:] for stock_name in stock_name_list if "SZ" in stock_name],
|
||||
}
|
||||
|
||||
lib_name = get_library_name(doc_type)
|
||||
a = Arctic(ARCTIC_SRV)
|
||||
# a.initialize_library(lib_name, lib_type=CHUNK_STORE)
|
||||
|
||||
stock_name_exist = a[lib_name].list_symbols()
|
||||
lib = a[lib_name]
|
||||
initialize_count = 0
|
||||
for stock_name in stock_name_list:
|
||||
if stock_name not in stock_name_exist:
|
||||
initialize_count += 1
|
||||
# A placeholder for stocks
|
||||
pdf = pd.DataFrame(index=[pd.Timestamp("1900-01-01")])
|
||||
pdf.index.name = "date" # an col named date is necessary
|
||||
lib.write(stock_name, pdf)
|
||||
print("initialize count: {}".format(initialize_count))
|
||||
print("tasks: {}".format(date_list))
|
||||
a.reset()
|
||||
|
||||
# date_list = [files.split("_")[0] for files in os.listdir("./raw_data_price") if "tar" in files]
|
||||
# print(len(date_list))
|
||||
date_list = ["20201231"] # for test
|
||||
Parallel(n_jobs=min(2, len(date_list)))(
|
||||
delayed(add_data)(date, doc_type, stock_name_dict) for date in date_list
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(DSCreator)
|
||||
312
examples/orderbook_data/example.py
Normal file
312
examples/orderbook_data/example.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from arctic.arctic import Arctic
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
import unittest
|
||||
|
||||
|
||||
class TestClass(unittest.TestCase):
|
||||
"""
|
||||
Useful commands
|
||||
- run all tests: pytest examples/orderbook_data/example.py
|
||||
- run a single test: pytest -s --pdb --disable-warnings examples/orderbook_data/example.py::TestClass::test_basic01
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Configure for arctic
|
||||
"""
|
||||
provider_uri = "~/.qlib/qlib_data/yahoo_cn_1min"
|
||||
qlib.init(
|
||||
provider_uri=provider_uri,
|
||||
mem_cache_size_limit=1024**3 * 2,
|
||||
mem_cache_type="sizeof",
|
||||
kernels=1,
|
||||
expression_provider={"class": "LocalExpressionProvider", "kwargs": {"time2idx": False}},
|
||||
feature_provider={
|
||||
"class": "ArcticFeatureProvider",
|
||||
"module_path": "qlib.contrib.data.data",
|
||||
"kwargs": {"uri": "127.0.0.1"},
|
||||
},
|
||||
dataset_provider={
|
||||
"class": "LocalDatasetProvider",
|
||||
"kwargs": {
|
||||
"align_time": False, # Order book is not fixed, so it can't be align to a shared fixed frequency calendar
|
||||
},
|
||||
},
|
||||
)
|
||||
# self.stocks_list = ["SH600519"]
|
||||
self.stocks_list = ["SZ000725"]
|
||||
|
||||
def test_basic(self):
|
||||
# NOTE: this data contains a lot of zeros in $askX and $bidX
|
||||
df = D.features(
|
||||
self.stocks_list,
|
||||
fields=["$ask1", "$ask2", "$bid1", "$bid2"],
|
||||
freq="ticks",
|
||||
start_time="20201230",
|
||||
end_time="20210101",
|
||||
)
|
||||
print(df)
|
||||
|
||||
def test_basic_without_time(self):
|
||||
df = D.features(self.stocks_list, fields=["$ask1"], freq="ticks")
|
||||
print(df)
|
||||
|
||||
def test_basic01(self):
|
||||
df = D.features(
|
||||
self.stocks_list,
|
||||
fields=["TResample($ask1, '1min', 'last')"],
|
||||
freq="ticks",
|
||||
start_time="20201230",
|
||||
end_time="20210101",
|
||||
)
|
||||
print(df)
|
||||
|
||||
def test_basic02(self):
|
||||
df = D.features(
|
||||
self.stocks_list,
|
||||
fields=["$function_code"],
|
||||
freq="transaction",
|
||||
start_time="20201230",
|
||||
end_time="20210101",
|
||||
)
|
||||
print(df)
|
||||
|
||||
def test_basic03(self):
|
||||
df = D.features(
|
||||
self.stocks_list,
|
||||
fields=["$function_code"],
|
||||
freq="order",
|
||||
start_time="20201230",
|
||||
end_time="20210101",
|
||||
)
|
||||
print(df)
|
||||
|
||||
# Here are some popular expressions for high-frequency
|
||||
# 1) some shared expression
|
||||
expr_sum_buy_ask_1 = "(TResample($ask1, '1min', 'last') + TResample($bid1, '1min', 'last'))"
|
||||
total_volume = (
|
||||
"TResample("
|
||||
+ "+".join([f"${name}{i}" for i in range(1, 11) for name in ["asize", "bsize"]])
|
||||
+ ", '1min', 'sum')"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def total_func(name, method):
|
||||
return "TResample(" + "+".join([f"${name}{i}" for i in range(1, 11)]) + ",'1min', '{}')".format(method)
|
||||
|
||||
def test_exp_01(self):
|
||||
exprs = []
|
||||
names = []
|
||||
for name in ["asize", "bsize"]:
|
||||
for i in range(1, 11):
|
||||
exprs.append(f"TResample(${name}{i}, '1min', 'mean') / ({self.total_volume})")
|
||||
names.append(f"v_{name}_{i}")
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="ticks")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
# 2) some often used papers;
|
||||
def test_exp_02(self):
|
||||
spread_func = (
|
||||
lambda index: f"2 * TResample($ask{index} - $bid{index}, '1min', 'last') / {self.expr_sum_buy_ask_1}"
|
||||
)
|
||||
mid_func = (
|
||||
lambda index: f"2 * TResample(($ask{index} + $bid{index})/2, '1min', 'last') / {self.expr_sum_buy_ask_1}"
|
||||
)
|
||||
|
||||
exprs = []
|
||||
names = []
|
||||
for i in range(1, 11):
|
||||
exprs.extend([spread_func(i), mid_func(i)])
|
||||
names.extend([f"p_spread_{i}", f"p_mid_{i}"])
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="ticks")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
def test_exp_03(self):
|
||||
expr3_func1 = (
|
||||
lambda name, index_left, index_right: f"2 * TResample(Abs(${name}{index_left} - ${name}{index_right}), '1min', 'last') / {self.expr_sum_buy_ask_1}"
|
||||
)
|
||||
for name in ["ask", "bid"]:
|
||||
for i in range(1, 10):
|
||||
exprs = [expr3_func1(name, i + 1, i)]
|
||||
names = [f"p_diff_{name}_{i}_{i+1}"]
|
||||
exprs.extend([expr3_func1("ask", 10, 1), expr3_func1("bid", 1, 10)])
|
||||
names.extend(["p_diff_ask_10_1", "p_diff_bid_1_10"])
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="ticks")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
def test_exp_04(self):
|
||||
exprs = []
|
||||
names = []
|
||||
for name in ["asize", "bsize"]:
|
||||
exprs.append(f"(({ self.total_func(name, 'mean')}) / 10) / {self.total_volume}")
|
||||
names.append(f"v_avg_{name}")
|
||||
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="ticks")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
def test_exp_05(self):
|
||||
exprs = [
|
||||
f"2 * Sub({ self.total_func('ask', 'last')}, {self.total_func('bid', 'last')})/{self.expr_sum_buy_ask_1}",
|
||||
f"Sub({ self.total_func('asize', 'mean')}, {self.total_func('bsize', 'mean')})/{self.total_volume}",
|
||||
]
|
||||
names = ["p_accspread", "v_accspread"]
|
||||
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="ticks")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
# (p|v)_diff_(ask|bid|asize|bsize)_(time_interval)
|
||||
def test_exp_06(self):
|
||||
t = 3
|
||||
expr6_price_func = (
|
||||
lambda name, index, method: f'2 * (TResample(${name}{index}, "{t}s", "{method}") - Ref(TResample(${name}{index}, "{t}s", "{method}"), 1)) / {t}'
|
||||
)
|
||||
exprs = []
|
||||
names = []
|
||||
for i in range(1, 11):
|
||||
for name in ["bid", "ask"]:
|
||||
exprs.append(
|
||||
f"TResample({expr6_price_func(name, i, 'last')}, '1min', 'mean') / {self.expr_sum_buy_ask_1}"
|
||||
)
|
||||
names.append(f"p_diff_{name}{i}_{t}s")
|
||||
|
||||
for i in range(1, 11):
|
||||
for name in ["asize", "bsize"]:
|
||||
exprs.append(f"TResample({expr6_price_func(name, i, 'mean')}, '1min', 'mean') / {self.total_volume}")
|
||||
names.append(f"v_diff_{name}{i}_{t}s")
|
||||
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="ticks")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
# TODOs:
|
||||
# Following expressions may be implemented in the future
|
||||
# expr7_2 = lambda funccode, bsflag, time_interval: \
|
||||
# "TResample(TRolling(TEq(@transaction.function_code, {}) & TEq(@transaction.bs_flag ,{}), '{}s', 'sum') / \
|
||||
# TRolling(@transaction.function_code, '{}s', 'count') , '1min', 'mean')".format(ord(funccode), bsflag,time_interval,time_interval)
|
||||
# create_dataset(7, "SH600000", [expr7_2("C")] + [expr7(funccode, ordercode) for funccode in ['B','S'] for ordercode in ['0','1']])
|
||||
# create_dataset(7, ["SH600000"], [expr7_2("C", 48)] )
|
||||
|
||||
@staticmethod
|
||||
def expr7_init(funccode, ordercode, time_interval):
|
||||
# NOTE: based on on order frequency (i.e. freq="order")
|
||||
return f"Rolling(Eq($function_code, {ord(funccode)}) & Eq($order_kind ,{ord(ordercode)}), '{time_interval}s', 'sum') / Rolling($function_code, '{time_interval}s', 'count')"
|
||||
|
||||
# (la|lb|ma|mb|ca|cb)_intensity_(time_interval)
|
||||
def test_exp_07_1(self):
|
||||
# NOTE: based on transaction frequency (i.e. freq="transaction")
|
||||
expr7_3 = (
|
||||
lambda funccode, code, time_interval: f"TResample(Rolling(Eq($function_code, {ord(funccode)}) & {code}($ask_order, $bid_order) , '{time_interval}s', 'sum') / Rolling($function_code, '{time_interval}s', 'count') , '1min', 'mean')"
|
||||
)
|
||||
|
||||
exprs = [expr7_3("C", "Gt", "3"), expr7_3("C", "Lt", "3")]
|
||||
names = ["ca_intensity_3s", "cb_intensity_3s"]
|
||||
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="transaction")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
trans_dict = {"B": "a", "S": "b", "0": "l", "1": "m"}
|
||||
|
||||
def test_exp_07_2(self):
|
||||
# NOTE: based on on order frequency
|
||||
expr7 = (
|
||||
lambda funccode, ordercode, time_interval: f"TResample({self.expr7_init(funccode, ordercode, time_interval)}, '1min', 'mean')"
|
||||
)
|
||||
|
||||
exprs = []
|
||||
names = []
|
||||
for funccode in ["B", "S"]:
|
||||
for ordercode in ["0", "1"]:
|
||||
exprs.append(expr7(funccode, ordercode, "3"))
|
||||
names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + "_intensity_3s")
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="transaction")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
@staticmethod
|
||||
def expr7_3_init(funccode, code, time_interval):
|
||||
# NOTE: It depends on transaction frequency
|
||||
return f"Rolling(Eq($function_code, {ord(funccode)}) & {code}($ask_order, $bid_order) , '{time_interval}s', 'sum') / Rolling($function_code, '{time_interval}s', 'count')"
|
||||
|
||||
# (la|lb|ma|mb|ca|cb)_relative_intensity_(time_interval_small)_(time_interval_big)
|
||||
def test_exp_08_1(self):
|
||||
expr8_1 = (
|
||||
lambda funccode, ordercode, time_interval_short, time_interval_long: f"TResample(Gt({self.expr7_init(funccode, ordercode, time_interval_short)},{self.expr7_init(funccode, ordercode, time_interval_long)}), '1min', 'mean')"
|
||||
)
|
||||
|
||||
exprs = []
|
||||
names = []
|
||||
for funccode in ["B", "S"]:
|
||||
for ordercode in ["0", "1"]:
|
||||
exprs.append(expr8_1(funccode, ordercode, "10", "900"))
|
||||
names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + "_relative_intensity_10s_900s")
|
||||
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="order")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
def test_exp_08_2(self):
|
||||
# NOTE: It depends on transaction frequency
|
||||
expr8_2 = (
|
||||
lambda funccode, ordercode, time_interval_short, time_interval_long: f"TResample(Gt({self.expr7_3_init(funccode, ordercode, time_interval_short)},{self.expr7_3_init(funccode, ordercode, time_interval_long)}), '1min', 'mean')"
|
||||
)
|
||||
|
||||
exprs = [expr8_2("C", "Gt", "10", "900"), expr8_2("C", "Lt", "10", "900")]
|
||||
names = ["ca_relative_intensity_10s_900s", "cb_relative_intensity_10s_900s"]
|
||||
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="transaction")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
## v9(la|lb|ma|mb|ca|cb)_diff_intensity_(time_interval1)_(time_interval2)
|
||||
# 1) calculating the original data
|
||||
# 2) Resample data to 3s and calculate the changing rate
|
||||
# 3) Resample data to 1min
|
||||
|
||||
def test_exp_09_trans(self):
|
||||
exprs = [
|
||||
f'TResample(Div(Sub(TResample({self.expr7_3_init("C", "Gt", "3")}, "3s", "last"), Ref(TResample({self.expr7_3_init("C", "Gt", "3")}, "3s","last"), 1)), 3), "1min", "mean")',
|
||||
f'TResample(Div(Sub(TResample({self.expr7_3_init("C", "Lt", "3")}, "3s", "last"), Ref(TResample({self.expr7_3_init("C", "Lt", "3")}, "3s","last"), 1)), 3), "1min", "mean")',
|
||||
]
|
||||
names = ["ca_diff_intensity_3s_3s", "cb_diff_intensity_3s_3s"]
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="transaction")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
def test_exp_09_order(self):
|
||||
exprs = []
|
||||
names = []
|
||||
for funccode in ["B", "S"]:
|
||||
for ordercode in ["0", "1"]:
|
||||
exprs.append(
|
||||
f'TResample(Div(Sub(TResample({self.expr7_init(funccode, ordercode, "3")}, "3s", "last"), Ref(TResample({self.expr7_init(funccode, ordercode, "3")},"3s", "last"), 1)), 3) ,"1min", "mean")'
|
||||
)
|
||||
names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + "_diff_intensity_3s_3s")
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="order")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
def test_exp_10(self):
|
||||
exprs = []
|
||||
names = []
|
||||
for i in [5, 10, 30, 60]:
|
||||
exprs.append(
|
||||
f'TResample(Ref(TResample($ask1 + $bid1, "1s", "ffill"), {-i}) / TResample($ask1 + $bid1, "1s", "ffill") - 1, "1min", "mean" )'
|
||||
)
|
||||
names.append(f"lag_{i}_change_rate" for i in [5, 10, 30, 60])
|
||||
df = D.features(self.stocks_list, fields=exprs, freq="ticks")
|
||||
df.columns = names
|
||||
print(df)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -24,6 +24,7 @@ We use China stock market data for our example.
|
||||
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
|
||||
rm -f csi300_weight.zip
|
||||
```
|
||||
NOTE: We don't find any public free resource to get the weight in the benchmark. To run the example, we manually create this weight data.
|
||||
|
||||
2. Prepare risk model data:
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ def gen_and_save_md_table(metrics, dataset):
|
||||
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
||||
def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
|
||||
with open(yaml_path, "r") as fp:
|
||||
config = yaml.load(fp)
|
||||
config = yaml.safe_load(fp)
|
||||
try:
|
||||
del config["task"]["model"]["kwargs"]["seed"]
|
||||
except KeyError:
|
||||
|
||||
1218
examples/tutorial/detailed_workflow.ipynb
Normal file
1218
examples/tutorial/detailed_workflow.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -256,7 +256,6 @@
|
||||
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"print(recorder)\n",
|
||||
"pred_df = recorder.load_object(\"pred.pkl\")\n",
|
||||
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
|
||||
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n",
|
||||
"positions = recorder.load_object(\"portfolio_analysis/positions_normal_1day.pkl\")\n",
|
||||
"analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis_1day.pkl\")"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.8.1"
|
||||
__version__ = "0.8.5.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
@@ -31,8 +31,8 @@ def init(default_conf="client", **kwargs):
|
||||
When using the recorder, skip_if_reg can set to True to avoid loss of recorder.
|
||||
|
||||
"""
|
||||
from .config import C
|
||||
from .data.cache import H
|
||||
from .config import C # pylint: disable=C0415
|
||||
from .data.cache import H # pylint: disable=C0415
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
logger = get_module_logger("Initialization", level=logging.INFO)
|
||||
@@ -63,7 +63,7 @@ def init(default_conf="client", **kwargs):
|
||||
else:
|
||||
logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
|
||||
elif uri_type == C.NFS_URI:
|
||||
_mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
|
||||
_mount_nfs_uri(provider_uri, C.dpm.get_data_uri(_freq), C["auto_mount"])
|
||||
else:
|
||||
raise NotImplementedError(f"This type of URI is not supported")
|
||||
|
||||
@@ -86,7 +86,7 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
|
||||
# If the provider uri looks like this 172.23.233.89//data/csdesign'
|
||||
# It will be a nfs path. The client provider will be used
|
||||
if not auto_mount:
|
||||
if not auto_mount: # pylint: disable=R1702
|
||||
if not Path(mount_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
@@ -96,7 +96,7 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
sys_type = platform.system()
|
||||
if "win" in sys_type.lower():
|
||||
# system: window
|
||||
exec_result = os.popen("mount -o anon %s %s" % (provider_uri, mount_path + ":"))
|
||||
exec_result = os.popen(f"mount -o anon {provider_uri} {mount_path}")
|
||||
result = exec_result.read()
|
||||
if "85" in result:
|
||||
LOG.warning(f"{provider_uri} on Windows:{mount_path} is already mounted")
|
||||
@@ -140,8 +140,10 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
if not _is_mount:
|
||||
try:
|
||||
Path(mount_path).mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
|
||||
except Exception as e:
|
||||
raise OSError(
|
||||
f"Failed to create directory {mount_path}, please create {mount_path} manually!"
|
||||
) from e
|
||||
|
||||
# check nfs-common
|
||||
command_res = os.popen("dpkg -l | grep nfs-common")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
from typing import List, Tuple, Union, TYPE_CHECKING
|
||||
@@ -171,8 +172,8 @@ def get_strategy_executor(
|
||||
# NOTE:
|
||||
# - for avoiding recursive import
|
||||
# - typing annotations is not reliable
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
from ..strategy.base import BaseStrategy # pylint: disable=C0415
|
||||
from .executor import BaseExecutor # pylint: disable=C0415
|
||||
|
||||
trade_account = create_account_instance(
|
||||
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
|
||||
@@ -323,3 +324,6 @@ def format_decisions(
|
||||
last_dec_idx = i
|
||||
res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 :])))
|
||||
return res
|
||||
|
||||
|
||||
__all__ = ["Order"]
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
from typing import Dict, List, Tuple, TYPE_CHECKING
|
||||
from typing import Dict, List, Tuple
|
||||
from qlib.utils import init_instance_by_config
|
||||
import pandas as pd
|
||||
|
||||
from .position import BasePosition, InfPosition, Position
|
||||
from .position import BasePosition
|
||||
from .report import PortfolioMetrics, Indicator
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
|
||||
@@ -7,19 +7,18 @@ from qlib.data.data import Cal
|
||||
from qlib.utils.time import concat_date_time, epsilon_change
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
from typing import ClassVar, Optional, Union, List, Tuple
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar, Optional, Union, List, Set, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class OrderDir(IntEnum):
|
||||
@@ -418,7 +417,7 @@ class BaseTradeDecision:
|
||||
return kwargs["default_value"]
|
||||
else:
|
||||
# Default to get full index
|
||||
raise NotImplementedError(f"The decision didn't provide an index range")
|
||||
raise NotImplementedError(f"The decision didn't provide an index range") from NotImplementedError
|
||||
|
||||
# clip index
|
||||
if getattr(self, "total_step", None) is not None:
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .account import Account
|
||||
|
||||
from qlib.backtest.position import BasePosition, Position
|
||||
import random
|
||||
from typing import List, Tuple, Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..config import C
|
||||
from ..constant import REG_CN
|
||||
from ..log import get_module_logger
|
||||
from .decision import Order, OrderDir, OrderHelper
|
||||
from .high_performance_ds import BaseQuote, PandasQuote, NumpyQuote
|
||||
from .high_performance_ds import BaseQuote, NumpyQuote
|
||||
|
||||
|
||||
class Exchange:
|
||||
|
||||
@@ -1,22 +1,18 @@
|
||||
from abc import abstractclassmethod, abstractmethod
|
||||
from abc import abstractmethod
|
||||
import copy
|
||||
from qlib.backtest.position import BasePosition
|
||||
from qlib.log import get_module_logger
|
||||
from types import GeneratorType
|
||||
from qlib.backtest.account import Account
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Union
|
||||
from collections import defaultdict
|
||||
|
||||
from qlib.backtest.report import Indicator
|
||||
|
||||
from .decision import EmptyTradeDecision, Order, BaseTradeDecision
|
||||
from .decision import Order, BaseTradeDecision
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
|
||||
|
||||
@@ -193,7 +189,8 @@ class BaseExecutor:
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
|
||||
@abstractclassmethod
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
"""
|
||||
Please refer to the doc of collect_data
|
||||
@@ -245,7 +242,7 @@ class BaseExecutor:
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
|
||||
atomic = not issubclass(self.__class__, NestedExecutor) # issubclass(A, A) is True
|
||||
atomic = not issubclass(self.__class__, NestedExecutor) # issubclass(A, A) is True
|
||||
|
||||
if atomic and trade_decision.get_range_limit(default_value=None) is not None:
|
||||
raise ValueError("atomic executor doesn't support specify `range_limit`")
|
||||
@@ -453,7 +450,6 @@ class NestedExecutor(BaseExecutor):
|
||||
inner_exe_res :
|
||||
the execution result of inner task
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_all_executors(self):
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
@@ -362,7 +360,9 @@ class Position(BasePosition):
|
||||
# check if to delete
|
||||
if self.position[stock_id]["amount"] < -1e-5:
|
||||
raise ValueError(
|
||||
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
|
||||
"only have {} {}, require {}".format(
|
||||
self.position[stock_id]["amount"] + trade_amount, stock_id, trade_amount
|
||||
)
|
||||
)
|
||||
|
||||
new_cash = trade_val - cost
|
||||
@@ -538,7 +538,7 @@ class InfPosition(BasePosition):
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool) -> Dict:
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
|
||||
@@ -10,11 +10,8 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from .decision import IdxTradeRange
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from ..data import D
|
||||
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
import qlib.utils.index_data as idd
|
||||
|
||||
@@ -28,7 +28,6 @@ class Signal(metaclass=abc.ABCMeta):
|
||||
Union[pd.Series, pd.DataFrame, None]:
|
||||
returns None if no signal in the specific day
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SignalWCache(Signal):
|
||||
|
||||
103
qlib/config.py
103
qlib/config.py
@@ -19,10 +19,10 @@ import logging
|
||||
import platform
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from qlib.constant import REG_CN, REG_US
|
||||
from qlib.constant import REG_CN, REG_US, REG_TW
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.utils.time import Freq
|
||||
@@ -40,7 +40,7 @@ class Config:
|
||||
if attr in self.__dict__["_config"]:
|
||||
return self.__dict__["_config"][attr]
|
||||
|
||||
raise AttributeError(f"No such {attr} in self._config")
|
||||
raise AttributeError(f"No such `{attr}` in self._config")
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.__dict__["_config"].get(key, default)
|
||||
@@ -92,6 +92,7 @@ _default_config = {
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
"instrument_provider": "LocalInstrumentProvider",
|
||||
"feature_provider": "LocalFeatureProvider",
|
||||
"pit_provider": "LocalPITProvider",
|
||||
"expression_provider": "LocalExpressionProvider",
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
@@ -108,10 +109,11 @@ _default_config = {
|
||||
"provider_uri": "",
|
||||
# cache
|
||||
"expression_cache": None,
|
||||
"dataset_cache": None,
|
||||
"calendar_cache": None,
|
||||
# for simple dataset cache
|
||||
"local_cache_path": None,
|
||||
# kernels can be a fixed value or a callable function lie `def (freq: str) -> int`
|
||||
# If the kernels are arctic_kernels, `min(NUM_USABLE_CPU, 30)` may be a good value
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# pickle.dump protocol version
|
||||
"dump_protocol_version": PROTOCOL_VERSION,
|
||||
@@ -121,11 +123,10 @@ _default_config = {
|
||||
"joblib_backend": "multiprocessing",
|
||||
"default_disk_cache": 1, # 0:skip/1:use
|
||||
"mem_cache_size_limit": 500,
|
||||
"mem_cache_limit_type": "length",
|
||||
# memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar'
|
||||
# default 1 hour
|
||||
"mem_cache_expire": 60 * 60,
|
||||
# memory cache space limit, default 5GB, only in used client
|
||||
"mem_cache_space_limit": 1024 * 1024 * 1024 * 5,
|
||||
# cache dir name
|
||||
"dataset_cache_dir_name": "dataset_cache",
|
||||
"features_cache_dir_name": "features_cache",
|
||||
@@ -170,6 +171,18 @@ _default_config = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
},
|
||||
"pit_record_type": {
|
||||
"date": "I", # uint32
|
||||
"period": "I", # uint32
|
||||
"value": "d", # float64
|
||||
"index": "I", # uint32
|
||||
},
|
||||
"pit_record_nan": {
|
||||
"date": 0,
|
||||
"period": 0,
|
||||
"value": float("NAN"),
|
||||
"index": 0xFFFFFFFF,
|
||||
},
|
||||
# Default config for MongoDB
|
||||
"mongo": {
|
||||
"task_url": "mongodb://localhost:27017/",
|
||||
@@ -183,20 +196,12 @@ _default_config = {
|
||||
|
||||
MODE_CONF = {
|
||||
"server": {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
"instrument_provider": "LocalInstrumentProvider",
|
||||
"feature_provider": "LocalFeatureProvider",
|
||||
"expression_provider": "LocalExpressionProvider",
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
# config it in qlib.init()
|
||||
"provider_uri": "",
|
||||
# redis
|
||||
"redis_host": "127.0.0.1",
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# cache
|
||||
"expression_cache": DISK_EXPRESSION_CACHE,
|
||||
"dataset_cache": DISK_DATASET_CACHE,
|
||||
@@ -204,24 +209,15 @@ MODE_CONF = {
|
||||
"mount_path": None,
|
||||
},
|
||||
"client": {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
"instrument_provider": "LocalInstrumentProvider",
|
||||
"feature_provider": "LocalFeatureProvider",
|
||||
"expression_provider": "LocalExpressionProvider",
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
# config it in user's own code
|
||||
"provider_uri": "~/.qlib/qlib_data/cn_data",
|
||||
# cache
|
||||
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
|
||||
"expression_cache": DISK_EXPRESSION_CACHE,
|
||||
"dataset_cache": DISK_DATASET_CACHE,
|
||||
# Disable cache by default. Avoid introduce advanced features for beginners
|
||||
"dataset_cache": None,
|
||||
# SimpleDatasetCache directory
|
||||
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
|
||||
"calendar_cache": None,
|
||||
# client config
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
"mount_path": None,
|
||||
"auto_mount": False, # The nfs is already mounted on our server[auto_mount: False].
|
||||
# The nfs should be auto-mounted by qlib on other
|
||||
@@ -255,6 +251,11 @@ _default_region_config = {
|
||||
"limit_threshold": None,
|
||||
"deal_price": "close",
|
||||
},
|
||||
REG_TW: {
|
||||
"trade_unit": 1000,
|
||||
"limit_threshold": 0.1,
|
||||
"deal_price": "close",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -269,11 +270,19 @@ class QlibConfig(Config):
|
||||
self._registered = False
|
||||
|
||||
class DataPathManager:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri: Union[str, Path, dict],
|
||||
mount_path: Union[str, Path, dict],
|
||||
):
|
||||
"""
|
||||
Motivation:
|
||||
- get the right path (e.g. data uri) for accessing data based on given information(e.g. provider_uri, mount_path and frequency)
|
||||
- some helper functions to process uri.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]):
|
||||
|
||||
"""
|
||||
The relation of `provider_uri` and `mount_path`
|
||||
- `mount_path` is used only if provider_uri is an NFS path
|
||||
- otherwise, provider_uri will be used for accessing data
|
||||
"""
|
||||
self.provider_uri = provider_uri
|
||||
self.mount_path = mount_path
|
||||
|
||||
@@ -304,6 +313,9 @@ class QlibConfig(Config):
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def get_data_uri(self, freq: Optional[Union[str, Freq]] = None) -> Path:
|
||||
"""
|
||||
please refer DataPathManager's __init__ and class doc
|
||||
"""
|
||||
if freq is not None:
|
||||
freq = str(freq) # converting Freq to string
|
||||
if freq is None or freq not in self.provider_uri:
|
||||
@@ -314,7 +326,8 @@ class QlibConfig(Config):
|
||||
elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
|
||||
if "win" in platform.system().lower():
|
||||
# windows, mount_path is the drive
|
||||
return Path(f"{self.mount_path[freq]}:\\")
|
||||
_path = str(self.mount_path[freq])
|
||||
return Path(f"{_path}:\\") if ":" not in _path else Path(_path)
|
||||
return Path(self.mount_path[freq])
|
||||
else:
|
||||
raise NotImplementedError(f"This type of uri is not supported")
|
||||
@@ -351,9 +364,7 @@ class QlibConfig(Config):
|
||||
for _freq in _provider_uri.keys():
|
||||
# mount_path
|
||||
_mount_path[_freq] = (
|
||||
_mount_path[_freq]
|
||||
if _mount_path[_freq] is None
|
||||
else str(Path(_mount_path[_freq]).expanduser().resolve())
|
||||
_mount_path[_freq] if _mount_path[_freq] is None else str(Path(_mount_path[_freq]).expanduser())
|
||||
)
|
||||
self["provider_uri"] = _provider_uri
|
||||
self["mount_path"] = _mount_path
|
||||
@@ -376,13 +387,11 @@ class QlibConfig(Config):
|
||||
default_conf : str
|
||||
the default config template chosen by user: "server", "client"
|
||||
"""
|
||||
from .utils import set_log_with_config, get_module_logger, can_use_cache
|
||||
from .utils import set_log_with_config, get_module_logger, can_use_cache # pylint: disable=C0415
|
||||
|
||||
self.reset()
|
||||
|
||||
_logging_config = self.logging_config
|
||||
if "logging_config" in kwargs:
|
||||
_logging_config = kwargs["logging_config"]
|
||||
_logging_config = kwargs.get("logging_config", self.logging_config)
|
||||
|
||||
# set global config
|
||||
if _logging_config:
|
||||
@@ -421,11 +430,11 @@ class QlibConfig(Config):
|
||||
)
|
||||
|
||||
def register(self):
|
||||
from .utils import init_instance_by_config
|
||||
from .data.ops import register_all_ops
|
||||
from .data.data import register_all_wrappers
|
||||
from .workflow import R, QlibRecorder
|
||||
from .workflow.utils import experiment_exit_handler
|
||||
from .utils import init_instance_by_config # pylint: disable=C0415
|
||||
from .data.ops import register_all_ops # pylint: disable=C0415
|
||||
from .data.data import register_all_wrappers # pylint: disable=C0415
|
||||
from .workflow import R, QlibRecorder # pylint: disable=C0415
|
||||
from .workflow.utils import experiment_exit_handler # pylint: disable=C0415
|
||||
|
||||
register_all_ops(self)
|
||||
register_all_wrappers(self)
|
||||
@@ -442,7 +451,7 @@ class QlibConfig(Config):
|
||||
self._registered = True
|
||||
|
||||
def reset_qlib_version(self):
|
||||
import qlib
|
||||
import qlib # pylint: disable=C0415
|
||||
|
||||
reset_version = self.get("qlib_reset_version", None)
|
||||
if reset_version is not None:
|
||||
@@ -452,6 +461,12 @@ class QlibConfig(Config):
|
||||
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
|
||||
# Using __version__bak instead of __version__
|
||||
|
||||
def get_kernels(self, freq: str):
|
||||
"""get number of processors given frequency"""
|
||||
if isinstance(self["kernels"], Callable):
|
||||
return self["kernels"](freq)
|
||||
return self["kernels"]
|
||||
|
||||
@property
|
||||
def registered(self):
|
||||
return self._registered
|
||||
|
||||
@@ -4,6 +4,10 @@
|
||||
# REGION CONST
|
||||
REG_CN = "cn"
|
||||
REG_US = "us"
|
||||
REG_TW = "tw"
|
||||
|
||||
# Epsilon for avoiding division by zero.
|
||||
EPS = 1e-12
|
||||
|
||||
# Infinity in integer
|
||||
INF = 10**18
|
||||
|
||||
55
qlib/contrib/data/data.py
Normal file
55
qlib/contrib/data/data.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# We remove arctic from core framework of Qlib to contrib due to
|
||||
# - Arctic has very strict limitation on pandas and numpy version
|
||||
# - https://github.com/man-group/arctic/pull/908
|
||||
# - pip fail to computing the right version number!!!!
|
||||
# - Maybe we can solve this problem by poetry
|
||||
|
||||
# FIXME: So if you want to use arctic-based provider, please install arctic manually
|
||||
# `pip install arctic` may not be enough.
|
||||
from arctic import Arctic
|
||||
import pandas as pd
|
||||
import pymongo
|
||||
|
||||
from qlib.data.data import FeatureProvider
|
||||
|
||||
|
||||
class ArcticFeatureProvider(FeatureProvider):
|
||||
def __init__(
|
||||
self, uri="127.0.0.1", retry_time=0, market_transaction_time_list=[("09:15", "11:30"), ("13:00", "15:00")]
|
||||
):
|
||||
super().__init__()
|
||||
self.uri = uri
|
||||
# TODO:
|
||||
# retry connecting if error occurs
|
||||
# does it real matters?
|
||||
self.retry_time = retry_time
|
||||
# NOTE: this is especially important for TResample operator
|
||||
self.market_transaction_time_list = market_transaction_time_list
|
||||
|
||||
def feature(self, instrument, field, start_index, end_index, freq):
|
||||
field = str(field)[1:]
|
||||
with pymongo.MongoClient(self.uri) as client:
|
||||
# TODO: this will result in frequently connecting the server and performance issue
|
||||
arctic = Arctic(client)
|
||||
|
||||
if freq not in arctic.list_libraries():
|
||||
raise ValueError("lib {} not in arctic".format(freq))
|
||||
|
||||
if instrument not in arctic[freq].list_symbols():
|
||||
# instruments does not exist
|
||||
return pd.Series()
|
||||
else:
|
||||
df = arctic[freq].read(instrument, columns=[field], chunk_range=(start_index, end_index))
|
||||
s = df[field]
|
||||
|
||||
if not s.empty:
|
||||
s = pd.concat(
|
||||
[
|
||||
s.between_time(time_tuple[0], time_tuple[1])
|
||||
for time_tuple in self.market_transaction_time_list
|
||||
]
|
||||
)
|
||||
return s
|
||||
@@ -7,8 +7,7 @@ import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -16,7 +15,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
def _to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return torch.tensor(x, dtype=torch.float, device=device)
|
||||
return torch.tensor(x, dtype=torch.float, device=device) # pylint: disable=E1101
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -5,9 +5,7 @@ from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_callable_kwargs
|
||||
from ...data.dataset import processor as processor_module
|
||||
from ...log import TimeInspector
|
||||
from inspect import getfullargspec
|
||||
import copy
|
||||
|
||||
|
||||
def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
|
||||
164
qlib/contrib/data/highfreq_handler.py
Normal file
164
qlib/contrib/data/highfreq_handler.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
|
||||
|
||||
EPSILON = 1e-4
|
||||
|
||||
|
||||
class HighFreqHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
drop_raw=drop_raw,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
# norm with the close price of 237th minute of yesterday.
|
||||
if shift == 0:
|
||||
template_norm = "{0}/DayLast(Ref({1}, 243))"
|
||||
else:
|
||||
template_norm = "Ref({0}, " + str(shift) + ")/DayLast(Ref({1}, 243))"
|
||||
|
||||
template_fillnan = "FFillNan({0})"
|
||||
# calculate -> ffill -> remove paused
|
||||
feature_ops = template_paused.format(
|
||||
template_fillnan.format(
|
||||
template_norm.format(template_if.format("$close", price_field), template_fillnan.format("$close"))
|
||||
)
|
||||
)
|
||||
return feature_ops
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 0)]
|
||||
fields += [get_normalized_price_feature("$high", 0)]
|
||||
fields += [get_normalized_price_feature("$low", 0)]
|
||||
fields += [get_normalized_price_feature("$close", 0)]
|
||||
fields += [get_normalized_price_feature("$vwap", 0)]
|
||||
names += ["$open", "$high", "$low", "$close", "$vwap"]
|
||||
|
||||
fields += [get_normalized_price_feature("$open", 240)]
|
||||
fields += [get_normalized_price_feature("$high", 240)]
|
||||
fields += [get_normalized_price_feature("$low", 240)]
|
||||
fields += [get_normalized_price_feature("$close", 240)]
|
||||
fields += [get_normalized_price_feature("$vwap", 240)]
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
|
||||
# calculate and fill nan with 0
|
||||
template_gzero = "If(Ge({0}, 0), {0}, 0)"
|
||||
fields += [
|
||||
template_gzero.format(
|
||||
template_paused.format(
|
||||
"If(IsNull({0}), 0, {0})".format("{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format("$volume"))
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume"]
|
||||
|
||||
fields += [
|
||||
template_gzero.format(
|
||||
template_paused.format(
|
||||
"If(IsNull({0}), 0, {0})".format(
|
||||
"Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format("$volume")
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqBacktestHandler(DataHandler):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Gt($hx_paused_num, 1.001), {0})"
|
||||
# template_paused = "{0}"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
|
||||
fields += [
|
||||
template_paused.format(
|
||||
template_if.format(
|
||||
template_fillnan.format("$close"),
|
||||
"$vwap",
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
|
||||
names += ["$volume0"]
|
||||
|
||||
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$factor"))]
|
||||
names += ["$factor0"]
|
||||
|
||||
return fields, names
|
||||
81
qlib/contrib/data/highfreq_processor.py
Normal file
81
qlib/contrib/data/highfreq_processor.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.data.dataset.utils import fetch_df_by_index
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class HighFreqTrans(Processor):
|
||||
def __init__(self, dtype: str = "bool"):
|
||||
self.dtype = dtype
|
||||
|
||||
def fit(self, df_features):
|
||||
pass
|
||||
|
||||
def __call__(self, df_features):
|
||||
if self.dtype == "bool":
|
||||
return df_features.astype(np.int8)
|
||||
else:
|
||||
return df_features.astype(np.float32)
|
||||
|
||||
|
||||
class HighFreqNorm(Processor):
|
||||
def __init__(
|
||||
self,
|
||||
fit_start_time: pd.Timestamp,
|
||||
fit_end_time: pd.Timestamp,
|
||||
feature_save_dir: str,
|
||||
norm_groups: Dict[str, int],
|
||||
):
|
||||
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.feature_save_dir = feature_save_dir
|
||||
self.norm_groups = norm_groups
|
||||
|
||||
def fit(self, df_features) -> None:
|
||||
if os.path.exists(self.feature_save_dir) and len(os.listdir(self.feature_save_dir)) != 0:
|
||||
return
|
||||
os.makedirs(self.feature_save_dir)
|
||||
fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
del df_features
|
||||
index = 0
|
||||
names = {}
|
||||
for name, dim in self.norm_groups.items():
|
||||
names[name] = slice(index, index + dim)
|
||||
index += dim
|
||||
for name, name_val in names.items():
|
||||
df_values = fetch_df.iloc(axis=1)[name_val].values
|
||||
if name.endswith("volume"):
|
||||
df_values = np.log1p(df_values)
|
||||
self.feature_mean = np.nanmean(df_values)
|
||||
np.save(self.feature_save_dir + name + "_mean.npy", self.feature_mean)
|
||||
df_values = df_values - self.feature_mean
|
||||
self.feature_std = np.nanstd(np.absolute(df_values))
|
||||
np.save(self.feature_save_dir + name + "_std.npy", self.feature_std)
|
||||
df_values = df_values / self.feature_std
|
||||
np.save(self.feature_save_dir + name + "_vmax.npy", np.nanmax(df_values))
|
||||
np.save(self.feature_save_dir + name + "_vmin.npy", np.nanmin(df_values))
|
||||
return
|
||||
|
||||
def __call__(self, df_features):
|
||||
if "date" in df_features:
|
||||
df_features.droplevel("date", inplace=True)
|
||||
df_values = df_features.values
|
||||
index = 0
|
||||
names = {}
|
||||
for name, dim in self.norm_groups.items():
|
||||
names[name] = slice(index, index + dim)
|
||||
index += dim
|
||||
for name, name_val in names.items():
|
||||
feature_mean = np.load(self.feature_save_dir + name + "_mean.npy")
|
||||
feature_std = np.load(self.feature_save_dir + name + "_std.npy")
|
||||
|
||||
if name.endswith("volume"):
|
||||
df_values[:, name_val] = np.log1p(df_values[:, name_val])
|
||||
df_values[:, name_val] -= feature_mean
|
||||
df_values[:, name_val] /= feature_std
|
||||
df_features = pd.DataFrame(data=df_values, index=df_features.index, columns=df_features.columns)
|
||||
return df_features.fillna(0)
|
||||
301
qlib/contrib/data/highfreq_provider.py
Normal file
301
qlib/contrib/data/highfreq_provider.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.data import Cal
|
||||
from qlib.contrib.ops.high_freq import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut
|
||||
import pickle as pkl
|
||||
from joblib import Parallel, delayed
|
||||
from utilsd.logging import print_log
|
||||
|
||||
|
||||
class HighFreqProvider:
|
||||
def __init__(
|
||||
self,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
train_end_time: str,
|
||||
valid_start_time: str,
|
||||
valid_end_time: str,
|
||||
test_start_time: str,
|
||||
qlib_conf: dict,
|
||||
feature_conf: dict,
|
||||
label_conf: Optional[dict] = None,
|
||||
backtest_conf: dict = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.test_start_time = test_start_time
|
||||
self.train_end_time = train_end_time
|
||||
self.valid_start_time = valid_start_time
|
||||
self.valid_end_time = valid_end_time
|
||||
self._init_qlib(qlib_conf)
|
||||
self.feature_conf = feature_conf
|
||||
self.label_conf = label_conf
|
||||
self.backtest_conf = backtest_conf
|
||||
self.qlib_conf = qlib_conf
|
||||
|
||||
def get_pre_datasets(self):
|
||||
"""Generate the training, validation and test datasets for prediction
|
||||
|
||||
Returns:
|
||||
Tuple[BaseDataset, BaseDataset, BaseDataset]: The training and test datasets
|
||||
"""
|
||||
|
||||
dict_feature_path = self.feature_conf["path"]
|
||||
train_feature_path = dict_feature_path[:-4] + "_train.pkl"
|
||||
valid_feature_path = dict_feature_path[:-4] + "_valid.pkl"
|
||||
test_feature_path = dict_feature_path[:-4] + "_test.pkl"
|
||||
|
||||
dict_label_path = self.label_conf["path"]
|
||||
train_label_path = dict_label_path[:-4] + "_train.pkl"
|
||||
valid_label_path = dict_label_path[:-4] + "_valid.pkl"
|
||||
test_label_path = dict_label_path[:-4] + "_test.pkl"
|
||||
|
||||
if (
|
||||
not os.path.isfile(train_feature_path)
|
||||
or not os.path.isfile(valid_feature_path)
|
||||
or not os.path.isfile(test_feature_path)
|
||||
):
|
||||
xtrain, xvalid, xtest = self._gen_data(self.feature_conf)
|
||||
xtrain.to_pickle(train_feature_path)
|
||||
xvalid.to_pickle(valid_feature_path)
|
||||
xtest.to_pickle(test_feature_path)
|
||||
del xtrain, xvalid, xtest
|
||||
|
||||
if (
|
||||
not os.path.isfile(train_label_path)
|
||||
or not os.path.isfile(valid_label_path)
|
||||
or not os.path.isfile(test_label_path)
|
||||
):
|
||||
ytrain, yvalid, ytest = self._gen_data(self.label_conf)
|
||||
ytrain.to_pickle(train_label_path)
|
||||
yvalid.to_pickle(valid_label_path)
|
||||
ytest.to_pickle(test_label_path)
|
||||
del ytrain, yvalid, ytest
|
||||
|
||||
feature = {
|
||||
"train": train_feature_path,
|
||||
"valid": valid_feature_path,
|
||||
"test": test_feature_path,
|
||||
}
|
||||
|
||||
label = {
|
||||
"train": train_label_path,
|
||||
"valid": valid_label_path,
|
||||
"test": test_label_path,
|
||||
}
|
||||
|
||||
return feature, label
|
||||
|
||||
def get_backtest(self, **kwargs) -> None:
|
||||
self._gen_data(self.backtest_conf)
|
||||
|
||||
def _init_qlib(self, qlib_conf):
|
||||
"""initialize qlib"""
|
||||
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut],
|
||||
expression_cache=None,
|
||||
**qlib_conf,
|
||||
)
|
||||
|
||||
def _prepare_calender_cache(self):
|
||||
"""preload the calendar for cache"""
|
||||
|
||||
# This code used the copy-on-write feature of Linux
|
||||
# to avoid calculating the calendar multiple times in the subprocess.
|
||||
# This code may accelerate, but may be not useful on Windows and Mac Os
|
||||
Cal.calendar(freq="1min")
|
||||
get_calendar_day(freq="1min")
|
||||
|
||||
def _gen_dataframe(self, config, datasets=["train", "valid", "test"]):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
data = pkl.load(f)
|
||||
if isinstance(data, dict):
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
trainset, validset, testset = dataset.prepare(["train", "valid", "test"])
|
||||
data = {
|
||||
"train": trainset,
|
||||
"valid": validset,
|
||||
"test": testset,
|
||||
}
|
||||
with open(path, "wb") as f:
|
||||
pkl.dump(data, f)
|
||||
with open(path[:-4] + "train.pkl", "wb") as f:
|
||||
pkl.dump(trainset, f)
|
||||
with open(path[:-4] + "valid.pkl", "wb") as f:
|
||||
pkl.dump(validset, f)
|
||||
with open(path[:-4] + "test.pkl", "wb") as f:
|
||||
pkl.dump(testset, f)
|
||||
res = [data[i] for i in datasets]
|
||||
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
return res
|
||||
|
||||
def _gen_data(self, config, datasets=["train", "valid", "test"]):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
|
||||
# res = dataset.prepare(['train', 'valid', 'test'])
|
||||
with open(path, "rb") as f:
|
||||
data = pkl.load(f)
|
||||
if isinstance(data, dict):
|
||||
res = [data[i] for i in datasets]
|
||||
else:
|
||||
res = data.prepare(datasets)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
start_time = time.time()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
res = dataset.prepare(datasets)
|
||||
print_log(f"Data generated, time cost: {(time.time() - start_time):.2f}", __name__)
|
||||
return res
|
||||
|
||||
def _gen_dataset(self, config):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
if os.path.isfile(path):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
dataset = pkl.load(f)
|
||||
print_log(f"Data loaded, time cost: {time.time() - start:.2f}", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.prepare(["train", "valid", "test"])
|
||||
print_log(f"Dataset prepared, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=True, recursive=True)
|
||||
dataset.to_pickle(path)
|
||||
return dataset
|
||||
|
||||
def _gen_day_dataset(self, config, conf_type):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
with open(path + "tmp_dataset.pkl", "rb") as f:
|
||||
new_dataset = pkl.load(f)
|
||||
|
||||
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="1min")[::240]
|
||||
|
||||
def generate_dataset(times):
|
||||
if os.path.isfile(path + times.strftime("%Y-%m-%d") + ".pkl"):
|
||||
print("exist " + times.strftime("%Y-%m-%d"))
|
||||
return
|
||||
self._init_qlib(self.qlib_conf)
|
||||
end_times = times + datetime.timedelta(days=1)
|
||||
new_dataset.handler.config(**{"start_time": times, "end_time": end_times})
|
||||
if conf_type == "backtest":
|
||||
new_dataset.handler.setup_data()
|
||||
else:
|
||||
new_dataset.handler.setup_data(init_type=DataHandlerLP.IT_LS)
|
||||
new_dataset.config(dump_all=True, recursive=True)
|
||||
new_dataset.to_pickle(path + times.strftime("%Y-%m-%d") + ".pkl")
|
||||
|
||||
Parallel(n_jobs=8)(delayed(generate_dataset)(times) for times in time_list)
|
||||
|
||||
def _gen_stock_dataset(self, config, conf_type):
|
||||
try:
|
||||
path = config.pop("path")
|
||||
except KeyError as e:
|
||||
raise ValueError("Must specify the path to save the dataset.") from e
|
||||
|
||||
if os.path.isfile(path + "tmp_dataset.pkl"):
|
||||
start = time.time()
|
||||
print_log("Dataset exists, load from disk.", __name__)
|
||||
else:
|
||||
start = time.time()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
print_log("Generating dataset", __name__)
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(config)
|
||||
print_log(f"Dataset init, time cost: {time.time() - start:.2f}", __name__)
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
dataset.to_pickle(path + "tmp_dataset.pkl")
|
||||
|
||||
with open(path + "tmp_dataset.pkl", "rb") as f:
|
||||
new_dataset = pkl.load(f)
|
||||
|
||||
instruments = D.instruments(market="all")
|
||||
stock_list = D.list_instruments(
|
||||
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq="1min", as_list=True
|
||||
)
|
||||
|
||||
def generate_dataset(stock):
|
||||
if os.path.isfile(path + stock + ".pkl"):
|
||||
print("exist " + stock)
|
||||
return
|
||||
self._init_qlib(self.qlib_conf)
|
||||
new_dataset.handler.config(**{"instruments": [stock]})
|
||||
if conf_type == "backtest":
|
||||
new_dataset.handler.setup_data()
|
||||
else:
|
||||
new_dataset.handler.setup_data(init_type=DataHandlerLP.IT_LS)
|
||||
new_dataset.config(dump_all=True, recursive=True)
|
||||
new_dataset.to_pickle(path + stock + ".pkl")
|
||||
|
||||
Parallel(n_jobs=32)(delayed(generate_dataset)(stock) for stock in stock_list)
|
||||
@@ -1,9 +1,6 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
from ...log import TimeInspector
|
||||
from ...utils.serial import Serializable
|
||||
from ...data.dataset.processor import Processor, get_group_columns
|
||||
|
||||
|
||||
@@ -62,10 +59,10 @@ class ConfigSectionProcessor(Processor):
|
||||
|
||||
# Features
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x**0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x**0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
_cols = [
|
||||
"KMID",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import pandas as pd
|
||||
from typing import Dict, Iterable
|
||||
from typing import Dict, Iterable, Union
|
||||
|
||||
|
||||
def align_index(df_dict, join):
|
||||
@@ -24,6 +24,10 @@ class SepDataFrame:
|
||||
SepDataFrame tries to act like a DataFrame whose column with multiindex
|
||||
"""
|
||||
|
||||
# TODO:
|
||||
# SepDataFrame try to behave like pandas dataframe, but it is still not them same
|
||||
# Contributions are welcome to make it more complete.
|
||||
|
||||
def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False):
|
||||
"""
|
||||
initialize the data based on the dataframe dictionary
|
||||
@@ -77,14 +81,37 @@ class SepDataFrame:
|
||||
|
||||
def _update_join(self):
|
||||
if self.join not in self:
|
||||
self.join = next(iter(self._df_dict.keys()))
|
||||
if len(self._df_dict) > 0:
|
||||
self.join = next(iter(self._df_dict.keys()))
|
||||
else:
|
||||
# NOTE: this will change the behavior of previous reindex when all the keys are empty
|
||||
self.join = None
|
||||
|
||||
def __getitem__(self, item):
|
||||
# TODO: behave more like pandas when multiindex
|
||||
return self._df_dict[item]
|
||||
|
||||
def __setitem__(self, item: str, df: pd.DataFrame):
|
||||
def __setitem__(self, item: str, df: Union[pd.DataFrame, pd.Series]):
|
||||
# TODO: consider the join behavior
|
||||
self._df_dict[item] = df
|
||||
if not isinstance(item, tuple):
|
||||
self._df_dict[item] = df
|
||||
else:
|
||||
# NOTE: corner case of MultiIndex
|
||||
_df_dict_key, *col_name = item
|
||||
col_name = tuple(col_name)
|
||||
if _df_dict_key in self._df_dict:
|
||||
if len(col_name) == 1:
|
||||
col_name = col_name[0]
|
||||
self._df_dict[_df_dict_key][col_name] = df
|
||||
else:
|
||||
if isinstance(df, pd.Series):
|
||||
if len(col_name) == 1:
|
||||
col_name = col_name[0]
|
||||
self._df_dict[_df_dict_key] = df.to_frame(col_name)
|
||||
else:
|
||||
df_copy = df.copy() # avoid changing df
|
||||
df_copy.columns = pd.MultiIndex.from_tuples([(*col_name, *idx) for idx in df.columns.to_list()])
|
||||
self._df_dict[_df_dict_key] = df_copy
|
||||
|
||||
def __delitem__(self, item: str):
|
||||
del self._df_dict[item]
|
||||
@@ -164,14 +191,14 @@ import builtins
|
||||
|
||||
|
||||
def _isinstance(instance, cls):
|
||||
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602
|
||||
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602 # noqa: F821
|
||||
if isinstance(cls, Iterable):
|
||||
for c in cls:
|
||||
if c is pd.DataFrame:
|
||||
return True
|
||||
elif cls is pd.DataFrame:
|
||||
return True
|
||||
return isinstance_orig(instance, cls) # pylint: disable=E0602
|
||||
return isinstance_orig(instance, cls) # pylint: disable=E0602 # noqa: F821
|
||||
|
||||
|
||||
builtins.isinstance_orig = builtins.isinstance
|
||||
|
||||
@@ -4,8 +4,10 @@ Here is a batch of evaluation functions.
|
||||
The interface should be redesigned carefully in the future.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
from typing import Tuple
|
||||
from qlib import get_module_logger
|
||||
from qlib.utils.paral import complex_parallel, DelayedDict
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
|
||||
def calc_long_short_prec(
|
||||
@@ -61,32 +63,6 @@ def calc_long_short_prec(
|
||||
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred :
|
||||
pred
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
return ic, ric
|
||||
|
||||
|
||||
def calc_long_short_return(
|
||||
pred: pd.Series,
|
||||
label: pd.Series,
|
||||
@@ -127,3 +103,105 @@ def calc_long_short_return(
|
||||
r_short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label.mean())
|
||||
r_avg = group.label.mean()
|
||||
return (r_long - r_short) / 2, r_avg
|
||||
|
||||
|
||||
def pred_autocorr(pred: pd.Series, lag=1, inst_col="instrument", date_col="datetime"):
|
||||
"""pred_autocorr.
|
||||
|
||||
Limitation:
|
||||
- If the datetime is not sequential densely, the correlation will be calulated based on adjacent dates. (some users may expected NaN)
|
||||
|
||||
:param pred: pd.Series with following format
|
||||
instrument datetime
|
||||
SH600000 2016-01-04 -0.000403
|
||||
2016-01-05 -0.000753
|
||||
2016-01-06 -0.021801
|
||||
2016-01-07 -0.065230
|
||||
2016-01-08 -0.062465
|
||||
:type pred: pd.Series
|
||||
:param lag:
|
||||
"""
|
||||
if isinstance(pred, pd.DataFrame):
|
||||
pred = pred.iloc[:, 0]
|
||||
get_module_logger("pred_autocorr").warning(f"Only the first column in {pred.columns} of `pred` is kept")
|
||||
pred_ustk = pred.sort_index().unstack(inst_col)
|
||||
corr_s = {}
|
||||
for (idx, cur), (_, prev) in zip(pred_ustk.iterrows(), pred_ustk.shift(lag).iterrows()):
|
||||
corr_s[idx] = cur.corr(prev)
|
||||
corr_s = pd.Series(corr_s).sort_index()
|
||||
return corr_s
|
||||
|
||||
|
||||
def pred_autocorr_all(pred_dict, n_jobs=-1, **kwargs):
|
||||
"""
|
||||
calculate auto correlation for pred_dict
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred_dict : dict
|
||||
A dict like {<method_name>: <prediction>}
|
||||
kwargs :
|
||||
all these arguments will be passed into pred_autocorr
|
||||
"""
|
||||
ac_dict = {}
|
||||
for k, pred in pred_dict.items():
|
||||
ac_dict[k] = delayed(pred_autocorr)(pred, **kwargs)
|
||||
return complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), ac_dict)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> (pd.Series, pd.Series):
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred :
|
||||
pred
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
return ic, ric
|
||||
|
||||
|
||||
def calc_all_ic(pred_dict_all, label, date_col="datetime", dropna=False, n_jobs=-1):
|
||||
"""calc_all_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred_dict_all :
|
||||
A dict like {<method_name>: <prediction>}
|
||||
label:
|
||||
A pd.Series of label values
|
||||
|
||||
Returns
|
||||
-------
|
||||
{'Q2+IND_z': {'ic': <ic series like>
|
||||
2016-01-04 -0.057407
|
||||
...
|
||||
2020-05-28 0.183470
|
||||
2020-05-29 0.171393
|
||||
'ric': <rank ic series like>
|
||||
2016-01-04 -0.040888
|
||||
...
|
||||
2020-05-28 0.236665
|
||||
2020-05-29 0.183886
|
||||
}
|
||||
...}
|
||||
"""
|
||||
pred_all_ics = {}
|
||||
for k, pred in pred_dict_all.items():
|
||||
pred_all_ics[k] = DelayedDict(["ic", "ric"], delayed(calc_ic)(pred, label, date_col=date_col, dropna=dropna))
|
||||
pred_all_ics = complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), pred_all_ics)
|
||||
return pred_all_ics
|
||||
|
||||
@@ -5,12 +5,10 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.stats import spearmanr, pearsonr
|
||||
|
||||
|
||||
from ..data import D
|
||||
|
||||
from collections import OrderedDict
|
||||
@@ -243,4 +241,4 @@ def get_rank_ic(a, b):
|
||||
|
||||
|
||||
def get_normal_ic(a, b):
|
||||
return pearsonr(a, b).correlation
|
||||
return pearsonr(a, b)[0]
|
||||
|
||||
@@ -2,3 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .data_selection import MetaTaskDS, MetaDatasetDS, MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaTaskDS", "MetaDatasetDS", "MetaModelDS"]
|
||||
|
||||
@@ -3,3 +3,6 @@
|
||||
|
||||
from .dataset import MetaDatasetDS, MetaTaskDS
|
||||
from .model import MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaDatasetDS", "MetaTaskDS", "MetaModelDS"]
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from copy import deepcopy
|
||||
from qlib.data.dataset.utils import init_task_handler
|
||||
from qlib.utils.data import deepcopy_basic_type
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from typing import Dict, List, Union, Text, Tuple
|
||||
from qlib.data.dataset.handler import DataHandler
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from joblib import Parallel, delayed
|
||||
from qlib.model.meta.dataset import MetaTaskDataset
|
||||
from qlib.model.trainer import task_train, TrainerR
|
||||
from qlib.data.dataset import DatasetH
|
||||
from tqdm.auto import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from joblib import Parallel, delayed # pylint: disable=E0401
|
||||
from typing import Dict, List, Union, Text, Tuple
|
||||
from qlib.data.dataset.utils import init_task_handler
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from qlib.model.meta.dataset import MetaTaskDataset
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config
|
||||
from qlib.utils.data import deepcopy_basic_type
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class InternalData:
|
||||
|
||||
@@ -1,28 +1,25 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from qlib.model.meta.task import MetaTask
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import optim
|
||||
from tqdm.auto import tqdm
|
||||
import collections
|
||||
import copy
|
||||
from typing import Union, List, Tuple, Dict
|
||||
from typing import Union, List
|
||||
|
||||
from ....data.dataset.weight import Reweighter
|
||||
from ....model.meta.dataset import MetaTaskDataset
|
||||
from ....model.meta.model import MetaModel, MetaTaskModel
|
||||
from ....model.meta.model import MetaTaskModel
|
||||
from ....workflow import R
|
||||
|
||||
from .utils import ICLoss
|
||||
from .dataset import MetaDatasetDS
|
||||
from qlib.contrib.meta.data_selection.net import PredNet
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
from qlib.contrib.meta.data_selection.net import PredNet
|
||||
|
||||
logger = get_module_logger("data selection")
|
||||
|
||||
@@ -100,7 +97,6 @@ class MetaModelDS(MetaTaskModel):
|
||||
|
||||
if phase == "train":
|
||||
opt.zero_grad()
|
||||
norm_loss = nn.MSELoss()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
elif phase == "test":
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
|
||||
|
||||
class ICLoss(nn.Module):
|
||||
def forward(self, pred, y, idx, skip_size=50):
|
||||
"""forward.
|
||||
FIXME:
|
||||
- Some times it will be a slightly different from the result from `pandas.corr()`
|
||||
- It may be caused by the precision problem of model;
|
||||
|
||||
:param pred:
|
||||
:param y:
|
||||
|
||||
@@ -10,17 +10,19 @@ try:
|
||||
from .gbdt import LGBModel
|
||||
except ModuleNotFoundError:
|
||||
DEnsembleModel, LGBModel = None, None
|
||||
print("Please install necessary libs for DEnsembleModel and LGBModel, such as lightgbm.")
|
||||
print(
|
||||
"ModuleNotFoundError. DEnsembleModel and LGBModel are skipped. (optional: maybe installing lightgbm can fix it.)"
|
||||
)
|
||||
try:
|
||||
from .xgboost import XGBModel
|
||||
except ModuleNotFoundError:
|
||||
XGBModel = None
|
||||
print("Please install necessary libs for XGBModel, such as xgboost.")
|
||||
print("ModuleNotFoundError. XGBModel is skipped(optional: maybe installing xgboost can fix it).")
|
||||
try:
|
||||
from .linear import LinearModel
|
||||
except ModuleNotFoundError:
|
||||
LinearModel = None
|
||||
print("Please install necessary libs for LinearModel, such as scipy and sklearn.")
|
||||
print("ModuleNotFoundError. LinearModel is skipped(optional: maybe installing scipy and sklearn can fix it).")
|
||||
# import pytorch models
|
||||
try:
|
||||
from .pytorch_alstm import ALSTM
|
||||
@@ -36,6 +38,6 @@ try:
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN, ADD)
|
||||
except ModuleNotFoundError:
|
||||
pytorch_classes = ()
|
||||
print("Please install necessary libs for PyTorch models.")
|
||||
print("ModuleNotFoundError. PyTorch models are skipped (optional: maybe installing pytorch can fix it).")
|
||||
|
||||
all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes
|
||||
|
||||
@@ -160,7 +160,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
h_avg = h.groupby("bins")["h_value"].mean()
|
||||
weights = pd.Series(np.zeros(N, dtype=float))
|
||||
for i_b, b in enumerate(h_avg.index):
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay**k_th * h_avg[i_b] + 0.1)
|
||||
return weights
|
||||
|
||||
def feature_selection(self, df_train, loss_values):
|
||||
@@ -249,7 +249,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
return pred
|
||||
|
||||
def predict_sub(self, submodel, df_data, features):
|
||||
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
|
||||
x_data = df_data["feature"].loc[:, features]
|
||||
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
|
||||
return pred_sub
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import LightGBMFInt
|
||||
from ...data.dataset.weight import Reweighter
|
||||
from qlib.workflow import R
|
||||
|
||||
|
||||
class LGBModel(ModelFT, LightGBMFInt):
|
||||
@@ -59,27 +60,34 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
num_boost_round=None,
|
||||
early_stopping_rounds=None,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
evals_result=None,
|
||||
reweighter=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
if evals_result is None:
|
||||
evals_result = {} # in case of unsafety of Python default values
|
||||
ds_l = self._prepare_data(dataset, reweighter)
|
||||
ds, names = list(zip(*ds_l))
|
||||
early_stopping_callback = lgb.early_stopping(
|
||||
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
|
||||
)
|
||||
# NOTE: if you encounter error here. Please upgrade your lightgbm
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
evals_result_callback = lgb.record_evaluation(evals_result)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
ds[0], # training dataset
|
||||
num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round,
|
||||
valid_sets=ds,
|
||||
valid_names=names,
|
||||
early_stopping_rounds=(
|
||||
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
|
||||
),
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
|
||||
**kwargs,
|
||||
)
|
||||
for k in names:
|
||||
evals_result[k] = list(evals_result[k].values())[0]
|
||||
for key, val in evals_result[k].items():
|
||||
name = f"{key}.{k}"
|
||||
for epoch, m in enumerate(val):
|
||||
R.log_metrics(**{name.replace("@", "_"): m}, step=epoch)
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
@@ -101,9 +109,10 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter)
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
|
||||
if dtrain.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
@@ -111,5 +120,5 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
callbacks=[verbose_eval_callback],
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
"""
|
||||
Test the signal in high frequency test set
|
||||
"""
|
||||
if self.model == None:
|
||||
if self.model is None:
|
||||
raise ValueError("Model hasn't been trained yet")
|
||||
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
df_test.dropna(inplace=True)
|
||||
@@ -110,20 +110,21 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
evals_result=None,
|
||||
):
|
||||
if evals_result is None:
|
||||
evals_result = dict()
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
early_stopping_callback = lgb.early_stopping(early_stopping_rounds)
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
evals_result_callback = lgb.record_evaluation(evals_result)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
@@ -149,6 +150,7 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
@@ -156,5 +158,5 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
callbacks=[verbose_eval_callback],
|
||||
)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
import os
|
||||
from pdb import set_trace
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
import copy
|
||||
from typing import Text, Union
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -146,7 +144,7 @@ class ADARNN(Model):
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.cuda()
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
@@ -155,7 +153,7 @@ class ADARNN(Model):
|
||||
def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None):
|
||||
self.model.train()
|
||||
criterion = nn.MSELoss()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||
len_loader = np.inf
|
||||
for loader in train_loader_list:
|
||||
if len(loader) < len_loader:
|
||||
@@ -167,7 +165,7 @@ class ADARNN(Model):
|
||||
list_label = []
|
||||
for data in data_all:
|
||||
# feature :[36, 24, 6]
|
||||
feature, label_reg = data[0].cuda().float(), data[1].cuda().float()
|
||||
feature, label_reg = data[0].to(self.device).float(), data[1].to(self.device).float()
|
||||
list_feat.append(feature)
|
||||
list_label.append(label_reg)
|
||||
flag = False
|
||||
@@ -181,12 +179,12 @@ class ADARNN(Model):
|
||||
if flag:
|
||||
continue
|
||||
|
||||
total_loss = torch.zeros(1).cuda()
|
||||
for i in range(len(index)):
|
||||
feature_s = list_feat[index[i][0]]
|
||||
feature_t = list_feat[index[i][1]]
|
||||
label_reg_s = list_label[index[i][0]]
|
||||
label_reg_t = list_label[index[i][1]]
|
||||
total_loss = torch.zeros(1).to(self.device)
|
||||
for i, n in enumerate(index):
|
||||
feature_s = list_feat[n[0]]
|
||||
feature_t = list_feat[n[1]]
|
||||
label_reg_s = list_label[n[0]]
|
||||
label_reg_t = list_label[n[1]]
|
||||
feature_all = torch.cat((feature_s, feature_t), 0)
|
||||
|
||||
if epoch < self.pre_epoch:
|
||||
@@ -327,7 +325,7 @@ class ADARNN(Model):
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().cuda()
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model.predict(x_batch).detach().cpu().numpy()
|
||||
@@ -337,7 +335,7 @@ class ADARNN(Model):
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
def transform_type(self, init_weight):
|
||||
weight = torch.ones(self.num_layers, self.len_seq).cuda()
|
||||
weight = torch.ones(self.num_layers, self.len_seq).to(self.device)
|
||||
for i in range(self.num_layers):
|
||||
for j in range(self.len_seq):
|
||||
weight[i, j] = init_weight[i][j].item()
|
||||
@@ -391,6 +389,7 @@ class AdaRNN(nn.Module):
|
||||
len_seq=9,
|
||||
model_type="AdaRNN",
|
||||
trans_loss="mmd",
|
||||
GPU=0,
|
||||
):
|
||||
super(AdaRNN, self).__init__()
|
||||
self.use_bottleneck = use_bottleneck
|
||||
@@ -401,6 +400,7 @@ class AdaRNN(nn.Module):
|
||||
self.model_type = model_type
|
||||
self.trans_loss = trans_loss
|
||||
self.len_seq = len_seq
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
in_size = self.n_input
|
||||
|
||||
features = nn.ModuleList()
|
||||
@@ -410,7 +410,7 @@ class AdaRNN(nn.Module):
|
||||
in_size = hidden
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
if use_bottleneck == True: # finance
|
||||
if use_bottleneck is True: # finance
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Linear(n_hiddens[-1], bottleneck_width),
|
||||
nn.Linear(bottleneck_width, bottleneck_width),
|
||||
@@ -449,7 +449,7 @@ class AdaRNN(nn.Module):
|
||||
def forward_pre_train(self, x, len_win=0):
|
||||
out = self.gru_features(x)
|
||||
fea = out[0] # [2N,L,H]
|
||||
if self.use_bottleneck == True:
|
||||
if self.use_bottleneck is True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
@@ -457,9 +457,9 @@ class AdaRNN(nn.Module):
|
||||
|
||||
out_list_all, out_weight_list = out[1], out[2]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
loss_transfer = torch.zeros((1,)).to(self.device)
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
h_start = 0
|
||||
for j in range(h_start, self.len_seq, 1):
|
||||
i_start = j - len_win if j - len_win >= 0 else 0
|
||||
@@ -471,7 +471,7 @@ class AdaRNN(nn.Module):
|
||||
else 1 / (self.len_seq - h_start) * (2 * len_win + 1)
|
||||
)
|
||||
loss_transfer = loss_transfer + weight * criterion_transder.compute(
|
||||
out_list_s[i][:, j, :], out_list_t[i][:, k, :]
|
||||
n[:, j, :], out_list_t[i][:, k, :]
|
||||
)
|
||||
return fc_out, loss_transfer, out_weight_list
|
||||
|
||||
@@ -484,7 +484,7 @@ class AdaRNN(nn.Module):
|
||||
out, _ = self.features[i](x_input.float())
|
||||
x_input = out
|
||||
out_lis.append(out)
|
||||
if self.model_type == "AdaRNN" and predict == False:
|
||||
if self.model_type == "AdaRNN" and predict is False:
|
||||
out_gate = self.process_gate_weight(x_input, i)
|
||||
out_weight_list.append(out_gate)
|
||||
return out, out_lis, out_weight_list
|
||||
@@ -518,16 +518,16 @@ class AdaRNN(nn.Module):
|
||||
|
||||
out_list_all = out[1]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
loss_transfer = torch.zeros((1,)).to(self.device)
|
||||
if weight_mat is None:
|
||||
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).cuda()
|
||||
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).to(self.device)
|
||||
else:
|
||||
weight = weight_mat
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
for j in range(self.len_seq):
|
||||
loss_trans = criterion_transder.compute(out_list_s[i][:, j, :], out_list_t[i][:, j, :])
|
||||
loss_trans = criterion_transder.compute(n[:, j, :], out_list_t[i][:, j, :])
|
||||
loss_transfer = loss_transfer + weight[i, j] * loss_trans
|
||||
dist_mat[i, j] = loss_trans
|
||||
return fc_out, loss_transfer, dist_mat, weight
|
||||
@@ -546,7 +546,7 @@ class AdaRNN(nn.Module):
|
||||
def predict(self, x):
|
||||
out = self.gru_features(x, predict=True)
|
||||
fea = out[0]
|
||||
if self.use_bottleneck == True:
|
||||
if self.use_bottleneck is True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
@@ -555,12 +555,13 @@ class AdaRNN(nn.Module):
|
||||
|
||||
|
||||
class TransferLoss:
|
||||
def __init__(self, loss_type="cosine", input_dim=512):
|
||||
def __init__(self, loss_type="cosine", input_dim=512, GPU=0):
|
||||
"""
|
||||
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
|
||||
"""
|
||||
self.loss_type = loss_type
|
||||
self.input_dim = input_dim
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
|
||||
def compute(self, X, Y):
|
||||
"""Compute adaptation loss
|
||||
@@ -572,22 +573,22 @@ class TransferLoss:
|
||||
Returns:
|
||||
[tensor] -- transfer loss
|
||||
"""
|
||||
if self.loss_type == "mmd_lin" or self.loss_type == "mmd":
|
||||
if self.loss_type in ("mmd_lin", "mmd"):
|
||||
mmdloss = MMD_loss(kernel_type="linear")
|
||||
loss = mmdloss(X, Y)
|
||||
elif self.loss_type == "coral":
|
||||
loss = CORAL(X, Y)
|
||||
elif self.loss_type == "cosine" or self.loss_type == "cos":
|
||||
loss = CORAL(X, Y, self.device)
|
||||
elif self.loss_type in ("cosine", "cos"):
|
||||
loss = 1 - cosine(X, Y)
|
||||
elif self.loss_type == "kl":
|
||||
loss = kl_div(X, Y)
|
||||
elif self.loss_type == "js":
|
||||
loss = js(X, Y)
|
||||
elif self.loss_type == "mine":
|
||||
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).cuda()
|
||||
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).to(self.device)
|
||||
loss = mine_model(X, Y)
|
||||
elif self.loss_type == "adv":
|
||||
loss = adv(X, Y, input_dim=self.input_dim, hidden_dim=32)
|
||||
loss = adv(X, Y, self.device, input_dim=self.input_dim, hidden_dim=32)
|
||||
elif self.loss_type == "mmd_rbf":
|
||||
mmdloss = MMD_loss(kernel_type="rbf")
|
||||
loss = mmdloss(X, Y)
|
||||
@@ -632,12 +633,12 @@ class Discriminator(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def adv(source, target, input_dim=256, hidden_dim=512):
|
||||
def adv(source, target, device, input_dim=256, hidden_dim=512):
|
||||
domain_loss = nn.BCELoss()
|
||||
# !!! Pay attention to .cuda !!!
|
||||
adv_net = Discriminator(input_dim, hidden_dim).cuda()
|
||||
domain_src = torch.ones(len(source)).cuda()
|
||||
domain_tar = torch.zeros(len(target)).cuda()
|
||||
adv_net = Discriminator(input_dim, hidden_dim).to(device)
|
||||
domain_src = torch.ones(len(source)).to(device)
|
||||
domain_tar = torch.zeros(len(target)).to(device)
|
||||
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
|
||||
reverse_src = ReverseLayerF.apply(source, 1)
|
||||
reverse_tar = ReverseLayerF.apply(target, 1)
|
||||
@@ -648,16 +649,16 @@ def adv(source, target, input_dim=256, hidden_dim=512):
|
||||
return loss
|
||||
|
||||
|
||||
def CORAL(source, target):
|
||||
def CORAL(source, target, device):
|
||||
d = source.size(1)
|
||||
ns, nt = source.size(0), target.size(0)
|
||||
|
||||
# source covariance
|
||||
tmp_s = torch.ones((1, ns)).cuda() @ source
|
||||
tmp_s = torch.ones((1, ns)).to(device) @ source
|
||||
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
|
||||
|
||||
# target covariance
|
||||
tmp_t = torch.ones((1, nt)).cuda() @ target
|
||||
tmp_t = torch.ones((1, nt)).to(device) @ target
|
||||
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
|
||||
|
||||
# frobenius norm
|
||||
@@ -684,9 +685,9 @@ class MMD_loss(nn.Module):
|
||||
if fix_sigma:
|
||||
bandwidth = fix_sigma
|
||||
else:
|
||||
bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
|
||||
bandwidth = torch.sum(L2_distance.data) / (n_samples**2 - n_samples)
|
||||
bandwidth /= kernel_mul ** (kernel_num // 2)
|
||||
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
|
||||
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
|
||||
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
||||
return sum(kernel_val)
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from qlib.contrib.model.pytorch_lstm import LSTMModel
|
||||
from qlib.contrib.model.pytorch_utils import count_parameters
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.processor import CSRankNorm
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import get_or_create_path
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -150,7 +149,7 @@ class ALSTM(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -312,8 +311,8 @@ class ALSTMModel(nn.Module):
|
||||
def _build_model(self):
|
||||
try:
|
||||
klass = getattr(nn, self.rnn_type.upper())
|
||||
except:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
|
||||
except Exception as e:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e
|
||||
self.net = nn.Sequential()
|
||||
self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
|
||||
self.net.add_module("act", nn.Tanh())
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user