1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00

Compare commits

...

37 Commits

Author SHA1 Message Date
Linlang
5a84aaf1dc Update version 2024-12-23 14:28:09 +08:00
Linlang
afbb178e24 Update publish (#1871)
* update publish

* reformat with black
2024-12-23 13:22:24 +08:00
Linlang
a0cef033cb update python version (#1868)
* update python version

* fix: Correct selector handling and add time filtering in storage.py

* fix: convert index and columns to list in repr methods

* feat: Add Makefile for managing project prerequisites

* feat: Add Cython extensions for rolling and expanding operations

* resolve install error

* fix lint error

* fix lint error

* fix lint error

* fix lint error

* fix lint error

* update build package

* update makefile

* update ci yaml

* fix docs build error

* fix ubuntu install error

* fix docs build error

* fix install error

* fix install error

* fix install error

* fix install error

* fix pylint error

* fix pylint error

* fix pylint error

* fix pylint error

* fix pylint error E1123

* fix pylint error R0917

* fix pytest error

* fix pytest error

* fix pytest error

* update code

* update code

* fix ci error

* fix pylint error

* fix black error

* fix pytest error

* fix CI error

* fix CI error

* add python version to CI

* add python version to CI

* add python version to CI

* fix pylint error

* fix pytest general nn error

* fix CI error

* optimize code

* add coments

* Extended macos version

* remove build package

---------

Co-authored-by: Young <afe.young@gmail.com>
2024-12-17 11:30:06 +08:00
you-n-g
7acb4f3484 Fix Async Call (#1869) 2024-12-16 18:32:46 +08:00
YQ Tsui
431f574967 fix duplicate log (#1661)
* fix duplicate log

* fix unit test

* fix log

* fix_duplicate_log

* fix_duplicate_log

* add comments

---------

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-12-09 15:45:31 +08:00
you-n-g
b604fe56b3 Update README.md 2024-12-05 10:21:37 +08:00
Linlang
af4b8772d2 Saurabh12571257/main (#1866)
* Update README.md

* test macos ci

* test macos ci

* test macos ci

* fix ci error

* fix ci error

---------

Co-authored-by: saurabh dave <87791567+saurabh12571257@users.noreply.github.com>
2024-12-04 16:23:21 +08:00
Di
18fcdf1521 Update requirements.txt (#1829)
Update urllib3 dependency according to https://github.com/advisories/GHSA-34jh-p97f-mpxf
2024-12-04 12:10:05 +08:00
Linlang
f2caf452e9 add dockerfile (#1817)
* add dockerfile

* add execute script

* add docs

* optimize docs

* optimize dockerfile

* optimize docs

* optimize dockerfile

* update code & update README

* doc build error

* update docs

* update code
2024-11-13 11:41:06 +08:00
Xu Yang
ca9f1861a4 Update README.md to show rdagent in qlib front page (#1848)
* update readme

* Update README.md

add english and chinese link to rdagent

* add the logo of rdagent to readme

add the logo of rdagent to readme

* adjust the height of the logo

* improve some works in readme

* add a line
2024-09-12 23:44:27 +08:00
Another
b45b006ef2 Update README.md (#1839)
Update data example to 20240809
2024-08-30 17:01:55 +08:00
Linlang
82cf438401 fix break img (#1842) 2024-08-14 14:59:28 +08:00
you-n-g
9e635168c0 Update README.md 2024-08-09 20:23:13 +08:00
you-n-g
b7ace1a622 🔥LLM-driven Auto Quant Factory🔥 (#1840)
* Update README.md

* Update README.md
2024-08-09 20:14:58 +08:00
cyncyw
c9ed050ef0 Ptnn4both datatypes and alignment tests (#1827)
* Init model for both dataset

* Remove some deprecated code

* Add model template;

* We must align with previous results

* We choose another mode as the initial version

* Almost success to run GRU

* Successfully run training

* Passed general_nn test

* gru test

* Alignment test passed

* comment

* fix readme & minor errors

* general nn updates & benchmarks

* Update examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml

---------

Co-authored-by: Young <afe.young@gmail.com>
Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2024-07-11 17:59:18 +08:00
Linlang
2c33332dd6 More dataloader example (#1823)
* More dataloader example

* optimize code

* optimeze code

* optimeze code

* optimeze code

* optimeze code

* optimeze code

* fix pylint error

* fix CI error

* fix CI error

* Comments

* fix error type

---------

Co-authored-by: Young <afe.young@gmail.com>
2024-07-10 14:48:44 +08:00
you-n-g
a7d5a9b500 Nested data loader (#1822)
* nested data loader

* Amend

* add data loder test

* fix pylint error

* fix pytest error

* fix pytest error

* delete comments

* Update qlib/contrib/data/handler.py

---------

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-07-05 15:44:16 +08:00
you-n-g
5190332c7e Add some misc features. (#1816)
* Normal mod

* Black linting

* Linting
2024-06-26 18:34:00 +08:00
cyncyw
cde80206e4 Update index_data.py for datatype conversion and alignment (#1813)
* Update index_data.py for data convertion and alignment

* Update qlib/utils/index_data.py

* Update qlib/utils/index_data.py

* fix linting

---------

Co-authored-by: taozhiwang <taozhiwa@gmail.com>
Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2024-06-24 15:34:48 +08:00
cyncyw
a339fc11d1 add a note for code standard (#1814)
* add a note for code standard

* handle both cases

---------

Co-authored-by: taozhiwang <taozhiwa@gmail.com>
2024-06-24 15:33:45 +08:00
Linlang
33482047dc change weight data download url (#1812) 2024-06-21 13:05:53 +08:00
Fivele-Li
47bd13295b Fix Yahoo daily data format inconsistent (#1517)
* Fix FutureWarning: Passing unit-less datetime64 dtype to .astype is deprecated and will raise in a future version. Pass 'datetime64[ns]' instead

* align index format while end date contains current day data

* fix black

* fix black

* optimize code

* optimize code

* optimize code

* fix ci error

* check ci error

* fix ci error

* check ci error

* check ci error

* check ci error

* check ci error

* check ci error

* check ci error

* fix ci error

* fix ci error

* fix ci error

* fix ci error

* fix ci error

---------

Co-authored-by: Cadenza-Li <362237642@qq.com>
Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-06-21 11:22:23 +08:00
陈屹华
ebc0ca893e Fix TSDataSampler Slicing Bug #1716 (#1803)
* Fix TSDataSampler Slicing Bug #1716

* Fix TSDataSampler Slicing Bug #1716

* Fix TSDataSampler Slicing Bug #1716

* Fix TSDataSampler Slicing Bug with simplyer implmentation#1716
 with Simplified Implementation

* Refactor: Fix CI errors by addressing pylint formatting issues

* Refactor: Remove extraneous whitespace for improved code formatting with Black
2024-06-21 09:25:23 +08:00
Lee Yuntong
3a348aec9f Fix typo (#1811)
Co-authored-by: LeeYuntong <nukuihayu@outlook.com>
2024-06-20 18:12:07 +08:00
Lee Yuntong
37b908792b Fix typo (#1809)
Co-authored-by: LeeYuntong <nukuihayu@outlook.com>
2024-06-19 17:31:57 +08:00
raikiriww
73ec0f4003 Add "mse" metric option to ALSTM.metric_fn (#1810) 2024-06-19 17:31:47 +08:00
Linlang
155c17f8ff fix logo display error (#1804) 2024-06-06 13:39:49 +08:00
Yang
41b94059aa fix panic during normalizing the invalid data (#1698)
* fix panic during normalizing the invalid data

* fix yaml load

* change error to warning

* change error code

* optimize code

---------

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-06-02 06:54:39 +08:00
block-gpt
7db83d84b7 Update utils.py for typo (#1751)
Fix typo

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-06-01 19:33:23 +08:00
Hao Zhao
35e0fdd1c0 fix the bug that the HS_SYMBOLS_URL is 404 (#1758)
* fix the bug that the HS_SYMBOLS_URL is 404

* fix bug

* format with black

* fix pylint error

* change error code

* fix ci error

* fix ci error

* optimize code

* optimize code

* add comments

---------

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-06-01 08:07:34 +08:00
you-n-g
598017f634 Update Dev in README.md (#1800) 2024-05-29 17:44:18 +08:00
igeni
907c888c23 changed concat of strings to f-strings and redundant type conversion was removed (#1767)
Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-05-28 12:13:12 +08:00
Linlang
02fe6b6974 bump verison 2024-05-24 16:38:48 +08:00
Linlang
b892b21045 update version 2024-05-24 15:14:49 +08:00
Linlang
155f80323c fix get data error (#1793)
* fix get data error

* fix get v0 data error

* optimize get_data code

* fix pylint error

* add comments
2024-05-24 12:59:50 +08:00
you-n-g
63021018d6 Update README.md's dataset 2024-05-21 08:15:18 +08:00
Linlang
f79a0eeaff fix docs (#1788)
Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-21 04:23:55 +08:00
112 changed files with 2449 additions and 872 deletions

8
.dockerignore Normal file
View File

@@ -0,0 +1,8 @@
__pycache__
*.pyc
*.pyo
*.pyd
.Python
.env
.git

View File

@@ -12,70 +12,54 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, macos-11]
# FIXME: macos-latest will raise error now.
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
os: [windows-latest, macos-13, macos-latest]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
exclude:
- os: macos-13
python-version: "3.11"
- os: macos-13
python-version: "3.12"
steps:
- uses: actions/checkout@v2
# This is because on macos systems you can install pyqlib using
# `pip install pyqlib` installs, it does not recognize the
# `pyqlib-<version>-cp38-cp38-macosx_11_0_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_11_0_x86_64.whl`.
# So we limit the version of python, in order to generate a version of qlib that is usable for macos: `pyqlib-<veresion>-cp38-cp37m
# `pyqlib-<version>-cp38-cp38-macosx_10_15_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_10_15_x86_64.whl`.
# Python 3.7.16, 3.8.16 can build macosx_10_15. But Python 3.7.17, 3.8.17 can build macosx_11_0
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'macos-11' && matrix.python-version == '3.7'
uses: actions/setup-python@v2
with:
python-version: "3.7.16"
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'macos-11' && matrix.python-version == '3.8'
uses: actions/setup-python@v2
with:
python-version: "3.8.16"
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os != 'macos-11'
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
make dev
- name: Build wheel on ${{ matrix.os }}
run: |
pip install numpy
pip install cython
python setup.py bdist_wheel
- name: Build and publish
make build
- name: Upload to PyPi
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
twine upload dist/*
twine check dist/*.whl
twine upload dist/*.whl --verbose
deploy_with_manylinux:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Build wheel on Linux
uses: RalfG/python-wheels-manylinux-build@v0.3.1-manylinux2010_x86_64
uses: RalfG/python-wheels-manylinux-build@v0.7.1-manylinux2014_x86_64
with:
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-versions: 'cp37-cp37m cp38-cp38'
python-versions: 'cp38-cp38 cp39-cp39 cp310-cp310 cp311-cp311 cp312-cp312'
build-requirements: 'numpy cython'
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Install dependencies
run: |
pip install twine
- name: Build and publish
python -m pip install twine
- name: Upload to PyPi
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
twine upload dist/pyqlib-*-manylinux*.whl
twine check dist/pyqlib-*-manylinux*.whl
twine upload dist/pyqlib-*-manylinux*.whl --verbose

View File

@@ -13,28 +13,17 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
# Since macos-latest changed from 12.7.4 to 14.4.1,
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
# so we limit the macos version to macos-12.
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
# In github action, using python 3.7, pip install will not match the latest version of the package.
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
# All things considered, we have removed python 3.7.
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- name: Test qlib from pip
uses: actions/checkout@v3
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
# So we make the version number of python 3.7 for MacOS more specific.
# refs: https://github.com/actions/setup-python/issues/682
- name: Set up Python ${{ matrix.python-version }}
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
uses: actions/setup-python@v4
with:
python-version: "3.7.16"
- name: Set up Python ${{ matrix.python-version }}
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
@@ -45,10 +34,10 @@ jobs:
- name: Qlib installation test
run: |
python -m pip install pyqlib
python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pyqlib==0.9.5.80
- name: Install Lightgbm for MacOS
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm

View File

@@ -14,28 +14,17 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
# Since macos-latest changed from 12.7.4 to 14.4.1,
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
# so we limit the macos version to macos-12.
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
# In github action, using python 3.7, pip install will not match the latest version of the package.
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
# All things considered, we have removed python 3.7.
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- name: Test qlib from source
uses: actions/checkout@v3
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
# So we make the version number of python 3.7 for MacOS more specific.
# refs: https://github.com/actions/setup-python/issues/682
- name: Set up Python ${{ matrix.python-version }}
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
uses: actions/setup-python@v4
with:
python-version: "3.7.16"
- name: Set up Python ${{ matrix.python-version }}
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
@@ -45,7 +34,7 @@ jobs:
python -m pip install --upgrade pip
- name: Installing pytorch for macos
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
run: |
python -m pip install torch torchvision torchaudio
@@ -61,85 +50,33 @@ jobs:
- name: Set up Python tools
run: |
python -m pip install --upgrade cython
python -m pip install -e .[dev]
make dev
- name: Lint with Black
# Python 3.7 will use a black with low level. So we use python with higher version for black check
if: (matrix.python-version != '3.7')
run: |
pip install -U black # follow the latest version of black, previous Qlib dependency will downgrade black
black . -l 120 --check --diff
make black
- name: Make html with sphinx
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
if: ${{ matrix.os == 'ubuntu-22.04' }}
run: |
cd docs
sphinx-build -W --keep-going -b html . _build
cd ..
make docs-gen
# Check Qlib with pylint
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
# C0103: invalid-name
# C0209: consider-using-f-string
# R0402: consider-using-from-import
# R1705: no-else-return
# R1710: inconsistent-return-statements
# R1725: super-with-arguments
# R1735: use-dict-literal
# W0102: dangerous-default-value
# W0212: protected-access
# W0221: arguments-differ
# W0223: abstract-method
# W0231: super-init-not-called
# W0237: arguments-renamed
# W0612: unused-variable
# W0621: redefined-outer-name
# W0622: redefined-builtin
# FIXME: specify exception type
# W0703: broad-except
# W1309: f-string-without-interpolation
# E1102: not-callable
# E1136: unsubscriptable-object
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
- name: Check Qlib with pylint
run: |
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
make pylint
# The following flake8 error codes were ignored:
# E501 line too long
# Description: We have used black to limit the length of each line to 120.
# F541 f-string is missing placeholders
# Description: The same thing is done when using pylint for detection.
# E266 too many leading '#' for block comment
# Description: To make the code more readable, a lot of "#" is used.
# This error code appears centrally in:
# qlib/backtest/executor.py
# qlib/data/ops.py
# qlib/utils/__init__.py
# E402 module level import not at top of file
# Description: There are times when module level import is not available at the top of the file.
# W503 line break before binary operator
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
# E731 do not assign a lambda expression, use a def
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
# E203 whitespace before ':'
# Description: If there is whitespace before ":", it cannot pass the black check.
- name: Check Qlib with flake8
run: |
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
make flake8
# https://github.com/python/mypy/issues/10600
- name: Check Qlib with mypy
run: |
mypy qlib --install-types --non-interactive || true
mypy qlib --verbose
make mypy
- name: Check Qlib ipynb with nbqa
run: |
nbqa black . -l 120 --check --diff
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$'
make nbqa
- name: Test data downloads
run: |
@@ -147,7 +84,7 @@ jobs:
python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl
- name: Install Lightgbm for MacOS
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
@@ -157,11 +94,9 @@ jobs:
brew unlink libomp
brew install libomp.rb
# Run after data downloads
- name: Check Qlib ipynb with nbconvert
run: |
# add more ipynb files in future
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
make nbconvert
- name: Test workflow by config (install from source)
run: |

View File

@@ -14,44 +14,31 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
# Since macos-latest changed from 12.7.4 to 14.4.1,
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
# so we limit the macos version to macos-12.
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
# In github action, using python 3.7, pip install will not match the latest version of the package.
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
# All things considered, we have removed python 3.7.
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- name: Test qlib from source slow
uses: actions/checkout@v3
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
# So we make the version number of python 3.7 for MacOS more specific.
# refs: https://github.com/actions/setup-python/issues/682
- name: Set up Python ${{ matrix.python-version }}
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
uses: actions/setup-python@v4
with:
python-version: "3.7.16"
- name: Set up Python ${{ matrix.python-version }}
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Set up Python tools
run: |
python -m pip install --upgrade pip
pip install --upgrade cython numpy
pip install -e .[dev]
make dev
- name: Downloads dependencies data
run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- name: Install Lightgbm for MacOS
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm

3
.gitignore vendored
View File

@@ -48,4 +48,5 @@ tags
*.swp
./pretrain
.idea/
.idea/
.aider*

View File

@@ -9,7 +9,7 @@ version: 2
build:
os: ubuntu-22.04
tools:
python: "3.7"
python: "3.8"
# Build documentation in the docs/ directory with Sphinx
sphinx:

31
Dockerfile Normal file
View File

@@ -0,0 +1,31 @@
FROM continuumio/miniconda3:latest
WORKDIR /qlib
COPY . .
RUN apt-get update && \
apt-get install -y build-essential
RUN conda create --name qlib_env python=3.8 -y
RUN echo "conda activate qlib_env" >> ~/.bashrc
ENV PATH /opt/conda/envs/qlib_env/bin:$PATH
RUN python -m pip install --upgrade pip
RUN python -m pip install numpy==1.23.5
RUN python -m pip install pandas==1.5.3
RUN python -m pip install importlib-metadata==5.2.0
RUN python -m pip install "cloudpickle<3"
RUN python -m pip install scikit-learn==1.3.2
RUN python -m pip install cython packaging tables matplotlib statsmodels
RUN python -m pip install pybind11 cvxpy
ARG IS_STABLE="yes"
RUN if [ "$IS_STABLE" = "yes" ]; then \
python -m pip install pyqlib; \
else \
python setup.py install; \
fi

View File

@@ -1 +1,6 @@
include qlib/VERSION.txt
exclude tests/*
include qlib/*
include qlib/*/*
include qlib/*/*/*
include qlib/*/*/*/*
include qlib/*/*/*/*/*

195
Makefile Normal file
View File

@@ -0,0 +1,195 @@
.PHONY: clean deepclean prerequisite dependencies lightgbm rl develop lint docs package test analysis all install dev black pylint flake8 mypy nbqa nbconvert lint build upload docs-gen
#You can modify it according to your terminal
SHELL := /bin/bash
########################################################################################
# Variables
########################################################################################
# Documentation target directory, will be adapted to specific folder for readthedocs.
PUBLIC_DIR := $(shell [ "$$READTHEDOCS" = "True" ] && echo "$$READTHEDOCS_OUTPUT/html" || echo "public")
SO_DIR := qlib/data/_libs
SO_FILES := $(wildcard $(SO_DIR)/*.so)
########################################################################################
# Development Environment Management
########################################################################################
# Remove common intermediate files.
clean:
-rm -rf \
$(PUBLIC_DIR) \
qlib/data/_libs/*.cpp \
qlib/data/_libs/*.so \
mlruns \
public \
build \
.coverage \
.mypy_cache \
.pytest_cache \
.ruff_cache \
Pipfile* \
coverage.xml \
dist \
release-notes.md
find . -name '*.egg-info' -print0 | xargs -0 rm -rf
find . -name '*.pyc' -print0 | xargs -0 rm -f
find . -name '*.swp' -print0 | xargs -0 rm -f
find . -name '.DS_Store' -print0 | xargs -0 rm -f
find . -name '__pycache__' -print0 | xargs -0 rm -rf
# Remove pre-commit hook, virtual environment alongside itermediate files.
deepclean: clean
if command -v pre-commit > /dev/null 2>&1; then pre-commit uninstall --hook-type pre-push; fi
if command -v pipenv >/dev/null 2>&1 && pipenv --venv >/dev/null 2>&1; then pipenv --rm; fi
# Prerequisite section
# What this code does is compile two Cython modules, rolling and expanding, using setuptools and Cython,
# and builds them as binary expansion modules that can be imported directly into Python.
# Since pyproject.toml can't do that, we compile it here.
prerequisite:
@if [ -n "$(SO_FILES)" ]; then \
echo "Shared library files exist, skipping build."; \
else \
echo "No shared library files found, building..."; \
pip install --upgrade setuptools wheel; \
python -m pip install cython numpy; \
python -c "from setuptools import setup, Extension; from Cython.Build import cythonize; import numpy; extensions = [Extension('qlib.data._libs.rolling', ['qlib/data/_libs/rolling.pyx'], language='c++', include_dirs=[numpy.get_include()]), Extension('qlib.data._libs.expanding', ['qlib/data/_libs/expanding.pyx'], language='c++', include_dirs=[numpy.get_include()])]; setup(ext_modules=cythonize(extensions, language_level='3'), script_args=['build_ext', '--inplace'])"; \
fi
# Install the package in editable mode.
dependencies:
python -m pip install -e .
lightgbm:
python -m pip install lightgbm --prefer-binary
rl:
python -m pip install -e .[rl]
develop:
python -m pip install -e .[dev]
lint:
python -m pip install -e .[lint]
docs:
python -m pip install -e .[docs]
package:
python -m pip install -e .[package]
test:
python -m pip install -e .[test]
analysis:
python -m pip install -e .[analysis]
all:
python -m pip install -e .[dev,lint,docs,package,test,analysis,rl]
install: prerequisite dependencies
dev: prerequisite all
########################################################################################
# Lint and pre-commit
########################################################################################
# Check lint with black.
black:
black . -l 120 --check --diff
# Check code folder with pylint.
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
# C0103: invalid-name
# C0209: consider-using-f-string
# R0402: consider-using-from-import
# R1705: no-else-return
# R1710: inconsistent-return-statements
# R1725: super-with-arguments
# R1735: use-dict-literal
# W0102: dangerous-default-value
# W0212: protected-access
# W0221: arguments-differ
# W0223: abstract-method
# W0231: super-init-not-called
# W0237: arguments-renamed
# W0612: unused-variable
# W0621: redefined-outer-name
# W0622: redefined-builtin
# FIXME: specify exception type
# W0703: broad-except
# W1309: f-string-without-interpolation
# E1102: not-callable
# E1136: unsubscriptable-object
# W4904: deprecated-class
# R0917: too-many-positional-arguments
# E1123: unexpected-keyword-arg
# References for disable error: https://pylint.pycqa.org/en/latest/user_guide/messages/messages_overview.html
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
pylint:
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,W4904,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1730,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' qlib --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,E1123,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' scripts --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
# Check code with flake8.
# The following flake8 error codes were ignored:
# E501 line too long
# Description: We have used black to limit the length of each line to 120.
# F541 f-string is missing placeholders
# Description: The same thing is done when using pylint for detection.
# E266 too many leading '#' for block comment
# Description: To make the code more readable, a lot of "#" is used.
# This error code appears centrally in:
# qlib/backtest/executor.py
# qlib/data/ops.py
# qlib/utils/__init__.py
# E402 module level import not at top of file
# Description: There are times when module level import is not available at the top of the file.
# W503 line break before binary operator
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
# E731 do not assign a lambda expression, use a def
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
# E203 whitespace before ':'
# Description: If there is whitespace before ":", it cannot pass the black check.
flake8:
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
# Check code with mypy.
# https://github.com/python/mypy/issues/10600
mypy:
mypy qlib --install-types --non-interactive
mypy qlib --verbose
# Check ipynb with nbqa.
nbqa:
nbqa black . -l 120 --check --diff
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}'
# Check ipynb with nbconvert.(Run after data downloads)
# TODO: Add more ipynb files in future
nbconvert:
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
lint: black pylint flake8 mypy nbqa
########################################################################################
# Package
########################################################################################
# Build the package.
build:
python -m build --wheel
# Upload the package.
upload:
python -m twine upload dist/*
########################################################################################
# Documentation
########################################################################################
docs-gen:
python -m sphinx.cmd.build -W docs $(PUBLIC_DIR)

View File

@@ -8,9 +8,30 @@
[![Join the chat at https://gitter.im/Microsoft/qlib](https://badges.gitter.im/Microsoft/qlib.svg)](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
## :newspaper: **What's NEW!** &nbsp; :sparkling_heart:
Recent released features
### Introducing <a href="https://github.com/microsoft/RD-Agent"><img src="docs/_static/img/rdagent_logo.png" alt="RD_Agent" style="height: 2em"></a>: LLM-Based Autonomous Evolving Agents for Industrial Data-Driven R&D
We are excited to announce the release of **RD-Agent**📢, a powerful tool that supports automated factor mining and model optimization in quant investment R&D.
RD-Agent is now available on [GitHub](https://github.com/microsoft/RD-Agent), and we welcome your star🌟!
To learn more, please visit our [Demo page](https://rdagent.azurewebsites.net/). Here, you will find demo videos in both English and Chinese to help you better understand the scenario and usage of RD-Agent.
We have prepared several demo videos for you:
| Scenario | Demo video (English) | Demo video (中文) |
| -- | ------ | ------ |
| Quant Factor Mining | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=zh) |
| Quant Factor Mining from reports | [Link](https://rdagent.azurewebsites.net/report_factor?lang=en) | [Link](https://rdagent.azurewebsites.net/report_factor?lang=zh) |
| Quant Model Optimization | [Link](https://rdagent.azurewebsites.net/model_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/model_loop?lang=zh) |
***
| Feature | Status |
| -- | ------ |
| BPQP for End-to-end learning | 📈Coming soon!([Under review](https://github.com/microsoft/qlib/pull/1863)) |
| 🔥LLM-driven Auto Quant Factory🔥 | 🚀 Released in [RD-Agent](https://github.com/microsoft/RD-Agent) on Aug 8, 2024 |
| KRNN and Sandwich models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1414/) on May 26, 2023 |
| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |
| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|
@@ -40,7 +61,7 @@ Recent released features
Features released before 2021 are not listed here.
<p align="center">
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
<img src="docs/_static/img/logo/1.png" />
</p>
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
@@ -132,11 +153,11 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
## Installation
This table demonstrates the supported Python version of `Qlib`:
| | install with pip | install from source | plot |
| ------------- |:---------------------:|:--------------------:|:----:|
| | install with pip | install from source | plot |
| ------------- |:---------------------:|:--------------------:|:------------------:|
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
**Note**:
1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.
@@ -166,7 +187,7 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
* Clone the repository and install ``Qlib`` as follows.
```bash
git clone https://github.com/microsoft/qlib.git && cd qlib
pip install .
pip install . # `pip install -e .[dev]` is recommended for development. check details in docs/developer/code_standard_and_dev_guide.rst
```
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommended approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
@@ -175,6 +196,20 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
**Tips for Mac**: If you are using Mac with M1, you might encounter issues in building the wheel for LightGBM, which is due to missing dependencies from OpenMP. To solve the problem, install openmp first with ``brew install libomp`` and then run ``pip install .`` to build it successfully.
## Data Preparation
❗ Due to more restrict data security policy. The offical dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.
Here is an example to download the data updated on 20240809.
```bash
wget https://github.com/chenditc/investment_data/releases/download/2024-08-09/qlib_bin.tar.gz
mkdir -p ~/.qlib/qlib_data/cn_data
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=1
rm -f qlib_bin.tar.gz
```
The official dataset below will resume in short future.
----
Load and prepare data by running the following code:
### Get with module
@@ -258,6 +293,38 @@ We recommend users to prepare their own data if they have a high-quality dataset
```
-->
## Docker images
1. Pulling a docker image from a docker hub repository
```bash
docker pull pyqlib/qlib_image_stable:stable
```
2. Start a new Docker container
```bash
docker run -it --name <container name> -v <Mounted local directory>:/app qlib_image_stable
```
3. At this point you are in the docker environment and can run the qlib scripts. An example:
```bash
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
>>> python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
```
4. Exit the container
```bash
>>> exit
```
5. Restart the container
```bash
docker start -i -a <container name>
```
6. Stop the container
```bash
docker stop <container name>
```
7. Delete the container
```bash
docker rm <container name>
```
8. If you want to know more information, please refer to the [documentation](https://qlib.readthedocs.io/en/latest/developer/how_to_build_image.html).
## Auto Quant Research Workflow
Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
@@ -291,22 +358,22 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
```
Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
2. Graphical Reports Analysis: First, run `python -m pip install .[analysis]` to install the required dependencies. Then run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports.
- Forecasting signal (model prediction) analysis
- Cumulative Return of groups
![Cumulative Return](http://fintech.msra.cn/images_v070/analysis/analysis_model_cumulative_return.png?v=0.1)
![Cumulative Return](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_cumulative_return.png)
- Return distribution
![long_short](http://fintech.msra.cn/images_v070/analysis/analysis_model_long_short.png?v=0.1)
![long_short](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_long_short.png)
- Information Coefficient (IC)
![Information Coefficient](http://fintech.msra.cn/images_v070/analysis/analysis_model_IC.png?v=0.1)
![Monthly IC](http://fintech.msra.cn/images_v070/analysis/analysis_model_monthly_IC.png?v=0.1)
![IC](http://fintech.msra.cn/images_v070/analysis/analysis_model_NDQ.png?v=0.1)
![Information Coefficient](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_IC.png)
![Monthly IC](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_monthly_IC.png)
![IC](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_NDQ.png)
- Auto Correlation of forecasting signal (model prediction)
![Auto Correlation](http://fintech.msra.cn/images_v070/analysis/analysis_model_auto_correlation.png?v=0.1)
![Auto Correlation](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_auto_correlation.png)
- Portfolio analysis
- Backtest return
![Report](http://fintech.msra.cn/images_v070/analysis/report.png?v=0.1)
![Report](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/report.png)
<!--
- Score IC
![Score IC](docs/_static/img/score_ic.png)
@@ -485,7 +552,7 @@ Qlib data are stored in a compact format, which is efficient to be combined into
Join IM discussion groups:
|[Gitter](https://gitter.im/Microsoft/qlib)|
|----|
|![image](http://fintech.msra.cn/images_v070/qrcode/gitter_qr.png)|
|![image](https://github.com/microsoft/qlib/blob/main/docs/_static/img/qrcode/gitter_qr.png)|
# Contributing
We appreciate all contributions and thank all the contributors!

31
build_docker_image.sh Normal file
View File

@@ -0,0 +1,31 @@
#!/bin/bash
docker_user="your_dockerhub_username"
read -p "Do you want to build the nightly version of the qlib image? (default is stable) (yes/no): " answer;
answer=$(echo "$answer" | tr '[:upper:]' '[:lower:]')
if [ "$answer" = "yes" ]; then
# Build the nightly version of the qlib image
docker build --build-arg IS_STABLE=no -t qlib_image -f ./Dockerfile .
image_tag="nightly"
else
# Build the stable version of the qlib image
docker build -t qlib_image -f ./Dockerfile .
image_tag="stable"
fi
read -p "Is it uploaded to docker hub? (default is no) (yes/no): " answer;
answer=$(echo "$answer" | tr '[:upper:]' '[:lower:]')
if [ "$answer" = "yes" ]; then
# Log in to Docker Hub
# If you are a new docker hub user, please verify your email address before proceeding with this step.
docker login
# Tag the Docker image
docker tag qlib_image "$docker_user/qlib_image:$image_tag"
# Push the Docker image to Docker Hub
docker push "$docker_user/qlib_image:$image_tag"
else
echo "Not uploaded to docker hub."
fi

BIN
docs/_static/img/rdagent_logo.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

View File

@@ -86,7 +86,7 @@ Example
},
}
# model initiaiton
# model initialization
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])

View File

@@ -123,7 +123,6 @@ html_logo = "_static/img/logo/1.png"
html_theme_options = {
"logo_only": True,
"collapse_navigation": False,
"display_version": False,
"navigation_depth": 4,
}

View File

@@ -60,4 +60,4 @@ The `[dev]` option will help you to install some related packages when developin
.. code-block:: bash
pip install -e .[dev]
pip install -e ".[dev]"

View File

@@ -0,0 +1,81 @@
.. _docker_image:
==================
Build Docker Image
==================
Dockerfile
==========
There is a **Dockerfile** file in the root directory of the project from which you can build the docker image. There are two build methods in Dockerfile to choose from.
When executing the build command, use the ``--build-arg`` parameter to control the image version. The ``--build-arg`` parameter defaults to ``yes``, which builds the ``stable`` version of the qlib image.
1.For the ``stable`` version, use ``pip install pyqlib`` to build the qlib image.
.. code-block:: bash
docker build --build-arg IS_STABLE=yes -t <image name> -f ./Dockerfile .
.. code-block:: bash
docker build -t <image name> -f ./Dockerfile .
2. For the ``nightly`` version, use current source code to build the qlib image.
.. code-block:: bash
docker build --build-arg IS_STABLE=no -t <image name> -f ./Dockerfile .
Auto build of qlib images
=========================
1. There is a **build_docker_image.sh** file in the root directory of your project, which can be used to automatically build docker images and upload them to your docker hub repository(Optional, configuration required).
.. code-block:: bash
sh build_docker_image.sh
>>> Do you want to build the nightly version of the qlib image? (default is stable) (yes/no):
>>> Is it uploaded to docker hub? (default is no) (yes/no):
2. If you want to upload the built image to your docker hub repository, you need to edit your **build_docker_image.sh** file first, fill in ``docker_user`` in the file, and then execute this file.
How to use qlib images
======================
1. Start a new Docker container
.. code-block:: bash
docker run -it --name <container name> -v <Mounted local directory>:/app <image name>
2. At this point you are in the docker environment and can run the qlib scripts. An example:
.. code-block:: bash
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
>>> python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
3. Exit the container
.. code-block:: bash
>>> exit
4. Restart the container
.. code-block:: bash
docker start -i -a <container name>
5. Stop the container
.. code-block:: bash
docker stop -i -a <container name>
6. Delete the container
.. code-block:: bash
docker rm <container name>
7. For more information on using docker see the `docker documentation <https://docs.docker.com/reference/cli/docker/>`_.

View File

@@ -61,6 +61,7 @@ Document Structure
:caption: FOR DEVELOPERS:
Code Standard & Development Guidance <developer/code_standard_and_dev_guide.rst>
How to build image <developer/how_to_build_image.rst>
.. toctree::
:maxdepth: 3

View File

@@ -5,3 +5,4 @@ scipy
scikit-learn
pandas
tianshou
sphinx_rtd_theme

View File

@@ -0,0 +1,19 @@
# Introduction
What is GeneralPtNN
- Fix previous design that fail to support both Time-series and tabular data
- Now you can just replace the Pytorch model structure to run a NN model.
We provide an example to demonstrate the effectiveness of the current design.
- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158-dataset)
- `workflow_config_gru2mlp.yaml` to demonstrate we can convert config from time-series to tabular data with minimal changes
- You only have to change the net & dataset class to make the conversion.
- `workflow_config_mlp.yaml` achieved similar functionality with [MLP](../README.md#Alpha158-dataset)
# TODO
- We will align existing models to current design.
- The result of `workflow_config_mlp.yaml` is different with the result of [MLP](../README.md#Alpha158-dataset) since GeneralPtNN has a different stopping method compared to previous implementations. Specificly, GeneralPtNN controls training according to epoches, whereas previous methods controlled by max_steps.

View File

@@ -0,0 +1,100 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal: <PRED>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: GeneralPTNN
module_path: qlib.contrib.model.pytorch_general_nn
kwargs:
n_epochs: 200
lr: 2e-4
early_stop: 10
batch_size: 800
metric: loss
loss: mse
n_jobs: 20
GPU: 0
pt_model_uri: "qlib.contrib.model.pytorch_gru_ts.GRUModel"
pt_model_kwargs: {
"d_feat": 20,
"hidden_size": 64,
"num_layers": 2,
"dropout": 0.,
}
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,93 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal: <PRED>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: GeneralPTNN
module_path: qlib.contrib.model.pytorch_general_nn
kwargs:
lr: 1e-3
n_epochs: 1
batch_size: 800
loss: mse
optimizer: adam
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
pt_model_kwargs:
input_dim: 20
layers: [20,]
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,98 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors: [
{
"class" : "DropCol",
"kwargs":{"col_list": ["VWAP0"]}
},
{
"class" : "CSZFillna",
"kwargs":{"fields_group": "feature"}
}
]
learn_processors: [
{
"class" : "DropCol",
"kwargs":{"col_list": ["VWAP0"]}
},
{
"class" : "DropnaProcessor",
"kwargs":{"fields_group": "feature"}
},
"DropnaLabel",
{
"class": "CSZScoreNorm",
"kwargs": {"fields_group": "label"}
}
]
process_type: "independent"
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal: <PRED>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: GeneralPTNN
module_path: qlib.contrib.model.pytorch_general_nn
kwargs:
# FIXME: wrong parameters.
lr: 2e-3
batch_size: 8192
loss: mse
weight_decay: 0.0002
optimizer: adam
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
pt_model_kwargs:
input_dim: 157
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -1,14 +1,15 @@
import argparse
import qlib
import ruamel.yaml as yaml
from ruamel.yaml import YAML
from qlib.utils import init_instance_by_config
def main(seed, config_file="configs/config_alstm.yaml"):
# set random seed
with open(config_file) as f:
config = yaml.safe_load(f)
yaml = YAML(typ="safe", pure=True)
config = yaml.load(f)
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
seed_suffix = ""

View File

@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Union
@@ -35,6 +36,10 @@ class DDGDABench(DDGDA):
if __name__ == "__main__":
GetData().qlib_data(exists_skip=True)
auto_init()
kwargs = {}
if os.environ.get("PROVIDER_URI", "") == "":
GetData().qlib_data(exists_skip=True)
else:
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
auto_init(**kwargs)
fire.Fire(DDGDABench)

View File

@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Union
@@ -31,6 +32,10 @@ class RollingBenchmark(Rolling):
if __name__ == "__main__":
GetData().qlib_data(exists_skip=True)
auto_init()
kwargs = {}
if os.environ.get("PROVIDER_URI", "") == "":
GetData().qlib_data(exists_skip=True)
else:
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
auto_init(**kwargs)
fire.Fire(RollingBenchmark)

View File

@@ -9,8 +9,8 @@ from copy import deepcopy
from pathlib import Path
import pickle
from pprint import pprint
from ruamel.yaml import YAML
import subprocess
import yaml
from qlib.log import TimeInspector
from qlib import init
@@ -30,7 +30,8 @@ if __name__ == "__main__":
subprocess.run(f"qrun {config_path}", shell=True)
# 2) dump handler
task_config = yaml.safe_load(config_path.open())
yaml = YAML(typ="safe", pure=True)
task_config = yaml.load(config_path.open())
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
pprint(hd_conf)
hd: DataHandlerLP = init_instance_by_config(hd_conf)

View File

@@ -9,10 +9,9 @@ from copy import deepcopy
from pathlib import Path
import pickle
from pprint import pprint
from ruamel.yaml import YAML
import subprocess
import yaml
from qlib import init
from qlib.data.dataset.handler import DataHandlerLP
from qlib.log import TimeInspector
@@ -29,7 +28,8 @@ if __name__ == "__main__":
exp_name = "data_mem_reuse_demo"
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
task_config = yaml.safe_load(config_path.open())
yaml = YAML(typ="safe", pure=True)
task_config = yaml.load(config_path.open())
# 1) without using processed data in memory
with TimeInspector.logt("The original time without reusing processed data in memory:"):

View File

@@ -16,7 +16,7 @@ Current version of script with default value tries to connect localhost **via de
Run following command to install necessary libraries
```
pip install pytest coverage
pip install pytest coverage gdown
pip install arctic # NOTE: pip may fail to resolve the right package dependency !!! Please make sure the dependency are satisfied.
```
@@ -27,7 +27,8 @@ pip install arctic # NOTE: pip may fail to resolve the right package dependency
2. Please follow following steps to download example data
```bash
cd examples/orderbook_data/
python ../../scripts/get_data.py download_data --target_dir . --file_name highfreq_orderbook_example_data.zip
gdown https://drive.google.com/uc?id=15nZF7tFT_eKVZAcMFL1qPS4jGyJflH7e # Proxies may be necessary here.
python ../../scripts/get_data.py _unzip --file_path highfreq_orderbook_example_data.zip --target_dir .
```
3. Please import the example data to your mongo db

View File

@@ -20,7 +20,7 @@ We use China stock market data for our example.
1. Prepare CSI300 weight:
```bash
wget http://fintech.msra.cn/stock_data/downloads/csi300_weight.zip
wget https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/csi300_weight.zip
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
rm -f csi300_weight.zip
```

View File

@@ -6,7 +6,6 @@ import sys
import fire
import time
import glob
import yaml
import shutil
import signal
import inspect
@@ -15,6 +14,7 @@ import functools
import statistics
import subprocess
from datetime import datetime
from ruamel.yaml import YAML
from pathlib import Path
from operator import xor
from pprint import pprint
@@ -188,7 +188,8 @@ def gen_and_save_md_table(metrics, dataset):
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
with open(yaml_path, "r") as fp:
config = yaml.safe_load(fp)
yaml = YAML(typ="safe", pure=True)
config = yaml.load(fp)
try:
del config["task"]["model"]["kwargs"]["seed"]
except KeyError:

View File

@@ -161,7 +161,7 @@
" },\n",
"}\n",
"\n",
"# model initiaiton\n",
"# model initialization\n",
"model = init_instance_by_config(task[\"model\"])\n",
"dataset = init_instance_by_config(task[\"dataset\"])\n",
"\n",

View File

@@ -1,2 +1,93 @@
[build-system]
requires = ["setuptools", "numpy", "Cython"]
requires = ["setuptools", "cython", "numpy>=1.24.0"]
build-backend = "setuptools.build_meta"
[project]
classifiers = [
"Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS",
"License :: OSI Approved :: MIT License",
"Development Status :: 3 - Alpha",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
name = "pyqlib"
dynamic = ["version"]
description = "A Quantitative-research Platform"
requires-python = ">=3.8.0"
readme = {file = "README.md", content-type = "text/markdown"}
dependencies = [
"pyyaml",
"numpy",
"pandas",
"mlflow",
"filelock>=3.16.0",
"redis",
"dill",
"fire",
"ruamel.yaml>=0.17.38",
"python-redis-lock",
"tqdm",
"pymongo",
"loguru",
"lightgbm",
"gym",
"cvxpy",
"joblib",
"matplotlib",
"jupyter",
"nbconvert",
]
[project.optional-dependencies]
dev = [
"pytest",
"statsmodels",
]
# On macos-13 system, when using python version greater than or equal to 3.10,
# pytorch can't fully support Numpy version above 2.0, so, when you want to install torch,
# it will limit the version of Numpy less than 2.0.
rl = [
"tianshou<=0.4.10",
"torch",
"numpy<2.0.0",
]
lint = [
"black",
"pylint",
"mypy<1.5.0",
"flake8",
"nbqa",
]
docs = [
"sphinx",
"sphinx_rtd_theme",
"readthedocs_sphinx_ext",
]
package = [
"twine",
"build",
]
# test_pit dependency packages
test = [
"yahooquery",
"baostock",
]
analysis = [
"plotly",
]
[tool.setuptools]
packages = [
"qlib",
]
[project.scripts]
qrun = "qlib.workflow.cli:run"

View File

@@ -2,11 +2,11 @@
# Licensed under the MIT License.
from pathlib import Path
__version__ = "0.9.4.99"
__version__ = "0.9.6"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
from typing import Union
import yaml
from ruamel.yaml import YAML
import logging
import platform
import subprocess
@@ -176,7 +176,8 @@ def init_from_yaml_conf(conf_path, **kwargs):
config = {}
else:
with open(conf_path) as f:
config = yaml.safe_load(f)
yaml = YAML(typ="safe", pure=True)
config = yaml.load(f)
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)
@@ -272,7 +273,8 @@ def auto_init(**kwargs):
logger = get_module_logger("Initialization")
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)
yaml = YAML(typ="safe", pure=True)
conf = yaml.load(f)
conf_type = conf.get("conf_type", "origin")
if conf_type == "origin":

View File

@@ -278,7 +278,7 @@ class BaseSingleMetric:
raise NotImplementedError(f"Please implement the `empty` method")
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
"""Replace np.NaN with fill_value in two metrics and add them."""
"""Replace np.nan with fill_value in two metrics and add them."""
raise NotImplementedError(f"Please implement the `add` method")
@@ -412,7 +412,7 @@ class BaseOrderIndicator:
metrics : Union[str, List[str]]
all metrics needs to be sumed.
fill_value : float, optional
fill np.NaN with value. By default None.
fill np.nan with value. By default None.
"""
raise NotImplementedError(f"Please implement the 'sum_all_indicators' method")

View File

@@ -325,9 +325,9 @@ class Indicator:
def _update_order_fulfill_rate(self) -> None:
def func(deal_amount, amount):
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
# deal_amount is np.nan or None when there is no inner decision. So full fill rate is 0.
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
tmp_deal_amount = tmp_deal_amount.replace({np.NaN: 0})
tmp_deal_amount = tmp_deal_amount.replace({np.nan: 0})
return tmp_deal_amount / amount
self.order_indicator.transfer(func, "ffr")
@@ -354,8 +354,8 @@ class Indicator:
)
def func(trade_price, deal_amount):
# trade_price is np.NaN instead of inf when deal_amount is zero.
tmp_deal_amount = deal_amount.replace({0: np.NaN})
# trade_price is np.nan instead of inf when deal_amount is zero.
tmp_deal_amount = deal_amount.replace({0: np.nan})
return trade_price / tmp_deal_amount
self.order_indicator.transfer(func, "trade_price")
@@ -425,7 +425,7 @@ class Indicator:
assert isinstance(price_s, idd.SingleData)
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
# ~(np.NaN < 1e-8) -> ~(False) -> True
# ~(np.nan < 1e-8) -> ~(False) -> True
assert isinstance(price_s, idd.SingleData)
if agg == "vwap":

View File

@@ -173,7 +173,11 @@ _default_config = {
"filters": ["field_not_found"],
}
},
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
# Normally this should be set to `False` to avoid duplicated logging [1].
# However, due to bug in pytest, it requires log message to propagate to root logger to be captured by `caplog` [2].
# [1] https://github.com/microsoft/qlib/pull/1661
# [2] https://github.com/pytest-dev/pytest/issues/3697
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"], "propagate": False}},
# To let qlib work with other packages, we shouldn't disable existing loggers.
# Note that this param is default to True according to the documentation of logging.
"disable_existing_loggers": False,

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.contrib.data.loader import Alpha158DL, Alpha360DL
from ...data.dataset.handler import DataHandlerLP
from ...data.dataset.processor import Processor
from ...utils import get_callable_kwargs
@@ -57,7 +58,7 @@ class Alpha360(DataHandlerLP):
fit_end_time=None,
filter_pipe=None,
inst_processors=None,
**kwargs
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
@@ -66,7 +67,7 @@ class Alpha360(DataHandlerLP):
"class": "QlibDataLoader",
"kwargs": {
"config": {
"feature": self.get_feature_config(),
"feature": Alpha360DL.get_feature_config(),
"label": kwargs.pop("label", self.get_label_config()),
},
"filter_pipe": filter_pipe,
@@ -82,57 +83,12 @@ class Alpha360(DataHandlerLP):
data_loader=data_loader,
learn_processors=learn_processors,
infer_processors=infer_processors,
**kwargs
**kwargs,
)
def get_label_config(self):
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
@staticmethod
def get_feature_config():
# NOTE:
# Alpha360 tries to provide a dataset with original price data
# the original price data includes the prices and volume in the last 60 days.
# To make it easier to learn models from this dataset, all the prices and volume
# are normalized by the latest price and volume data ( dividing by $close, $volume)
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
fields = []
names = []
for i in range(59, 0, -1):
fields += ["Ref($close, %d)/$close" % i]
names += ["CLOSE%d" % i]
fields += ["$close/$close"]
names += ["CLOSE0"]
for i in range(59, 0, -1):
fields += ["Ref($open, %d)/$close" % i]
names += ["OPEN%d" % i]
fields += ["$open/$close"]
names += ["OPEN0"]
for i in range(59, 0, -1):
fields += ["Ref($high, %d)/$close" % i]
names += ["HIGH%d" % i]
fields += ["$high/$close"]
names += ["HIGH0"]
for i in range(59, 0, -1):
fields += ["Ref($low, %d)/$close" % i]
names += ["LOW%d" % i]
fields += ["$low/$close"]
names += ["LOW0"]
for i in range(59, 0, -1):
fields += ["Ref($vwap, %d)/$close" % i]
names += ["VWAP%d" % i]
fields += ["$vwap/$close"]
names += ["VWAP0"]
for i in range(59, 0, -1):
fields += ["Ref($volume, %d)/($volume+1e-12)" % i]
names += ["VOLUME%d" % i]
fields += ["$volume/($volume+1e-12)"]
names += ["VOLUME0"]
return fields, names
class Alpha360vwap(Alpha360):
def get_label_config(self):
@@ -153,7 +109,7 @@ class Alpha158(DataHandlerLP):
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processors=None,
**kwargs
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
@@ -178,7 +134,7 @@ class Alpha158(DataHandlerLP):
infer_processors=infer_processors,
learn_processors=learn_processors,
process_type=process_type,
**kwargs
**kwargs,
)
def get_feature_config(self):
@@ -190,242 +146,11 @@ class Alpha158(DataHandlerLP):
},
"rolling": {},
}
return self.parse_config_to_fields(conf)
return Alpha158DL.get_feature_config(conf)
def get_label_config(self):
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
@staticmethod
def parse_config_to_fields(config):
"""create factors from config
config = {
'kbar': {}, # whether to use some hard-code kbar features
'price': { # whether to use raw price features
'windows': [0, 1, 2, 3, 4], # use price at n days ago
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
},
'volume': { # whether to use raw volume features
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
},
'rolling': { # whether to use rolling operator based features
'windows': [5, 10, 20, 30, 60], # rolling windows size
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
#if include is None we will use default operators
'exclude': ['RANK'], # rolling operator not to use
}
}
"""
fields = []
names = []
if "kbar" in config:
fields += [
"($close-$open)/$open",
"($high-$low)/$open",
"($close-$open)/($high-$low+1e-12)",
"($high-Greater($open, $close))/$open",
"($high-Greater($open, $close))/($high-$low+1e-12)",
"(Less($open, $close)-$low)/$open",
"(Less($open, $close)-$low)/($high-$low+1e-12)",
"(2*$close-$high-$low)/$open",
"(2*$close-$high-$low)/($high-$low+1e-12)",
]
names += [
"KMID",
"KLEN",
"KMID2",
"KUP",
"KUP2",
"KLOW",
"KLOW2",
"KSFT",
"KSFT2",
]
if "price" in config:
windows = config["price"].get("windows", range(5))
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
for field in feature:
field = field.lower()
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
names += [field.upper() + str(d) for d in windows]
if "volume" in config:
windows = config["volume"].get("windows", range(5))
fields += ["Ref($volume, %d)/($volume+1e-12)" % d if d != 0 else "$volume/($volume+1e-12)" for d in windows]
names += ["VOLUME" + str(d) for d in windows]
if "rolling" in config:
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
include = config["rolling"].get("include", None)
exclude = config["rolling"].get("exclude", [])
# `exclude` in dataset config unnecessary filed
# `include` in dataset config necessary field
def use(x):
return x not in exclude and (include is None or x in include)
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
if use("ROC"):
# https://www.investopedia.com/terms/r/rateofchange.asp
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
fields += ["Ref($close, %d)/$close" % d for d in windows]
names += ["ROC%d" % d for d in windows]
if use("MA"):
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
fields += ["Mean($close, %d)/$close" % d for d in windows]
names += ["MA%d" % d for d in windows]
if use("STD"):
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
fields += ["Std($close, %d)/$close" % d for d in windows]
names += ["STD%d" % d for d in windows]
if use("BETA"):
# The rate of close price change in the past d days, divided by latest close price to remove unit
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
fields += ["Slope($close, %d)/$close" % d for d in windows]
names += ["BETA%d" % d for d in windows]
if use("RSQR"):
# The R-sqaure value of linear regression for the past d days, represent the trend linear
fields += ["Rsquare($close, %d)" % d for d in windows]
names += ["RSQR%d" % d for d in windows]
if use("RESI"):
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
fields += ["Resi($close, %d)/$close" % d for d in windows]
names += ["RESI%d" % d for d in windows]
if use("MAX"):
# The max price for past d days, divided by latest close price to remove unit
fields += ["Max($high, %d)/$close" % d for d in windows]
names += ["MAX%d" % d for d in windows]
if use("LOW"):
# The low price for past d days, divided by latest close price to remove unit
fields += ["Min($low, %d)/$close" % d for d in windows]
names += ["MIN%d" % d for d in windows]
if use("QTLU"):
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
# Used with MIN and MAX
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
names += ["QTLU%d" % d for d in windows]
if use("QTLD"):
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
names += ["QTLD%d" % d for d in windows]
if use("RANK"):
# Get the percentile of current close price in past d day's close price.
# Represent the current price level comparing to past N days, add additional information to moving average.
fields += ["Rank($close, %d)" % d for d in windows]
names += ["RANK%d" % d for d in windows]
if use("RSV"):
# Represent the price position between upper and lower resistent price for past d days.
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
names += ["RSV%d" % d for d in windows]
if use("IMAX"):
# The number of days between current date and previous highest price date.
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
# The indicator measures the time between highs and the time between lows over a time period.
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
names += ["IMAX%d" % d for d in windows]
if use("IMIN"):
# The number of days between current date and previous lowest price date.
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
# The indicator measures the time between highs and the time between lows over a time period.
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
names += ["IMIN%d" % d for d in windows]
if use("IMXD"):
# The time period between previous lowest-price date occur after highest price date.
# Large value suggest downward momemtum.
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
names += ["IMXD%d" % d for d in windows]
if use("CORR"):
# The correlation between absolute close price and log scaled trading volume
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
names += ["CORR%d" % d for d in windows]
if use("CORD"):
# The correlation between price change ratio and volume change ratio
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
names += ["CORD%d" % d for d in windows]
if use("CNTP"):
# The percentage of days in past d days that price go up.
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
names += ["CNTP%d" % d for d in windows]
if use("CNTN"):
# The percentage of days in past d days that price go down.
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
names += ["CNTN%d" % d for d in windows]
if use("CNTD"):
# The diff between past up day and past down day
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
names += ["CNTD%d" % d for d in windows]
if use("SUMP"):
# The total gain / the absolute total price changed
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMP%d" % d for d in windows]
if use("SUMN"):
# The total lose / the absolute total price changed
# Can be derived from SUMP by SUMN = 1 - SUMP
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMN%d" % d for d in windows]
if use("SUMD"):
# The diff ratio between total gain and total lose
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["SUMD%d" % d for d in windows]
if use("VMA"):
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VMA%d" % d for d in windows]
if use("VSTD"):
# The standard deviation for volume in past d days.
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VSTD%d" % d for d in windows]
if use("WVMA"):
# The volume weighted price change volatility
fields += [
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["WVMA%d" % d for d in windows]
if use("VSUMP"):
# The total volume increase / the absolute total volume changed
fields += [
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMP%d" % d for d in windows]
if use("VSUMN"):
# The total volume increase / the absolute total volume changed
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
fields += [
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMN%d" % d for d in windows]
if use("VSUMD"):
# The diff ratio between total volume increase and total volume decrease
# RSI indicator for volume
fields += [
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["VSUMD%d" % d for d in windows]
return fields, names
class Alpha158vwap(Alpha158):
def get_label_config(self):

310
qlib/contrib/data/loader.py Normal file
View File

@@ -0,0 +1,310 @@
from qlib.data.dataset.loader import QlibDataLoader
class Alpha360DL(QlibDataLoader):
"""Dataloader to get Alpha360"""
def __init__(self, config=None, **kwargs):
_config = {
"feature": self.get_feature_config(),
}
if config is not None:
_config.update(config)
super().__init__(config=_config, **kwargs)
@staticmethod
def get_feature_config():
# NOTE:
# Alpha360 tries to provide a dataset with original price data
# the original price data includes the prices and volume in the last 60 days.
# To make it easier to learn models from this dataset, all the prices and volume
# are normalized by the latest price and volume data ( dividing by $close, $volume)
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
fields = []
names = []
for i in range(59, 0, -1):
fields += ["Ref($close, %d)/$close" % i]
names += ["CLOSE%d" % i]
fields += ["$close/$close"]
names += ["CLOSE0"]
for i in range(59, 0, -1):
fields += ["Ref($open, %d)/$close" % i]
names += ["OPEN%d" % i]
fields += ["$open/$close"]
names += ["OPEN0"]
for i in range(59, 0, -1):
fields += ["Ref($high, %d)/$close" % i]
names += ["HIGH%d" % i]
fields += ["$high/$close"]
names += ["HIGH0"]
for i in range(59, 0, -1):
fields += ["Ref($low, %d)/$close" % i]
names += ["LOW%d" % i]
fields += ["$low/$close"]
names += ["LOW0"]
for i in range(59, 0, -1):
fields += ["Ref($vwap, %d)/$close" % i]
names += ["VWAP%d" % i]
fields += ["$vwap/$close"]
names += ["VWAP0"]
for i in range(59, 0, -1):
fields += ["Ref($volume, %d)/($volume+1e-12)" % i]
names += ["VOLUME%d" % i]
fields += ["$volume/($volume+1e-12)"]
names += ["VOLUME0"]
return fields, names
class Alpha158DL(QlibDataLoader):
"""Dataloader to get Alpha158"""
def __init__(self, config=None, **kwargs):
_config = {
"feature": self.get_feature_config(),
}
if config is not None:
_config.update(config)
super().__init__(config=_config, **kwargs)
@staticmethod
def get_feature_config(
config={
"kbar": {},
"price": {
"windows": [0],
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
},
"rolling": {},
}
):
"""create factors from config
config = {
'kbar': {}, # whether to use some hard-code kbar features
'price': { # whether to use raw price features
'windows': [0, 1, 2, 3, 4], # use price at n days ago
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
},
'volume': { # whether to use raw volume features
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
},
'rolling': { # whether to use rolling operator based features
'windows': [5, 10, 20, 30, 60], # rolling windows size
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
#if include is None we will use default operators
'exclude': ['RANK'], # rolling operator not to use
}
}
"""
fields = []
names = []
if "kbar" in config:
fields += [
"($close-$open)/$open",
"($high-$low)/$open",
"($close-$open)/($high-$low+1e-12)",
"($high-Greater($open, $close))/$open",
"($high-Greater($open, $close))/($high-$low+1e-12)",
"(Less($open, $close)-$low)/$open",
"(Less($open, $close)-$low)/($high-$low+1e-12)",
"(2*$close-$high-$low)/$open",
"(2*$close-$high-$low)/($high-$low+1e-12)",
]
names += [
"KMID",
"KLEN",
"KMID2",
"KUP",
"KUP2",
"KLOW",
"KLOW2",
"KSFT",
"KSFT2",
]
if "price" in config:
windows = config["price"].get("windows", range(5))
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
for field in feature:
field = field.lower()
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
names += [field.upper() + str(d) for d in windows]
if "volume" in config:
windows = config["volume"].get("windows", range(5))
fields += ["Ref($volume, %d)/($volume+1e-12)" % d if d != 0 else "$volume/($volume+1e-12)" for d in windows]
names += ["VOLUME" + str(d) for d in windows]
if "rolling" in config:
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
include = config["rolling"].get("include", None)
exclude = config["rolling"].get("exclude", [])
# `exclude` in dataset config unnecessary filed
# `include` in dataset config necessary field
def use(x):
return x not in exclude and (include is None or x in include)
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
if use("ROC"):
# https://www.investopedia.com/terms/r/rateofchange.asp
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
fields += ["Ref($close, %d)/$close" % d for d in windows]
names += ["ROC%d" % d for d in windows]
if use("MA"):
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
fields += ["Mean($close, %d)/$close" % d for d in windows]
names += ["MA%d" % d for d in windows]
if use("STD"):
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
fields += ["Std($close, %d)/$close" % d for d in windows]
names += ["STD%d" % d for d in windows]
if use("BETA"):
# The rate of close price change in the past d days, divided by latest close price to remove unit
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
fields += ["Slope($close, %d)/$close" % d for d in windows]
names += ["BETA%d" % d for d in windows]
if use("RSQR"):
# The R-sqaure value of linear regression for the past d days, represent the trend linear
fields += ["Rsquare($close, %d)" % d for d in windows]
names += ["RSQR%d" % d for d in windows]
if use("RESI"):
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
fields += ["Resi($close, %d)/$close" % d for d in windows]
names += ["RESI%d" % d for d in windows]
if use("MAX"):
# The max price for past d days, divided by latest close price to remove unit
fields += ["Max($high, %d)/$close" % d for d in windows]
names += ["MAX%d" % d for d in windows]
if use("LOW"):
# The low price for past d days, divided by latest close price to remove unit
fields += ["Min($low, %d)/$close" % d for d in windows]
names += ["MIN%d" % d for d in windows]
if use("QTLU"):
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
# Used with MIN and MAX
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
names += ["QTLU%d" % d for d in windows]
if use("QTLD"):
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
names += ["QTLD%d" % d for d in windows]
if use("RANK"):
# Get the percentile of current close price in past d day's close price.
# Represent the current price level comparing to past N days, add additional information to moving average.
fields += ["Rank($close, %d)" % d for d in windows]
names += ["RANK%d" % d for d in windows]
if use("RSV"):
# Represent the price position between upper and lower resistent price for past d days.
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
names += ["RSV%d" % d for d in windows]
if use("IMAX"):
# The number of days between current date and previous highest price date.
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
# The indicator measures the time between highs and the time between lows over a time period.
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
names += ["IMAX%d" % d for d in windows]
if use("IMIN"):
# The number of days between current date and previous lowest price date.
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
# The indicator measures the time between highs and the time between lows over a time period.
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
names += ["IMIN%d" % d for d in windows]
if use("IMXD"):
# The time period between previous lowest-price date occur after highest price date.
# Large value suggest downward momemtum.
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
names += ["IMXD%d" % d for d in windows]
if use("CORR"):
# The correlation between absolute close price and log scaled trading volume
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
names += ["CORR%d" % d for d in windows]
if use("CORD"):
# The correlation between price change ratio and volume change ratio
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
names += ["CORD%d" % d for d in windows]
if use("CNTP"):
# The percentage of days in past d days that price go up.
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
names += ["CNTP%d" % d for d in windows]
if use("CNTN"):
# The percentage of days in past d days that price go down.
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
names += ["CNTN%d" % d for d in windows]
if use("CNTD"):
# The diff between past up day and past down day
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
names += ["CNTD%d" % d for d in windows]
if use("SUMP"):
# The total gain / the absolute total price changed
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMP%d" % d for d in windows]
if use("SUMN"):
# The total lose / the absolute total price changed
# Can be derived from SUMP by SUMN = 1 - SUMP
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMN%d" % d for d in windows]
if use("SUMD"):
# The diff ratio between total gain and total lose
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
fields += [
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["SUMD%d" % d for d in windows]
if use("VMA"):
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VMA%d" % d for d in windows]
if use("VSTD"):
# The standard deviation for volume in past d days.
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VSTD%d" % d for d in windows]
if use("WVMA"):
# The volume weighted price change volatility
fields += [
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["WVMA%d" % d for d in windows]
if use("VSUMP"):
# The total volume increase / the absolute total volume changed
fields += [
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMP%d" % d for d in windows]
if use("VSUMN"):
# The total volume increase / the absolute total volume changed
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
fields += [
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMN%d" % d for d in windows]
if use("VSUMD"):
# The diff ratio between total volume increase and total volume decrease
# RSI indicator for volume
fields += [
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["VSUMD%d" % d for d in windows]
return fields, names

View File

@@ -243,7 +243,7 @@ class MetaDatasetDS(MetaTaskDataset):
trunc_days: int = None,
rolling_ext_days: int = 0,
exp_name: Union[str, InternalData],
segments: Union[Dict[Text, Tuple], float],
segments: Union[Dict[Text, Tuple], float, str],
hist_step_n: int = 10,
task_mode: str = MetaTask.PROC_MODE_FULL,
fill_method: str = "max",
@@ -271,12 +271,16 @@ class MetaDatasetDS(MetaTaskDataset):
- str: the name of the experiment to store the performance of data
- InternalData: a prepared internal data
segments: Union[Dict[Text, Tuple], float]
the segments to divide data
both left and right
if the segment is a Dict
the segments to divide data
both left and right are included
if segments is a float:
the float represents the percentage of data for training
if segments is a string:
it will try its best to put its data in training and ensure that the date `segments` is in the test set
hist_step_n: int
length of historical steps for the meta infomation
Number of steps of the data similarity information
task_mode : str
Please refer to the docs of MetaTask
"""
@@ -383,10 +387,30 @@ class MetaDatasetDS(MetaTaskDataset):
if isinstance(self.segments, float):
train_task_n = int(len(self.meta_task_l) * self.segments)
if segment == "train":
return self.meta_task_l[:train_task_n]
train_tasks = self.meta_task_l[:train_task_n]
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
return train_tasks
elif segment == "test":
return self.meta_task_l[train_task_n:]
test_tasks = self.meta_task_l[train_task_n:]
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
return test_tasks
else:
raise NotImplementedError(f"This type of input is not supported")
elif isinstance(self.segments, str):
train_tasks = []
test_tasks = []
for t in self.meta_task_l:
test_end = t.task["dataset"]["kwargs"]["segments"]["test"][1]
if test_end is None or pd.Timestamp(test_end) < pd.Timestamp(self.segments):
train_tasks.append(t)
else:
test_tasks.append(t)
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
if segment == "train":
return train_tasks
elif segment == "test":
return test_tasks
raise NotImplementedError(f"This type of input is not supported")
else:
raise NotImplementedError(f"This type of input is not supported")

View File

@@ -53,7 +53,12 @@ class MetaModelDS(MetaTaskModel):
max_epoch=100,
seed=43,
alpha=0.0,
loss_skip_thresh=50,
):
"""
loss_skip_size: int
The number of threshold to skip the loss calculation for each day.
"""
self.step = step
self.hist_step_n = hist_step_n
self.clip_method = clip_method
@@ -63,6 +68,7 @@ class MetaModelDS(MetaTaskModel):
self.max_epoch = max_epoch
self.fitted = False
self.alpha = alpha
self.loss_skip_thresh = loss_skip_thresh
torch.manual_seed(seed)
def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
@@ -88,12 +94,14 @@ class MetaModelDS(MetaTaskModel):
criterion = nn.MSELoss()
loss = criterion(pred, meta_input["y_test"])
elif self.criterion == "ic_loss":
criterion = ICLoss()
criterion = ICLoss(self.loss_skip_thresh)
try:
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"], skip_size=50)
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"])
except ValueError as e:
get_module_logger("MetaModelDS").warning(f"Exception `{e}` when calculating IC loss")
continue
else:
raise ValueError(f"Unknown criterion: {self.criterion}")
assert not np.isnan(loss.detach().item()), "NaN loss!"

View File

@@ -10,7 +10,11 @@ from qlib.log import get_module_logger
class ICLoss(nn.Module):
def forward(self, pred, y, idx, skip_size=50):
def __init__(self, skip_size=50):
super().__init__()
self.skip_size = skip_size
def forward(self, pred, y, idx):
"""forward.
FIXME:
- Some times it will be a slightly different from the result from `pandas.corr()`
@@ -33,7 +37,7 @@ class ICLoss(nn.Module):
skip_n = 0
for start_i, end_i in zip(diff_point, diff_point[1:]):
pred_focus = pred[start_i:end_i] # TODO: just for fake
if pred_focus.shape[0] < skip_size:
if pred_focus.shape[0] < self.skip_size:
# skip some days which have very small amount of stock.
skip_n += 1
continue
@@ -50,6 +54,7 @@ class ICLoss(nn.Module):
)
ic_all += ic_day
if len(diff_point) - 1 - skip_n <= 0:
__import__("ipdb").set_trace()
raise ValueError("No enough data for calculating IC")
if skip_n > 0:
get_module_logger("ICLoss").info(

View File

@@ -33,7 +33,7 @@ class CatBoostModel(Model, FeatureInt):
verbose_eval=20,
evals_result=dict(),
reweighter=None,
**kwargs
**kwargs,
):
df_train, df_valid = dataset.prepare(
["train", "valid"],

View File

@@ -31,7 +31,7 @@ class DEnsembleModel(Model, FeatureInt):
sub_weights=None,
epochs=100,
early_stopping_rounds=None,
**kwargs
**kwargs,
):
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
self.num_models = num_models # the number of sub-models

View File

@@ -63,6 +63,7 @@ class LinearModel(Model):
df_train = pd.concat([df_train, df_valid])
except KeyError:
get_module_logger("LinearModel").info("include_valid=True, but valid does not exist")
df_train = df_train.dropna()
if df_train.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
if reweighter is not None:

View File

@@ -56,7 +56,7 @@ class ADARNN(Model):
n_splits=2,
GPU=0,
seed=None,
**_
**_,
):
# Set logger.
self.logger = get_module_logger("ADARNN")
@@ -154,10 +154,7 @@ class ADARNN(Model):
self.model.train()
criterion = nn.MSELoss()
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
len_loader = np.inf
for loader in train_loader_list:
if len(loader) < len_loader:
len_loader = len(loader)
out_weight_list = None
for data_all in zip(*train_loader_list):
# for data_all in zip(*train_loader_list):
self.train_optimizer.zero_grad()
@@ -571,6 +568,7 @@ class TransferLoss:
Returns:
[tensor] -- transfer loss
"""
loss = None
if self.loss_type in ("mmd_lin", "mmd"):
mmdloss = MMD_loss(kernel_type="linear")
loss = mmdloss(X, Y)

View File

@@ -63,7 +63,7 @@ class ADD(Model):
mu=0.05,
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("ADD")

View File

@@ -52,7 +52,7 @@ class ALSTM(Model):
optimizer="adam",
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("ALSTM")

View File

@@ -56,7 +56,7 @@ class ALSTM(Model):
n_jobs=10,
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("ALSTM")
@@ -160,6 +160,10 @@ class ALSTM(Model):
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
elif self.metric == "mse":
mask = ~torch.isnan(label)
weight = torch.ones_like(label)
return -self.mse(pred[mask], label[mask], weight[mask])
raise ValueError("unknown metric `%s`" % self.metric)

View File

@@ -56,7 +56,7 @@ class GATs(Model):
optimizer="adam",
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("GATs")

View File

@@ -73,7 +73,7 @@ class GATs(Model):
GPU=0,
n_jobs=10,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("GATs")

View File

@@ -0,0 +1,358 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from typing import Union
import copy
import torch
import torch.optim as optim
from qlib.data.dataset.weight import Reweighter
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
from ...utils import (
init_instance_by_config,
get_or_create_path,
)
from ...log import get_module_logger
from ...model.utils import ConcatDataset
class GeneralPTNN(Model):
"""
Motivation:
We want to provide a Qlib General Pytorch Model Adaptor
You can reuse it for all kinds of Pytorch models.
It should include the training and predict process
Parameters
----------
d_feat : int
input dimension for each time step
metric: str
the evaluation metric used in early stop
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
"""
def __init__(
self,
n_epochs=200,
lr=0.001,
metric="",
batch_size=2000,
early_stop=20,
loss="mse",
weight_decay=0.0,
optimizer="adam",
n_jobs=10,
GPU=0,
seed=None,
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
pt_model_kwargs={
"d_feat": 6,
"hidden_size": 64,
"num_layers": 2,
"dropout": 0.0,
},
):
# Set logger.
self.logger = get_module_logger("GeneralPTNN")
self.logger.info("GeneralPTNN pytorch version...")
# set hyper-parameters.
self.n_epochs = n_epochs
self.lr = lr
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.weight_decay = weight_decay
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
self.seed = seed
self.pt_model_uri, self.pt_model_kwargs = pt_model_uri, pt_model_kwargs
self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs})
self.logger.info(
"GeneralPTNN parameters setting:"
"\nn_epochs : {}"
"\nlr : {}"
"\nmetric : {}"
"\nbatch_size : {}"
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\ndevice : {}"
"\nn_jobs : {}"
"\nuse_GPU : {}"
"\nweight_decay : {}"
"\nseed : {}"
"\npt_model_uri: {}"
"\npt_model_kwargs: {}".format(
n_epochs,
lr,
metric,
batch_size,
early_stop,
optimizer.lower(),
loss,
self.device,
n_jobs,
self.use_gpu,
weight_decay,
seed,
pt_model_uri,
pt_model_kwargs,
)
)
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.logger.info("model:\n{:}".format(self.dnn_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.dnn_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label, weight):
loss = weight * (pred - label) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label, weight=None):
mask = ~torch.isnan(label)
if weight is None:
weight = torch.ones_like(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask], weight[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def _get_fl(self, data: torch.Tensor):
"""
get feature and label from data
- Handle the different data shape of time series and tabular data
Parameters
----------
data : torch.Tensor
input data which maybe 3 dimension or 2 dimension
- 3dim: [batch_size, time_step, feature_dim]
- 2dim: [batch_size, feature_dim]
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
"""
if data.dim() == 3:
# it is a time series dataset
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
elif data.dim() == 2:
# it is a tabular dataset
feature = data[:, 0:-1].to(self.device)
label = data[:, -1].to(self.device)
else:
raise ValueError("Unsupported data shape.")
return feature, label
def train_epoch(self, data_loader):
self.dnn_model.train()
for data, weight in data_loader:
feature, label = self._get_fl(data)
pred = self.dnn_model(feature.float())
loss = self.loss_fn(pred, label, weight.to(self.device))
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.dnn_model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_loader):
self.dnn_model.eval()
scores = []
losses = []
for data, weight in data_loader:
feature, label = self._get_fl(data)
with torch.no_grad():
pred = self.dnn_model(feature.float())
loss = self.loss_fn(pred, label, weight.to(self.device))
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset: Union[DatasetH, TSDatasetH],
evals_result=dict(),
save_path=None,
reweighter=None,
):
ists = isinstance(dataset, TSDatasetH) # is this time series dataset
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if dl_train.empty or dl_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
if reweighter is None:
wl_train = np.ones(len(dl_train))
wl_valid = np.ones(len(dl_valid))
elif isinstance(reweighter, Reweighter):
wl_train = reweighter.reweight(dl_train)
wl_valid = reweighter.reweight(dl_valid)
else:
raise ValueError("Unsupported reweighter type.")
# Preprocess for data. To align to Dataset Interface for DataLoader
if ists:
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
else:
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
dl_train = dl_train.values
dl_valid = dl_valid.values
train_loader = DataLoader(
ConcatDataset(dl_train, wl_train),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.n_jobs,
drop_last=True,
)
valid_loader = DataLoader(
ConcatDataset(dl_valid, wl_valid),
batch_size=self.batch_size,
shuffle=False,
num_workers=self.n_jobs,
drop_last=True,
)
del dl_train, dl_valid, wl_train, wl_valid
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(train_loader)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(train_loader)
val_loss, val_score = self.test_epoch(valid_loader)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if step == 0:
best_param = copy.deepcopy(self.dnn_model.state_dict())
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.dnn_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.dnn_model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(
self,
dataset: Union[DatasetH, TSDatasetH],
batch_size=None,
n_jobs=None,
):
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
if isinstance(dataset, TSDatasetH):
dl_test.config(fillna_type="ffill+bfill") # process nan brought by dataloader
index = dl_test.get_index()
else:
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
index = dl_test.index
dl_test = dl_test.values
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.dnn_model.eval()
preds = []
for data in test_loader:
feature, _ = self._get_fl(data)
feature = feature.to(self.device)
with torch.no_grad():
pred = self.dnn_model(feature.float()).detach().cpu().numpy()
preds.append(pred)
preds_concat = np.concatenate(preds)
if preds_concat.ndim != 1:
preds_concat = preds_concat.ravel()
return pd.Series(preds_concat, index=index)

View File

@@ -1,25 +1,25 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import copy
from typing import Text, Union
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from .pytorch_utils import count_parameters
from ...model.base import Model
from qlib.workflow import R
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...log import get_module_logger
from ...model.base import Model
from ...utils import get_or_create_path
from .pytorch_utils import count_parameters
class GRU(Model):
@@ -52,7 +52,7 @@ class GRU(Model):
optimizer="adam",
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("GRU")
@@ -212,16 +212,31 @@ class GRU(Model):
evals_result=dict(),
save_path=None,
):
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
# prepare training and validation data
dfs = {
k: dataset.prepare(
k,
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
for k in ["train", "valid"]
if k in dataset.segments
}
df_train, df_valid = dfs.get("train", pd.DataFrame()), dfs.get("valid", pd.DataFrame())
# check if training data is empty
if df_train.empty:
raise ValueError("Empty training data from dataset, please check your dataset config.")
df_train = df_train.dropna()
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
# check if validation data is provided
if not df_valid.empty:
df_valid = df_valid.dropna()
x_valid, y_valid = df_valid["feature"], df_valid["label"]
else:
x_valid, y_valid = None, None
save_path = get_or_create_path(save_path)
stop_steps = 0
@@ -235,32 +250,42 @@ class GRU(Model):
self.logger.info("training...")
self.fitted = True
best_param = copy.deepcopy(self.gru_model.state_dict())
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(x_train, y_train)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(x_train, y_train)
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.gru_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
# evaluate on validation data if provided
if x_valid is not None and y_valid is not None:
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.gru_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.gru_model.load_state_dict(best_param)
torch.save(best_param, save_path)
# Logging
rec = R.get_recorder()
for k, v_l in evals_result.items():
for i, v in enumerate(v_l):
rec.log_metrics(step=i, **{k: v})
if self.use_gpu:
torch.cuda.empty_cache()

View File

@@ -54,7 +54,7 @@ class GRU(Model):
n_jobs=10,
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("GRU")

View File

@@ -59,7 +59,7 @@ class HIST(Model):
optimizer="adam",
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("HIST")
@@ -256,7 +256,7 @@ class HIST(Model):
raise ValueError("Empty data from dataset, please check your dataset config.")
if not os.path.exists(self.stock2concept):
url = "http://fintech.msra.cn/stock_data/downloads/qlib_csi300_stock2concept.npy"
url = "https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/qlib_csi300_stock2concept.npy"
urllib.request.urlretrieve(url, self.stock2concept)
stock_index = np.load(self.stock_index, allow_pickle=True).item()

View File

@@ -55,7 +55,7 @@ class IGMTF(Model):
optimizer="adam",
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("IGMTF")

View File

@@ -255,7 +255,7 @@ class KRNN(Model):
optimizer="adam",
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("KRNN")

View File

@@ -44,7 +44,7 @@ class LocalformerModel(Model):
n_jobs=10,
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# set hyper-parameters.
self.d_model = d_model

View File

@@ -42,7 +42,7 @@ class LocalformerModel(Model):
n_jobs=10,
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# set hyper-parameters.
self.d_model = d_model

View File

@@ -51,7 +51,7 @@ class LSTM(Model):
optimizer="adam",
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("LSTM")

View File

@@ -53,7 +53,7 @@ class LSTM(Model):
n_jobs=10,
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("LSTM")

View File

@@ -35,7 +35,7 @@ class SandwichModel(nn.Module):
rnn_layers,
dropout,
device,
**params
**params,
):
"""Build a Sandwich model
@@ -129,7 +129,7 @@ class Sandwich(Model):
optimizer="adam",
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("Sandwich")

View File

@@ -212,7 +212,7 @@ class SFM(Model):
optimizer="gd",
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("SFM")

View File

@@ -56,7 +56,7 @@ class TCN(Model):
optimizer="adam",
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("TCN")

View File

@@ -54,7 +54,7 @@ class TCN(Model):
n_jobs=10,
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("TCN")

View File

@@ -58,7 +58,7 @@ class TCTS(Model):
mode="soft",
seed=None,
lowest_valid_performance=0.993,
**kwargs
**kwargs,
):
# Set logger.
self.logger = get_module_logger("TCTS")

View File

@@ -43,7 +43,7 @@ class TransformerModel(Model):
n_jobs=10,
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# set hyper-parameters.
self.d_model = d_model

View File

@@ -41,7 +41,7 @@ class TransformerModel(Model):
n_jobs=10,
GPU=0,
seed=None,
**kwargs
**kwargs,
):
# set hyper-parameters.
self.d_model = d_model

View File

@@ -28,7 +28,7 @@ class XGBModel(Model, FeatureInt):
verbose_eval=20,
evals_result=dict(),
reweighter=None,
**kwargs
**kwargs,
):
df_train, df_valid = dataset.prepare(
["train", "valid"],
@@ -63,7 +63,7 @@ class XGBModel(Model, FeatureInt):
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs
**kwargs,
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]

View File

@@ -4,10 +4,10 @@
# pylint: skip-file
# flake8: noqa
import yaml
import pathlib
import pandas as pd
import shutil
from ruamel.yaml import YAML
from ...backtest.account import Account
from .user import User
from .utils import load_instance, save_instance
@@ -110,7 +110,8 @@ class UserManager:
raise ValueError("User data for {} already exists".format(user_id))
with config_file.open("r") as fp:
config = yaml.safe_load(fp)
yaml = YAML(typ="safe", pure=True)
config = yaml.load(fp)
# load model
model = init_instance_by_config(config["model"])

View File

@@ -6,8 +6,8 @@
import pathlib
import pickle
import yaml
import pandas as pd
from ruamel.yaml import YAML
from ...data import D
from ...config import C
from ...log import get_module_logger
@@ -91,7 +91,8 @@ def prepare(um, today, user_id, exchange_config=None):
dates.append(get_next_trading_date(dates[-1], future=True))
if exchange_config:
with pathlib.Path(exchange_config).open("r") as fp:
exchange_paras = yaml.safe_load(fp)
yaml = YAML(typ="safe", pure=True)
exchange_paras = yaml.load(fp)
else:
exchange_paras = {}
trade_exchange = Exchange(trade_dates=dates, **exchange_paras)

View File

@@ -1,5 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Here we have a comprehensive set of analysis classes.
Here is an example.
.. code-block:: python
from qlib.contrib.report.data.ana import FeaMeanStd
fa = FeaMeanStd(ret_df)
fa.plot_all(wspace=0.3, sub_figsize=(12, 3), col_n=5)
"""
import pandas as pd
import numpy as np
from qlib.contrib.report.data.base import FeaAnalyser
@@ -152,6 +164,7 @@ class FeaSkewTurt(NumFeaAnalyser):
self._kurt[col].plot(ax=right_ax, label="kurt", color="green")
right_ax.set_xlabel("")
right_ax.set_ylabel("kurt")
right_ax.grid(None) # set the grid to None to avoid two layer of grid
h1, l1 = ax.get_legend_handles_labels()
h2, l2 = right_ax.get_legend_handles_labels()
@@ -171,12 +184,15 @@ class FeaMeanStd(NumFeaAnalyser):
ax.set_xlabel("")
ax.set_ylabel("mean")
ax.legend()
ax.tick_params(axis="x", rotation=90)
right_ax = ax.twinx()
self._std[col].plot(ax=right_ax, label="std", color="green")
right_ax.set_xlabel("")
right_ax.set_ylabel("std")
right_ax.tick_params(axis="x", rotation=90)
right_ax.grid(None) # set the grid to None to avoid two layer of grid
h1, l1 = ax.get_legend_handles_labels()
h2, l2 = right_ax.get_legend_handles_labels()

View File

@@ -14,6 +14,24 @@ from qlib.contrib.report.utils import sub_fig_generator
class FeaAnalyser:
def __init__(self, dataset: pd.DataFrame):
"""
Parameters
----------
dataset : pd.DataFrame
We often have multiple columns for dataset. Each column corresponds to one sub figure.
There will be a datatime column in the index levels.
Aggretation will be used for more summarized metrics overtime.
Here is an example of data:
.. code-block::
return
datetime instrument
2007-02-06 equity_tpx 0.010087
equity_spx 0.000786
"""
self._dataset = dataset
with TimeInspector.logt("calc_stat_values"):
self.calc_stat_values()

View File

@@ -176,7 +176,7 @@ class HeatmapGraph(BaseGraph):
x=self._df.columns,
y=self._df.index,
z=self._df.values.tolist(),
**self._graph_kwargs
**self._graph_kwargs,
)
]
return _data
@@ -213,7 +213,7 @@ class SubplotsGraph:
sub_graph_layout: dict = None,
sub_graph_data: list = None,
subplots_kwargs: dict = None,
**kwargs
**kwargs,
):
"""
@@ -355,7 +355,7 @@ class SubplotsGraph:
df=self._df.loc[:, [column_name]],
name_dict={column_name: temp_name},
graph_kwargs=_graph_kwargs,
)
),
)
else:
raise TypeError()

View File

@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
import pandas as pd
def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
def sub_fig_generator(sub_figsize=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
"""sub_fig_generator.
it will return a generator, each row contains <col_n> sub graph
@@ -13,7 +13,7 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None
Parameters
----------
sub_fs :
sub_figsize :
the figure size of each subgraph in <col_n> * <row_n> subgraphs
col_n :
the number of subgraph in each row; It will generating a new graph after generating <col_n> of subgraphs.
@@ -33,7 +33,7 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None
while True:
fig, axes = plt.subplots(
row_n, col_n, figsize=(sub_fs[0] * col_n, sub_fs[1] * row_n), sharex=sharex, sharey=sharey
row_n, col_n, figsize=(sub_figsize[0] * col_n, sub_figsize[1] * row_n), sharex=sharex, sharey=sharey
)
plt.subplots_adjust(wspace=wspace, hspace=hspace)
axes = axes.reshape(row_n, col_n)

View File

@@ -2,11 +2,11 @@
# Licensed under the MIT License.
from copy import deepcopy
from pathlib import Path
from ruamel.yaml import YAML
from typing import List, Optional, Union
import fire
import pandas as pd
import yaml
from qlib import auto_init
from qlib.log import get_module_logger
@@ -73,8 +73,8 @@ class Rolling:
The horizon of the prediction target.
This is used to override the prediction horizon of the file.
h_path : Optional[str]
the dumped data handler;
It may come from other data source. It will override the data handler in the config.
It is other data source that is dumped as a handler. It will override the data handler section in the config.
If it is not given, it will create a customized cache for the handler when `enable_handler_cache=True`
test_end : Optional[str]
the test end for the data. It is typically used together with the handler
You can do the same thing with task_ext_conf in a more complicated way
@@ -117,9 +117,10 @@ class Rolling:
def _raw_conf(self) -> dict:
with self.conf_path.open("r") as f:
return yaml.safe_load(f)
yaml = YAML(typ="safe", pure=True)
return yaml.load(f)
def _replace_hanler_with_cache(self, task: dict):
def _replace_handler_with_cache(self, task: dict):
"""
Due to the data processing part in original rolling is slow. So we have to
This class tries to add more feature
@@ -159,13 +160,20 @@ class Rolling:
# - get horizon automatically from the expression!!!!
raise NotImplementedError(f"This type of input is not supported")
else:
self.logger.info("The prediction horizon is overrided")
task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [
"Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1)
]
if enable_handler_cache and self.h_path is not None:
self.logger.info("Fail to override the horizon due to data handler cache")
else:
self.logger.info("The prediction horizon is overrided")
if isinstance(task["dataset"]["kwargs"]["handler"], dict):
task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [
"Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1)
]
else:
self.logger.warning("Try to automatically configure the lablel but failed.")
if enable_handler_cache:
task = self._replace_hanler_with_cache(task)
if self.h_path is not None or enable_handler_cache:
# if we already have provided data source or we want to create one
task = self._replace_handler_with_cache(task)
task = self._update_start_end_time(task)
if self.task_ext_conf is not None:
@@ -173,6 +181,16 @@ class Rolling:
self.logger.info(task)
return task
def run_basic_task(self):
"""
Run the basic task without rolling.
This is for fast testing for model tunning.
"""
task = self.basic_task()
print(task)
trainer = TrainerR(experiment_name=self.exp_name)
trainer([task])
def get_task_list(self) -> List[dict]:
"""return a batch of tasks for rolling."""
task = self.basic_task()

View File

@@ -80,6 +80,11 @@ class DDGDA(Rolling):
sim_task_model: UTIL_MODEL_TYPE = "gbdt",
meta_1st_train_end: Optional[str] = None,
alpha: float = 0.01,
loss_skip_thresh: int = 50,
fea_imp_n: Optional[int] = 30,
meta_data_proc: Optional[str] = "V01",
segments: Union[float, str] = 0.62,
hist_step_n: int = 30,
working_dir: Optional[Union[str, Path]] = None,
**kwargs,
):
@@ -94,6 +99,15 @@ class DDGDA(Rolling):
alpha: float
Setting the L2 regularization for ridge
The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)
loss_skip_thresh: int
The thresh to skip the loss calculation for each day. If the number of item is less than it, it will skip the loss on that day.
meta_data_proc : Optional[str]
How we process the meta dataset for learning meta model.
segments : Union[float, str]
if segments is a float:
The ratio of training data in the meta task dataset
if segments is a string:
it will try its best to put its data in training and ensure that the date `segments` is in the test set
"""
# NOTE:
# the horizon must match the meaning in the base task template
@@ -104,14 +118,22 @@ class DDGDA(Rolling):
super().__init__(**kwargs)
self.working_dir = self.conf_path.parent if working_dir is None else Path(working_dir)
self.proxy_hd = self.working_dir / "handler_proxy.pkl"
self.fea_imp_n = fea_imp_n
self.meta_data_proc = meta_data_proc
self.loss_skip_thresh = loss_skip_thresh
self.segments = segments
self.hist_step_n = hist_step_n
def _adjust_task(self, task: dict, astype: UTIL_MODEL_TYPE):
"""
some task are use for special purpose.
Base on the original task, we need to do some extra things.
For example:
- GBDT for calculating feature importance
- Linear or GBDT for calculating similarity
- Datset (well processed) that aligned to Linear that for meta learning
So we may need to change the dataset and model for the special purpose and other settings remains the same.
"""
# NOTE: here is just for aligning with previous implementation
# It is not necessary for the current implementation
@@ -119,12 +141,16 @@ class DDGDA(Rolling):
if astype == "gbdt":
task["model"] = LGBM_MODEL
if isinstance(handler, dict):
# We don't need preprocessing when using GBDT model
for k in ["infer_processors", "learn_processors"]:
if k in handler.setdefault("kwargs", {}):
handler["kwargs"].pop(k)
elif astype == "linear":
task["model"] = LINEAR_MODEL
handler["kwargs"].update(PROC_ARGS)
if isinstance(handler, dict):
handler["kwargs"].update(PROC_ARGS)
else:
self.logger.warning("The handler can't be adjusted.")
else:
raise ValueError(f"astype not supported: {astype}")
return task
@@ -155,12 +181,15 @@ class DDGDA(Rolling):
The meta model will be trained upon the proxy forecasting model.
This dataset is for the proxy forecasting model.
"""
topk = 30
fi = self._get_feature_importance()
col_selected = fi.nlargest(topk)
# NOTE: adjusting to `self.sim_task_model` just for aligning with previous implementation.
# In previous version. The data for proxy model is using sim_task_model's way for processing
task = self._adjust_task(self.basic_task(enable_handler_cache=False), self.sim_task_model)
task = replace_task_handler_with_cache(task, self.working_dir)
# if self.meta_data_proc is not None:
# else:
# # Otherwise, we don't need futher processing
# task = self.basic_task()
dataset = init_instance_by_config(task["dataset"])
prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
@@ -168,12 +197,18 @@ class DDGDA(Rolling):
feature_df = prep_ds["feature"]
label_df = prep_ds["label"]
feature_selected = feature_df.loc[:, col_selected.index]
if self.fea_imp_n is not None:
fi = self._get_feature_importance()
col_selected = fi.nlargest(self.fea_imp_n)
feature_selected = feature_df.loc[:, col_selected.index]
else:
feature_selected = feature_df
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
lambda df: (df - df.mean()).div(df.std())
)
feature_selected = feature_selected.fillna(0.0)
if self.meta_data_proc == "V01":
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
lambda df: (df - df.mean()).div(df.std())
)
feature_selected = feature_selected.fillna(0.0)
df_all = {
"label": label_df.reindex(feature_selected.index),
@@ -223,7 +258,10 @@ class DDGDA(Rolling):
# 1) leverage the simplified proxy forecasting model to train meta model.
# - Only the dataset part is important, in current version of meta model will integrate the
# the train_start for training meta model does not necessarily align with final rolling
# NOTE:
# - The train_start for training meta model does not necessarily align with final rolling
# But please select a right time to make sure the finnal rolling tasks are not leaked in the training data.
# - The test_start is automatically aligned to the next day of test_end. Validation is ignored.
train_start = "2008-01-01" if self.train_start is None else self.train_start
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
@@ -249,9 +287,9 @@ class DDGDA(Rolling):
kwargs = dict(
task_tpl=proxy_forecast_model_task,
step=self.step,
segments=0.62, # keep test period consistent with the dataset yaml
segments=self.segments, # keep test period consistent with the dataset yaml
trunc_days=1 + self.horizon,
hist_step_n=30,
hist_step_n=self.hist_step_n,
fill_method=fill_method,
rolling_ext_days=0,
)
@@ -268,7 +306,13 @@ class DDGDA(Rolling):
with R.start(experiment_name=self.meta_exp_name):
R.log_params(**kwargs)
mm = MetaModelDS(
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=30, seed=43, alpha=self.alpha
step=self.step,
hist_step_n=kwargs["hist_step_n"],
lr=0.001,
max_epoch=30,
seed=43,
alpha=self.alpha,
loss_skip_thresh=self.loss_skip_thresh,
)
mm.fit(md)
R.save_objects(model=mm)

View File

@@ -4,9 +4,9 @@
# pylint: skip-file
# flake8: noqa
import yaml
import copy
import os
from ruamel.yaml import YAML
class TunerConfigManager:
@@ -16,7 +16,8 @@ class TunerConfigManager:
self.config_path = config_path
with open(config_path) as fp:
config = yaml.safe_load(fp)
yaml = YAML(typ="safe", pure=True)
config = yaml.load(fp)
self.config = copy.deepcopy(config)
self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self)

View File

@@ -35,7 +35,7 @@ class Client:
def connect_server(self):
"""Connect to server."""
try:
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
self.sio.connect(f"ws://{self.server_host}:{self.server_port}")
except socketio.exceptions.ConnectionError:
self.logger.error("Cannot connect to server - check your network or server status")

View File

@@ -616,7 +616,7 @@ class DatasetProvider(abc.ABC):
data = pd.DataFrame(obj)
if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")):
# If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format
# If the underlaying provides the data not in datetime format, we'll convert it into datetime format
_calendar = Cal.calendar(freq=freq)
data.index = _calendar[data.index.values.astype(int)]
data.index.names = ["datetime"]

View File

@@ -403,7 +403,7 @@ class TSDataSampler:
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
axis=0,
)
self.nan_idx = -1 # The last line is all NaN
self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716
# the data type will be changed
# The index of usable data is between start_idx and end_idx

View File

@@ -7,7 +7,7 @@ from pathlib import Path
import warnings
import pandas as pd
from typing import Tuple, Union, List
from typing import Tuple, Union, List, Dict
from qlib.data import D
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
@@ -41,6 +41,7 @@ class DataLoader(abc.ABC):
----------
instruments : str or dict
it can either be the market name or the config file of instruments generated by InstrumentProvider.
If the value of instruments is None, it means that no filtering is done.
start_time : str
start of the time range.
end_time : str
@@ -50,6 +51,11 @@ class DataLoader(abc.ABC):
-------
pd.DataFrame:
data load from the under layer source
Raise
-----
KeyError:
if the instruments filter is not supported, raise KeyError
"""
@@ -247,10 +253,14 @@ class StaticDataLoader(DataLoader, Serializable):
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
self._maybe_load_raw_data()
# 1) Filter by instruments
if instruments is None:
df = self._data
else:
df = self._data.loc(axis=0)[:, instruments]
# 2) Filter by Datetime
if start_time is None and end_time is None:
return df # NOTE: avoid copy by loc
# pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
@@ -275,6 +285,61 @@ class StaticDataLoader(DataLoader, Serializable):
self._data = self._config
class NestedDataLoader(DataLoader):
"""
We have multiple DataLoader, we can use this class to combine them.
"""
def __init__(self, dataloader_l: List[Dict], join="left") -> None:
"""
Parameters
----------
dataloader_l : list[dict]
A list of dataloader, for exmaple
.. code-block:: python
nd = NestedDataLoader(
dataloader_l=[
{
"class": "qlib.contrib.data.loader.Alpha158DL",
}, {
"class": "qlib.contrib.data.loader.Alpha360DL",
"kwargs": {
"config": {
"label": ( ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
}
}
}
]
)
join :
it will pass to pd.concat when merging it.
"""
super().__init__()
self.data_loader_l = [
(dl if isinstance(dl, DataLoader) else init_instance_by_config(dl)) for dl in dataloader_l
]
self.join = join
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
df_full = None
for dl in self.data_loader_l:
try:
df_current = dl.load(instruments, start_time, end_time)
except KeyError:
warnings.warn(
"If the value of `instruments` cannot be processed, it will set instruments to None to get all the data."
)
df_current = dl.load(instruments=None, start_time=start_time, end_time=end_time)
if df_full is None:
df_full = df_current
else:
df_full = pd.merge(df_full, df_current, left_index=True, right_index=True, how=self.join)
return df_full.sort_index(axis=1)
class DataLoaderDH(DataLoader):
"""DataLoaderDH
DataLoader based on (D)ata (H)andler

View File

@@ -104,15 +104,24 @@ class HashingStockStorage(BaseHandlerStorage):
"""
stock_selector = slice(None)
time_selector = slice(None) # by default not filter by time.
if level is None:
# For directly applying.
if isinstance(selector, tuple) and self.stock_level < len(selector):
# full selector format
stock_selector = selector[self.stock_level]
time_selector = selector[1 - self.stock_level]
elif isinstance(selector, (list, str)) and self.stock_level == 0:
# only stock selector
stock_selector = selector
elif level in ("instrument", self.stock_level):
if isinstance(selector, tuple):
# NOTE: How could the stock level selector be a tuple?
stock_selector = selector[0]
raise TypeError(
"I forget why would this case appear. But I think it does not make sense. So we raise a error for that case."
)
elif isinstance(selector, (list, str)):
stock_selector = selector
@@ -120,7 +129,7 @@ class HashingStockStorage(BaseHandlerStorage):
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
if stock_selector == slice(None):
return self.hash_df
return self.hash_df, time_selector
if isinstance(stock_selector, str):
stock_selector = [stock_selector]
@@ -129,7 +138,7 @@ class HashingStockStorage(BaseHandlerStorage):
for each_stock in sorted(stock_selector):
if each_stock in self.hash_df:
select_dict[each_stock] = self.hash_df[each_stock]
return select_dict
return select_dict, time_selector
def fetch(
self,
@@ -138,10 +147,13 @@ class HashingStockStorage(BaseHandlerStorage):
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
) -> pd.DataFrame:
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
fetch_stock_df_list, time_selector = self._fetch_hash_df_by_stock(selector=selector, level=level)
fetch_stock_df_list = list(fetch_stock_df_list.values())
for _index, stock_df in enumerate(fetch_stock_df_list):
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig)
fetch_index_df = fetch_df_by_index(
df=fetch_col_df, selector=time_selector, level="datetime", fetch_orig=fetch_orig
)
fetch_stock_df_list[_index] = fetch_index_df
if len(fetch_stock_df_list) == 0:
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")

View File

@@ -9,7 +9,7 @@ if TYPE_CHECKING:
from qlib.data.dataset import DataHandler
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
def get_level_index(df: pd.DataFrame, level: Union[str, int]) -> int:
"""
get the level index of `df` given `level`

View File

@@ -164,6 +164,7 @@ class SeriesDFilter(BaseDFilter):
timestamp = []
_lbool = None
_ltime = None
_cur_start = None
for _ts, _bool in timestamp_series.items():
# there is likely to be NAN when the filter series don't have the
# bool value, so we just change the NAN into False

View File

@@ -51,3 +51,6 @@ class MetaTask:
Return the **processed** meta_info
"""
return self.meta_info
def __repr__(self):
return f"MetaTask(task={self.task}, meta_info={self.meta_info})"

View File

@@ -41,7 +41,7 @@ def _log_task_info(task_config: dict):
def _exe_task(task_config: dict):
rec = R.get_recorder()
# model & dataset initiation
# model & dataset initialization
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
reweighter: Reweighter = task_config.get("reweighter", None)

View File

@@ -7,8 +7,7 @@ import shutil
import sys
import tempfile
from importlib import import_module
import yaml
from ruamel.yaml import YAML
DELETE_KEY = "_delete_"
@@ -57,7 +56,8 @@ def parse_backtest_config(path: str) -> dict:
del sys.modules[tmp_module_name]
else:
with open(tmp_config_file.name) as input_stream:
config = yaml.safe_load(input_stream)
yaml = YAML(typ="safe", pure=True)
config = yaml.load(input_stream)
if "_base_" in config:
base_file_name = config.pop("_base_")

View File

@@ -8,12 +8,12 @@ import random
import sys
import warnings
from pathlib import Path
from ruamel.yaml import YAML
from typing import cast, List, Optional
import numpy as np
import pandas as pd
import torch
import yaml
from qlib.backtest import Order
from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN
@@ -263,6 +263,7 @@ if __name__ == "__main__":
args = parser.parse_args()
with open(args.config_path, "r") as input_stream:
config = yaml.safe_load(input_stream)
yaml = YAML(typ="safe", pure=True)
config = yaml.load(input_stream)
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)

View File

@@ -12,15 +12,11 @@ import datetime
from tqdm import tqdm
from pathlib import Path
from loguru import logger
from cryptography.fernet import Fernet
from qlib.utils import exists_qlib_data
class GetData:
REMOTE_URL = "https://qlibpublic.blob.core.windows.net/data/default/stock_data"
# "?" is not included in the token.
TOKEN = b"gAAAAABkmDhojHc0VSCDdNK1MqmRzNLeDFXe5hy8obHpa6SDQh4de6nW5gtzuD-fa6O_WZb0yyqYOL7ndOfJX_751W3xN5YB4-n-P22jK-t6ucoZqhT70KPD0Lf0_P328QPJVZ1gDnjIdjhi2YLOcP4BFTHLNYO0mvzszR8TKm9iT5AKRvuysWnpi8bbYwGU9zAcJK3x9EPL43hOGtxliFHcPNGMBoJW4g_ercdhi0-Qgv5_JLsV-29_MV-_AhuaYvJuN2dEywBy"
KEY = "EYcA8cgorA8X9OhyMwVfuFxn_1W3jGk6jCbs3L2oPoA="
REMOTE_URL = "https://github.com/SunsetWolf/qlib_dataset/releases/download"
def __init__(self, delete_zip_file=False):
"""
@@ -33,9 +29,45 @@ class GetData:
self.delete_zip_file = delete_zip_file
def merge_remote_url(self, file_name: str):
fernet = Fernet(self.KEY)
token = fernet.decrypt(self.TOKEN).decode()
return f"{self.REMOTE_URL}/{file_name}?{token}"
"""
Generate download links.
Parameters
----------
file_name: str
The name of the file to be downloaded.
The file name can be accompanied by a version number, (e.g.: v2/qlib_data_simple_cn_1d_latest.zip),
if no version number is attached, it will be downloaded from v0 by default.
"""
return f"{self.REMOTE_URL}/{file_name}" if "/" in file_name else f"{self.REMOTE_URL}/v0/{file_name}"
def download(self, url: str, target_path: [Path, str]):
"""
Download a file from the specified url.
Parameters
----------
url: str
The url of the data.
target_path: str
The location where the data is saved, including the file name.
"""
file_name = str(target_path).rsplit("/", maxsplit=1)[-1]
resp = requests.get(url, stream=True, timeout=60)
resp.raise_for_status()
if resp.status_code != 200:
raise requests.exceptions.HTTPError()
chunk_size = 1024
logger.warning(
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
)
logger.info(f"{os.path.basename(file_name)} downloading......")
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
with target_path.open("wb") as fp:
for chunk in resp.iter_content(chunk_size=chunk_size):
fp.write(chunk)
p_bar.update(chunk_size)
def download_data(self, file_name: str, target_dir: [Path, str], delete_old: bool = True):
"""
@@ -70,21 +102,7 @@ class GetData:
target_path = target_dir.joinpath(_target_file_name)
url = self.merge_remote_url(file_name)
resp = requests.get(url, stream=True, timeout=60)
resp.raise_for_status()
if resp.status_code != 200:
raise requests.exceptions.HTTPError()
chunk_size = 1024
logger.warning(
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
)
logger.info(f"{os.path.basename(file_name)} downloading......")
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
with target_path.open("wb") as fp:
for chunk in resp.iter_content(chunk_size=chunk_size):
fp.write(chunk)
p_bar.update(chunk_size)
self.download(url=url, target_path=target_path)
self._unzip(target_path, target_dir, delete_old)
if self.delete_zip_file:
@@ -99,7 +117,9 @@ class GetData:
return status
@staticmethod
def _unzip(file_path: Path, target_dir: Path, delete_old: bool = True):
def _unzip(file_path: [Path, str], target_dir: [Path, str], delete_old: bool = True):
file_path = Path(file_path)
target_dir = Path(target_dir)
if delete_old:
logger.warning(
f"will delete the old qlib data directory(features, instruments, calendars, features_cache, dataset_cache): {target_dir}"

View File

@@ -10,7 +10,6 @@ import os
import re
import copy
import json
import yaml
import redis
import bisect
import struct
@@ -25,6 +24,7 @@ import pandas as pd
from pathlib import Path
from typing import List, Union, Optional, Callable
from packaging import version
from ruamel.yaml import YAML
from .file import (
get_or_create_path,
save_multiple_parts_file,
@@ -244,12 +244,13 @@ def parse_config(config):
if not isinstance(config, str):
return config
# Check whether config is file
yaml = YAML(typ="safe", pure=True)
if os.path.exists(config):
with open(config, "r") as f:
return yaml.safe_load(f)
return yaml.load(f)
# Check whether the str can be parsed
try:
return yaml.safe_load(config)
return yaml.load(config)
except BaseException as base_exp:
raise ValueError("cannot parse config!") from base_exp
@@ -799,6 +800,7 @@ def fill_placeholder(config: dict, config_extend: dict):
)
return value
item_keys = None
while top < tail:
now_item = item_queue[top]
top += 1

View File

@@ -44,7 +44,7 @@ def concat(data_list: Union[SingleData], axis=0) -> MultiData:
all_index_map = dict(zip(all_index, range(len(all_index))))
# concat all
tmp_data = np.full((len(all_index), len(data_list)), np.NaN)
tmp_data = np.full((len(all_index), len(data_list)), np.nan)
for data_id, index_data in enumerate(data_list):
assert isinstance(index_data, SingleData)
now_data_map = [all_index_map[index] for index in index_data.index]
@@ -64,7 +64,7 @@ def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) ->
new_index : list
the new_index of new SingleData.
fill_value : float
fill the missing values or replace np.NaN.
fill the missing values or replace np.nan.
Returns
-------
@@ -108,6 +108,12 @@ class Index:
self.index_map = self.idx_list = np.arange(idx_list)
self._is_sorted = True
else:
# Check if all elements in idx_list are of the same type
if not all(isinstance(x, type(idx_list[0])) for x in idx_list):
raise TypeError("All elements in idx_list must be of the same type")
# Check if all elements in idx_list are of the same datetime64 precision
if isinstance(idx_list[0], np.datetime64) and not all(x.dtype == idx_list[0].dtype for x in idx_list):
raise TypeError("All elements in idx_list must be of the same datetime64 precision")
self.idx_list = np.array(idx_list)
# NOTE: only the first appearance is indexed
self.index_map = dict(zip(self.idx_list, range(len(self))))
@@ -131,7 +137,12 @@ class Index:
if self.idx_list.dtype.type is np.datetime64:
if isinstance(item, pd.Timestamp):
# This happens often when creating index based on pandas.DatetimeIndex and query with pd.Timestamp
return item.to_numpy()
return item.to_numpy().astype(self.idx_list.dtype)
elif isinstance(item, np.datetime64):
# This happens often when creating index based on np.datetime64 and query with another precision
return item.astype(self.idx_list.dtype)
# NOTE: It is hard to consider every case at first.
# We just try to cover part of cases to make it more user-friendly
return item
def index(self, item) -> int:
@@ -433,7 +444,7 @@ class IndexData(metaclass=index_data_ops_creator):
return self.__class__(~self.data.astype(bool), *self.indices)
def abs(self):
"""get the abs of data except np.NaN."""
"""get the abs of data except np.nan."""
tmp_data = np.absolute(self.data)
return self.__class__(tmp_data, *self.indices)
@@ -555,8 +566,8 @@ class SingleData(IndexData):
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
)
def reindex(self, index: Index, fill_value=np.NaN) -> SingleData:
"""reindex data and fill the missing value with np.NaN.
def reindex(self, index: Index, fill_value=np.nan) -> SingleData:
"""reindex data and fill the missing value with np.nan.
Parameters
----------
@@ -604,7 +615,7 @@ class SingleData(IndexData):
return pd.Series(self.data, index=self.index)
def __repr__(self) -> str:
return str(pd.Series(self.data, index=self.index))
return str(pd.Series(self.data, index=self.index.tolist()))
class MultiData(IndexData):
@@ -640,4 +651,4 @@ class MultiData(IndexData):
)
def __repr__(self) -> str:
return str(pd.DataFrame(self.data, index=self.index, columns=self.columns))
return str(pd.DataFrame(self.data, index=self.index.tolist(), columns=self.columns.tolist()))

View File

@@ -161,7 +161,13 @@ def init_instance_by_config(
# path like 'file:///<path to pickle file>/obj.pkl'
pr = urlparse(config)
if pr.scheme == "file":
pr_path = os.path.join(pr.netloc, pr.path) if bool(pr.path) else pr.netloc
# To enable relative path like file://data/a/b/c.pkl. pr.netloc will be data
path = pr.path
if pr.netloc != "":
path = path.lstrip("/")
pr_path = os.path.join(pr.netloc, path) if bool(pr.path) else pr.netloc
with open(os.path.normpath(pr_path), "rb") as f:
return pickle.load(f)
else:

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import threading
from functools import partial
from threading import Thread
from typing import Callable, Text, Union
@@ -9,7 +10,7 @@ from joblib import Parallel, delayed
from joblib._parallel_backends import MultiprocessingBackend
import pandas as pd
from queue import Queue
from queue import Empty, Queue
import concurrent
from qlib.config import C, QlibConfig
@@ -85,7 +86,17 @@ class AsyncCaller:
def run(self):
while True:
data = self._q.get()
# NOTE:
# atexit will only trigger when all the threads ended. So it may results in deadlock.
# So the child-threading should actively watch the status of main threading to stop itself.
main_thread = threading.main_thread()
if not main_thread.is_alive():
break
try:
data = self._q.get(timeout=1)
except Empty:
# NOTE: avoid deadlock. make checking main thread possible
continue
if data == self.STOP_MARK:
break
data()

View File

@@ -1,18 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import sys
import os
from pathlib import Path
import sys
import fire
from jinja2 import Template, meta
from ruamel.yaml import YAML
import qlib
import fire
import ruamel.yaml as yaml
from qlib.config import C
from qlib.model.trainer import task_train
from qlib.utils.data import update_config
from qlib.log import get_module_logger
from qlib.model.trainer import task_train
from qlib.utils import set_log_with_config
from qlib.utils.data import update_config
set_log_with_config(C.logging_config)
logger = get_module_logger("qrun", logging.INFO)
@@ -47,6 +49,39 @@ def sys_config(config, config_path):
sys.path.append(str(Path(config_path).parent.resolve().absolute() / p))
def render_template(config_path: str) -> str:
"""
render the template based on the environment
Parameters
----------
config_path : str
configuration path
Returns
-------
str
the rendered content
"""
with open(config_path, "r") as f:
config = f.read()
# Set up the Jinja2 environment
template = Template(config)
# Parse the template to find undeclared variables
env = template.environment
parsed_content = env.parse(config)
variables = meta.find_undeclared_variables(parsed_content)
# Get context from os.environ according to the variables
context = {var: os.getenv(var, "") for var in variables if var in os.environ}
logger.info(f"Render the template with the context: {context}")
# Render the template with the context
rendered_content = template.render(context)
return rendered_content
# workflow handler function
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
"""
@@ -67,8 +102,10 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
market: csi300
"""
with open(config_path) as fp:
config = yaml.safe_load(fp)
# Render the template
rendered_yaml = render_template(config_path)
yaml = YAML(typ="safe", pure=True)
config = yaml.load(rendered_yaml)
base_config_path = config.get("BASE_CONFIG_PATH", None)
if base_config_path:
@@ -90,7 +127,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}")
with open(path) as fp:
base_config = yaml.safe_load(fp)
yaml = YAML(typ="safe", pure=True)
base_config = yaml.load(fp)
logger.info(f"Load BASE_CONFIG_PATH succeed: {path.resolve()}")
config = update_config(base_config, config)

View File

@@ -8,6 +8,7 @@ from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCod
from mlflow.entities import ViewType
import os
from typing import Optional, Text
from pathlib import Path
from .exp import MLflowExperiment, Experiment
from ..config import C
@@ -233,7 +234,7 @@ class ExpManager:
# So we supported it in the interface wrapper
pr = urlparse(self.uri)
if pr.scheme == "file":
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")): # pylint: disable=E0110
with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110
return self.create_exp(experiment_name), True
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
try:
@@ -421,7 +422,11 @@ class MLflowExpManager(ExpManager):
def list_experiments(self):
# retrieve all the existing experiments
exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY)
mlflow_version = int(mlflow.__version__.split(".", maxsplit=1)[0])
if mlflow_version >= 2:
exps = self.client.search_experiments(view_type=ViewType.ACTIVE_ONLY)
else:
exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY) # pylint: disable=E1101
experiments = dict()
for exp in exps:
experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)

View File

@@ -9,6 +9,7 @@ import shutil
import pickle
import tempfile
import subprocess
import platform
from pathlib import Path
from datetime import datetime
@@ -316,7 +317,10 @@ class MLflowRecorder(Recorder):
This function will return the directory path of this recorder.
"""
if self.artifact_uri is not None:
local_dir_path = Path(self.artifact_uri.lstrip("file:")) / ".."
if platform.system() == "Windows":
local_dir_path = Path(self.artifact_uri.lstrip("file:").lstrip("/")).parent
else:
local_dir_path = Path(self.artifact_uri.lstrip("file:")).parent
local_dir_path = str(local_dir_path.resolve())
if os.path.isdir(local_dir_path):
return local_dir_path

View File

@@ -242,7 +242,7 @@ class TimeAdjuster:
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
"""
Shift the datatime of segment
Shift the datetime of segment
If there are None (which indicates unbounded index) in the segment, this method will return None.

View File

@@ -301,6 +301,7 @@ class Normalize:
na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns},
)
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
df = self._normalize_obj.normalize(df)
if df is not None and not df.empty:
if self._end_date is not None:

View File

@@ -28,7 +28,7 @@ termcolor==1.1.0
tqdm==4.63.0
trio==0.20.0
trio-websocket==0.9.2
urllib3==1.26.8
urllib3==1.26.19
wget==3.2
wsproto==1.1.0
yahooquery==2.2.15

Some files were not shown because too many files have changed in this diff Show More