mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
44 Commits
bump_versi
...
v0.9.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a84aaf1dc | ||
|
|
afbb178e24 | ||
|
|
a0cef033cb | ||
|
|
7acb4f3484 | ||
|
|
431f574967 | ||
|
|
b604fe56b3 | ||
|
|
af4b8772d2 | ||
|
|
18fcdf1521 | ||
|
|
f2caf452e9 | ||
|
|
ca9f1861a4 | ||
|
|
b45b006ef2 | ||
|
|
82cf438401 | ||
|
|
9e635168c0 | ||
|
|
b7ace1a622 | ||
|
|
c9ed050ef0 | ||
|
|
2c33332dd6 | ||
|
|
a7d5a9b500 | ||
|
|
5190332c7e | ||
|
|
cde80206e4 | ||
|
|
a339fc11d1 | ||
|
|
33482047dc | ||
|
|
47bd13295b | ||
|
|
ebc0ca893e | ||
|
|
3a348aec9f | ||
|
|
37b908792b | ||
|
|
73ec0f4003 | ||
|
|
155c17f8ff | ||
|
|
41b94059aa | ||
|
|
7db83d84b7 | ||
|
|
35e0fdd1c0 | ||
|
|
598017f634 | ||
|
|
907c888c23 | ||
|
|
02fe6b6974 | ||
|
|
b892b21045 | ||
|
|
155f80323c | ||
|
|
63021018d6 | ||
|
|
f79a0eeaff | ||
|
|
8a087d0db9 | ||
|
|
2ae4be426a | ||
|
|
6ed83f7c04 | ||
|
|
917e3a725e | ||
|
|
b1e0e77c97 | ||
|
|
ea245f5435 | ||
|
|
3779b5186a |
8
.dockerignore
Normal file
8
.dockerignore
Normal file
@@ -0,0 +1,8 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
.env
|
||||
.git
|
||||
|
||||
66
.github/workflows/python-publish.yml
vendored
66
.github/workflows/python-publish.yml
vendored
@@ -12,70 +12,54 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-11]
|
||||
# FIXME: macos-latest will raise error now.
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, macos-13, macos-latest]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
exclude:
|
||||
- os: macos-13
|
||||
python-version: "3.11"
|
||||
- os: macos-13
|
||||
python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
# This is because on macos systems you can install pyqlib using
|
||||
# `pip install pyqlib` installs, it does not recognize the
|
||||
# `pyqlib-<version>-cp38-cp38-macosx_11_0_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_11_0_x86_64.whl`.
|
||||
# So we limit the version of python, in order to generate a version of qlib that is usable for macos: `pyqlib-<veresion>-cp38-cp37m
|
||||
# `pyqlib-<version>-cp38-cp38-macosx_10_15_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_10_15_x86_64.whl`.
|
||||
# Python 3.7.16, 3.8.16 can build macosx_10_15. But Python 3.7.17, 3.8.17 can build macosx_11_0
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os == 'macos-11' && matrix.python-version == '3.7'
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os == 'macos-11' && matrix.python-version == '3.8'
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.8.16"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os != 'macos-11'
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install setuptools wheel twine
|
||||
make dev
|
||||
- name: Build wheel on ${{ matrix.os }}
|
||||
run: |
|
||||
pip install numpy
|
||||
pip install cython
|
||||
python setup.py bdist_wheel
|
||||
- name: Build and publish
|
||||
make build
|
||||
- name: Upload to PyPi
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine upload dist/*
|
||||
twine check dist/*.whl
|
||||
twine upload dist/*.whl --verbose
|
||||
|
||||
deploy_with_manylinux:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Build wheel on Linux
|
||||
uses: RalfG/python-wheels-manylinux-build@v0.3.1-manylinux2010_x86_64
|
||||
uses: RalfG/python-wheels-manylinux-build@v0.7.1-manylinux2014_x86_64
|
||||
with:
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-versions: 'cp37-cp37m cp38-cp38'
|
||||
python-versions: 'cp38-cp38 cp39-cp39 cp310-cp310 cp311-cp311 cp312-cp312'
|
||||
build-requirements: 'numpy cython'
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.7
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install twine
|
||||
- name: Build and publish
|
||||
python -m pip install twine
|
||||
- name: Upload to PyPi
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine upload dist/pyqlib-*-manylinux*.whl
|
||||
twine check dist/pyqlib-*-manylinux*.whl
|
||||
twine upload dist/pyqlib-*-manylinux*.whl --verbose
|
||||
|
||||
25
.github/workflows/test_qlib_from_pip.yml
vendored
25
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -13,28 +13,17 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from pip
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -45,10 +34,10 @@ jobs:
|
||||
|
||||
- name: Qlib installation test
|
||||
run: |
|
||||
python -m pip install pyqlib
|
||||
python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pyqlib==0.9.5.80
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
|
||||
99
.github/workflows/test_qlib_from_source.yml
vendored
99
.github/workflows/test_qlib_from_source.yml
vendored
@@ -14,28 +14,17 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -45,7 +34,7 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
@@ -61,85 +50,33 @@ jobs:
|
||||
|
||||
- name: Set up Python tools
|
||||
run: |
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install -e .[dev]
|
||||
make dev
|
||||
|
||||
- name: Lint with Black
|
||||
# Python 3.7 will use a black with low level. So we use python with higher version for black check
|
||||
if: (matrix.python-version != '3.7')
|
||||
run: |
|
||||
pip install -U black # follow the latest version of black, previous Qlib dependency will downgrade black
|
||||
black . -l 120 --check --diff
|
||||
make black
|
||||
|
||||
- name: Make html with sphinx
|
||||
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
|
||||
if: ${{ matrix.os == 'ubuntu-22.04' }}
|
||||
run: |
|
||||
cd docs
|
||||
sphinx-build -W --keep-going -b html . _build
|
||||
cd ..
|
||||
make docs-gen
|
||||
|
||||
# Check Qlib with pylint
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
# C0209: consider-using-f-string
|
||||
# R0402: consider-using-from-import
|
||||
# R1705: no-else-return
|
||||
# R1710: inconsistent-return-statements
|
||||
# R1725: super-with-arguments
|
||||
# R1735: use-dict-literal
|
||||
# W0102: dangerous-default-value
|
||||
# W0212: protected-access
|
||||
# W0221: arguments-differ
|
||||
# W0223: abstract-method
|
||||
# W0231: super-init-not-called
|
||||
# W0237: arguments-renamed
|
||||
# W0612: unused-variable
|
||||
# W0621: redefined-outer-name
|
||||
# W0622: redefined-builtin
|
||||
# FIXME: specify exception type
|
||||
# W0703: broad-except
|
||||
# W1309: f-string-without-interpolation
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
|
||||
- name: Check Qlib with pylint
|
||||
run: |
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
make pylint
|
||||
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
# Description: We have used black to limit the length of each line to 120.
|
||||
# F541 f-string is missing placeholders
|
||||
# Description: The same thing is done when using pylint for detection.
|
||||
# E266 too many leading '#' for block comment
|
||||
# Description: To make the code more readable, a lot of "#" is used.
|
||||
# This error code appears centrally in:
|
||||
# qlib/backtest/executor.py
|
||||
# qlib/data/ops.py
|
||||
# qlib/utils/__init__.py
|
||||
# E402 module level import not at top of file
|
||||
# Description: There are times when module level import is not available at the top of the file.
|
||||
# W503 line break before binary operator
|
||||
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
|
||||
# E731 do not assign a lambda expression, use a def
|
||||
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
|
||||
# E203 whitespace before ':'
|
||||
# Description: If there is whitespace before ":", it cannot pass the black check.
|
||||
- name: Check Qlib with flake8
|
||||
run: |
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
make flake8
|
||||
|
||||
# https://github.com/python/mypy/issues/10600
|
||||
- name: Check Qlib with mypy
|
||||
run: |
|
||||
mypy qlib --install-types --non-interactive || true
|
||||
mypy qlib --verbose
|
||||
make mypy
|
||||
|
||||
- name: Check Qlib ipynb with nbqa
|
||||
run: |
|
||||
nbqa black . -l 120 --check --diff
|
||||
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$'
|
||||
make nbqa
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
@@ -147,7 +84,7 @@ jobs:
|
||||
python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
@@ -157,11 +94,9 @@ jobs:
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
|
||||
# Run after data downloads
|
||||
- name: Check Qlib ipynb with nbconvert
|
||||
run: |
|
||||
# add more ipynb files in future
|
||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||
make nbconvert
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
|
||||
27
.github/workflows/test_qlib_from_source_slow.yml
vendored
27
.github/workflows/test_qlib_from_source_slow.yml
vendored
@@ -14,44 +14,31 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source slow
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
make dev
|
||||
|
||||
- name: Downloads dependencies data
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -49,3 +49,4 @@ tags
|
||||
|
||||
./pretrain
|
||||
.idea/
|
||||
.aider*
|
||||
|
||||
@@ -5,6 +5,12 @@
|
||||
# Required
|
||||
version: 2
|
||||
|
||||
# Set the version of Python and other tools you might need
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.8"
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
configuration: docs/conf.py
|
||||
@@ -14,7 +20,6 @@ formats: all
|
||||
|
||||
# Optionally set the version of Python and requirements required to build your docs
|
||||
python:
|
||||
version: 3.7
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- method: pip
|
||||
31
Dockerfile
Normal file
31
Dockerfile
Normal file
@@ -0,0 +1,31 @@
|
||||
FROM continuumio/miniconda3:latest
|
||||
|
||||
WORKDIR /qlib
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential
|
||||
|
||||
RUN conda create --name qlib_env python=3.8 -y
|
||||
RUN echo "conda activate qlib_env" >> ~/.bashrc
|
||||
ENV PATH /opt/conda/envs/qlib_env/bin:$PATH
|
||||
|
||||
RUN python -m pip install --upgrade pip
|
||||
|
||||
RUN python -m pip install numpy==1.23.5
|
||||
RUN python -m pip install pandas==1.5.3
|
||||
RUN python -m pip install importlib-metadata==5.2.0
|
||||
RUN python -m pip install "cloudpickle<3"
|
||||
RUN python -m pip install scikit-learn==1.3.2
|
||||
|
||||
RUN python -m pip install cython packaging tables matplotlib statsmodels
|
||||
RUN python -m pip install pybind11 cvxpy
|
||||
|
||||
ARG IS_STABLE="yes"
|
||||
|
||||
RUN if [ "$IS_STABLE" = "yes" ]; then \
|
||||
python -m pip install pyqlib; \
|
||||
else \
|
||||
python setup.py install; \
|
||||
fi
|
||||
@@ -1 +1,6 @@
|
||||
include qlib/VERSION.txt
|
||||
exclude tests/*
|
||||
include qlib/*
|
||||
include qlib/*/*
|
||||
include qlib/*/*/*
|
||||
include qlib/*/*/*/*
|
||||
include qlib/*/*/*/*/*
|
||||
|
||||
195
Makefile
Normal file
195
Makefile
Normal file
@@ -0,0 +1,195 @@
|
||||
.PHONY: clean deepclean prerequisite dependencies lightgbm rl develop lint docs package test analysis all install dev black pylint flake8 mypy nbqa nbconvert lint build upload docs-gen
|
||||
#You can modify it according to your terminal
|
||||
SHELL := /bin/bash
|
||||
|
||||
########################################################################################
|
||||
# Variables
|
||||
########################################################################################
|
||||
|
||||
# Documentation target directory, will be adapted to specific folder for readthedocs.
|
||||
PUBLIC_DIR := $(shell [ "$$READTHEDOCS" = "True" ] && echo "$$READTHEDOCS_OUTPUT/html" || echo "public")
|
||||
|
||||
SO_DIR := qlib/data/_libs
|
||||
SO_FILES := $(wildcard $(SO_DIR)/*.so)
|
||||
|
||||
########################################################################################
|
||||
# Development Environment Management
|
||||
########################################################################################
|
||||
# Remove common intermediate files.
|
||||
clean:
|
||||
-rm -rf \
|
||||
$(PUBLIC_DIR) \
|
||||
qlib/data/_libs/*.cpp \
|
||||
qlib/data/_libs/*.so \
|
||||
mlruns \
|
||||
public \
|
||||
build \
|
||||
.coverage \
|
||||
.mypy_cache \
|
||||
.pytest_cache \
|
||||
.ruff_cache \
|
||||
Pipfile* \
|
||||
coverage.xml \
|
||||
dist \
|
||||
release-notes.md
|
||||
|
||||
find . -name '*.egg-info' -print0 | xargs -0 rm -rf
|
||||
find . -name '*.pyc' -print0 | xargs -0 rm -f
|
||||
find . -name '*.swp' -print0 | xargs -0 rm -f
|
||||
find . -name '.DS_Store' -print0 | xargs -0 rm -f
|
||||
find . -name '__pycache__' -print0 | xargs -0 rm -rf
|
||||
|
||||
# Remove pre-commit hook, virtual environment alongside itermediate files.
|
||||
deepclean: clean
|
||||
if command -v pre-commit > /dev/null 2>&1; then pre-commit uninstall --hook-type pre-push; fi
|
||||
if command -v pipenv >/dev/null 2>&1 && pipenv --venv >/dev/null 2>&1; then pipenv --rm; fi
|
||||
|
||||
# Prerequisite section
|
||||
# What this code does is compile two Cython modules, rolling and expanding, using setuptools and Cython,
|
||||
# and builds them as binary expansion modules that can be imported directly into Python.
|
||||
# Since pyproject.toml can't do that, we compile it here.
|
||||
prerequisite:
|
||||
@if [ -n "$(SO_FILES)" ]; then \
|
||||
echo "Shared library files exist, skipping build."; \
|
||||
else \
|
||||
echo "No shared library files found, building..."; \
|
||||
pip install --upgrade setuptools wheel; \
|
||||
python -m pip install cython numpy; \
|
||||
python -c "from setuptools import setup, Extension; from Cython.Build import cythonize; import numpy; extensions = [Extension('qlib.data._libs.rolling', ['qlib/data/_libs/rolling.pyx'], language='c++', include_dirs=[numpy.get_include()]), Extension('qlib.data._libs.expanding', ['qlib/data/_libs/expanding.pyx'], language='c++', include_dirs=[numpy.get_include()])]; setup(ext_modules=cythonize(extensions, language_level='3'), script_args=['build_ext', '--inplace'])"; \
|
||||
fi
|
||||
|
||||
# Install the package in editable mode.
|
||||
dependencies:
|
||||
python -m pip install -e .
|
||||
|
||||
lightgbm:
|
||||
python -m pip install lightgbm --prefer-binary
|
||||
|
||||
rl:
|
||||
python -m pip install -e .[rl]
|
||||
|
||||
develop:
|
||||
python -m pip install -e .[dev]
|
||||
|
||||
lint:
|
||||
python -m pip install -e .[lint]
|
||||
|
||||
docs:
|
||||
python -m pip install -e .[docs]
|
||||
|
||||
package:
|
||||
python -m pip install -e .[package]
|
||||
|
||||
test:
|
||||
python -m pip install -e .[test]
|
||||
|
||||
analysis:
|
||||
python -m pip install -e .[analysis]
|
||||
|
||||
all:
|
||||
python -m pip install -e .[dev,lint,docs,package,test,analysis,rl]
|
||||
|
||||
install: prerequisite dependencies
|
||||
|
||||
dev: prerequisite all
|
||||
|
||||
########################################################################################
|
||||
# Lint and pre-commit
|
||||
########################################################################################
|
||||
|
||||
# Check lint with black.
|
||||
black:
|
||||
black . -l 120 --check --diff
|
||||
|
||||
# Check code folder with pylint.
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
# C0209: consider-using-f-string
|
||||
# R0402: consider-using-from-import
|
||||
# R1705: no-else-return
|
||||
# R1710: inconsistent-return-statements
|
||||
# R1725: super-with-arguments
|
||||
# R1735: use-dict-literal
|
||||
# W0102: dangerous-default-value
|
||||
# W0212: protected-access
|
||||
# W0221: arguments-differ
|
||||
# W0223: abstract-method
|
||||
# W0231: super-init-not-called
|
||||
# W0237: arguments-renamed
|
||||
# W0612: unused-variable
|
||||
# W0621: redefined-outer-name
|
||||
# W0622: redefined-builtin
|
||||
# FIXME: specify exception type
|
||||
# W0703: broad-except
|
||||
# W1309: f-string-without-interpolation
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# W4904: deprecated-class
|
||||
# R0917: too-many-positional-arguments
|
||||
# E1123: unexpected-keyword-arg
|
||||
# References for disable error: https://pylint.pycqa.org/en/latest/user_guide/messages/messages_overview.html
|
||||
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
pylint:
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,W4904,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1730,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' qlib --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,E1123,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' scripts --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
|
||||
# Check code with flake8.
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
# Description: We have used black to limit the length of each line to 120.
|
||||
# F541 f-string is missing placeholders
|
||||
# Description: The same thing is done when using pylint for detection.
|
||||
# E266 too many leading '#' for block comment
|
||||
# Description: To make the code more readable, a lot of "#" is used.
|
||||
# This error code appears centrally in:
|
||||
# qlib/backtest/executor.py
|
||||
# qlib/data/ops.py
|
||||
# qlib/utils/__init__.py
|
||||
# E402 module level import not at top of file
|
||||
# Description: There are times when module level import is not available at the top of the file.
|
||||
# W503 line break before binary operator
|
||||
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
|
||||
# E731 do not assign a lambda expression, use a def
|
||||
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
|
||||
# E203 whitespace before ':'
|
||||
# Description: If there is whitespace before ":", it cannot pass the black check.
|
||||
flake8:
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
# Check code with mypy.
|
||||
# https://github.com/python/mypy/issues/10600
|
||||
mypy:
|
||||
mypy qlib --install-types --non-interactive
|
||||
mypy qlib --verbose
|
||||
|
||||
# Check ipynb with nbqa.
|
||||
nbqa:
|
||||
nbqa black . -l 120 --check --diff
|
||||
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}'
|
||||
|
||||
# Check ipynb with nbconvert.(Run after data downloads)
|
||||
# TODO: Add more ipynb files in future
|
||||
nbconvert:
|
||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||
|
||||
lint: black pylint flake8 mypy nbqa
|
||||
|
||||
########################################################################################
|
||||
# Package
|
||||
########################################################################################
|
||||
|
||||
# Build the package.
|
||||
build:
|
||||
python -m build --wheel
|
||||
|
||||
# Upload the package.
|
||||
upload:
|
||||
python -m twine upload dist/*
|
||||
|
||||
########################################################################################
|
||||
# Documentation
|
||||
########################################################################################
|
||||
|
||||
docs-gen:
|
||||
python -m sphinx.cmd.build -W docs $(PUBLIC_DIR)
|
||||
95
README.md
95
README.md
@@ -8,9 +8,30 @@
|
||||
[](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
## :newspaper: **What's NEW!** :sparkling_heart:
|
||||
|
||||
Recent released features
|
||||
|
||||
### Introducing <a href="https://github.com/microsoft/RD-Agent"><img src="docs/_static/img/rdagent_logo.png" alt="RD_Agent" style="height: 2em"></a>: LLM-Based Autonomous Evolving Agents for Industrial Data-Driven R&D
|
||||
|
||||
We are excited to announce the release of **RD-Agent**📢, a powerful tool that supports automated factor mining and model optimization in quant investment R&D.
|
||||
|
||||
RD-Agent is now available on [GitHub](https://github.com/microsoft/RD-Agent), and we welcome your star🌟!
|
||||
|
||||
To learn more, please visit our [♾️Demo page](https://rdagent.azurewebsites.net/). Here, you will find demo videos in both English and Chinese to help you better understand the scenario and usage of RD-Agent.
|
||||
|
||||
We have prepared several demo videos for you:
|
||||
| Scenario | Demo video (English) | Demo video (中文) |
|
||||
| -- | ------ | ------ |
|
||||
| Quant Factor Mining | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=zh) |
|
||||
| Quant Factor Mining from reports | [Link](https://rdagent.azurewebsites.net/report_factor?lang=en) | [Link](https://rdagent.azurewebsites.net/report_factor?lang=zh) |
|
||||
| Quant Model Optimization | [Link](https://rdagent.azurewebsites.net/model_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/model_loop?lang=zh) |
|
||||
|
||||
***
|
||||
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| BPQP for End-to-end learning | 📈Coming soon!([Under review](https://github.com/microsoft/qlib/pull/1863)) |
|
||||
| 🔥LLM-driven Auto Quant Factory🔥 | 🚀 Released in [♾️RD-Agent](https://github.com/microsoft/RD-Agent) on Aug 8, 2024 |
|
||||
| KRNN and Sandwich models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1414/) on May 26, 2023 |
|
||||
| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |
|
||||
| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|
|
||||
@@ -40,7 +61,7 @@ Recent released features
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
|
||||
<img src="docs/_static/img/logo/1.png" />
|
||||
</p>
|
||||
|
||||
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
|
||||
@@ -132,11 +153,11 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
|
||||
## Installation
|
||||
|
||||
This table demonstrates the supported Python version of `Qlib`:
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:----:|
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:------------------:|
|
||||
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
|
||||
**Note**:
|
||||
1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.
|
||||
@@ -166,7 +187,7 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
* Clone the repository and install ``Qlib`` as follows.
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
pip install .
|
||||
pip install . # `pip install -e .[dev]` is recommended for development. check details in docs/developer/code_standard_and_dev_guide.rst
|
||||
```
|
||||
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommended approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
|
||||
@@ -175,6 +196,20 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
**Tips for Mac**: If you are using Mac with M1, you might encounter issues in building the wheel for LightGBM, which is due to missing dependencies from OpenMP. To solve the problem, install openmp first with ``brew install libomp`` and then run ``pip install .`` to build it successfully.
|
||||
|
||||
## Data Preparation
|
||||
❗ Due to more restrict data security policy. The offical dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.
|
||||
Here is an example to download the data updated on 20240809.
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/2024-08-09/qlib_bin.tar.gz
|
||||
mkdir -p ~/.qlib/qlib_data/cn_data
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=1
|
||||
rm -f qlib_bin.tar.gz
|
||||
```
|
||||
|
||||
The official dataset below will resume in short future.
|
||||
|
||||
|
||||
----
|
||||
|
||||
Load and prepare data by running the following code:
|
||||
|
||||
### Get with module
|
||||
@@ -258,6 +293,38 @@ We recommend users to prepare their own data if they have a high-quality dataset
|
||||
```
|
||||
-->
|
||||
|
||||
## Docker images
|
||||
1. Pulling a docker image from a docker hub repository
|
||||
```bash
|
||||
docker pull pyqlib/qlib_image_stable:stable
|
||||
```
|
||||
2. Start a new Docker container
|
||||
```bash
|
||||
docker run -it --name <container name> -v <Mounted local directory>:/app qlib_image_stable
|
||||
```
|
||||
3. At this point you are in the docker environment and can run the qlib scripts. An example:
|
||||
```bash
|
||||
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
>>> python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
4. Exit the container
|
||||
```bash
|
||||
>>> exit
|
||||
```
|
||||
5. Restart the container
|
||||
```bash
|
||||
docker start -i -a <container name>
|
||||
```
|
||||
6. Stop the container
|
||||
```bash
|
||||
docker stop <container name>
|
||||
```
|
||||
7. Delete the container
|
||||
```bash
|
||||
docker rm <container name>
|
||||
```
|
||||
8. If you want to know more information, please refer to the [documentation](https://qlib.readthedocs.io/en/latest/developer/how_to_build_image.html).
|
||||
|
||||
## Auto Quant Research Workflow
|
||||
Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||
|
||||
@@ -291,22 +358,22 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
```
|
||||
Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
|
||||
|
||||
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
|
||||
2. Graphical Reports Analysis: First, run `python -m pip install .[analysis]` to install the required dependencies. Then run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports.
|
||||
- Forecasting signal (model prediction) analysis
|
||||
- Cumulative Return of groups
|
||||

|
||||

|
||||
- Return distribution
|
||||

|
||||

|
||||
- Information Coefficient (IC)
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
- Auto Correlation of forecasting signal (model prediction)
|
||||

|
||||

|
||||
|
||||
- Portfolio analysis
|
||||
- Backtest return
|
||||

|
||||

|
||||
<!--
|
||||
- Score IC
|
||||

|
||||
@@ -485,7 +552,7 @@ Qlib data are stored in a compact format, which is efficient to be combined into
|
||||
Join IM discussion groups:
|
||||
|[Gitter](https://gitter.im/Microsoft/qlib)|
|
||||
|----|
|
||||
||
|
||||
||
|
||||
|
||||
# Contributing
|
||||
We appreciate all contributions and thank all the contributors!
|
||||
|
||||
31
build_docker_image.sh
Normal file
31
build_docker_image.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
docker_user="your_dockerhub_username"
|
||||
|
||||
read -p "Do you want to build the nightly version of the qlib image? (default is stable) (yes/no): " answer;
|
||||
answer=$(echo "$answer" | tr '[:upper:]' '[:lower:]')
|
||||
|
||||
if [ "$answer" = "yes" ]; then
|
||||
# Build the nightly version of the qlib image
|
||||
docker build --build-arg IS_STABLE=no -t qlib_image -f ./Dockerfile .
|
||||
image_tag="nightly"
|
||||
else
|
||||
# Build the stable version of the qlib image
|
||||
docker build -t qlib_image -f ./Dockerfile .
|
||||
image_tag="stable"
|
||||
fi
|
||||
|
||||
read -p "Is it uploaded to docker hub? (default is no) (yes/no): " answer;
|
||||
answer=$(echo "$answer" | tr '[:upper:]' '[:lower:]')
|
||||
|
||||
if [ "$answer" = "yes" ]; then
|
||||
# Log in to Docker Hub
|
||||
# If you are a new docker hub user, please verify your email address before proceeding with this step.
|
||||
docker login
|
||||
# Tag the Docker image
|
||||
docker tag qlib_image "$docker_user/qlib_image:$image_tag"
|
||||
# Push the Docker image to Docker Hub
|
||||
docker push "$docker_user/qlib_image:$image_tag"
|
||||
else
|
||||
echo "Not uploaded to docker hub."
|
||||
fi
|
||||
BIN
docs/_static/img/rdagent_logo.png
vendored
Normal file
BIN
docs/_static/img/rdagent_logo.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 94 KiB |
@@ -86,7 +86,7 @@ Example
|
||||
},
|
||||
}
|
||||
|
||||
# model initiaiton
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
|
||||
@@ -123,7 +123,6 @@ html_logo = "_static/img/logo/1.png"
|
||||
html_theme_options = {
|
||||
"logo_only": True,
|
||||
"collapse_navigation": False,
|
||||
"display_version": False,
|
||||
"navigation_depth": 4,
|
||||
}
|
||||
|
||||
|
||||
@@ -60,4 +60,4 @@ The `[dev]` option will help you to install some related packages when developin
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -e .[dev]
|
||||
pip install -e ".[dev]"
|
||||
81
docs/developer/how_to_build_image.rst
Normal file
81
docs/developer/how_to_build_image.rst
Normal file
@@ -0,0 +1,81 @@
|
||||
.. _docker_image:
|
||||
|
||||
==================
|
||||
Build Docker Image
|
||||
==================
|
||||
|
||||
Dockerfile
|
||||
==========
|
||||
|
||||
There is a **Dockerfile** file in the root directory of the project from which you can build the docker image. There are two build methods in Dockerfile to choose from.
|
||||
When executing the build command, use the ``--build-arg`` parameter to control the image version. The ``--build-arg`` parameter defaults to ``yes``, which builds the ``stable`` version of the qlib image.
|
||||
|
||||
1.For the ``stable`` version, use ``pip install pyqlib`` to build the qlib image.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker build --build-arg IS_STABLE=yes -t <image name> -f ./Dockerfile .
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker build -t <image name> -f ./Dockerfile .
|
||||
|
||||
2. For the ``nightly`` version, use current source code to build the qlib image.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker build --build-arg IS_STABLE=no -t <image name> -f ./Dockerfile .
|
||||
|
||||
Auto build of qlib images
|
||||
=========================
|
||||
|
||||
1. There is a **build_docker_image.sh** file in the root directory of your project, which can be used to automatically build docker images and upload them to your docker hub repository(Optional, configuration required).
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
sh build_docker_image.sh
|
||||
>>> Do you want to build the nightly version of the qlib image? (default is stable) (yes/no):
|
||||
>>> Is it uploaded to docker hub? (default is no) (yes/no):
|
||||
|
||||
2. If you want to upload the built image to your docker hub repository, you need to edit your **build_docker_image.sh** file first, fill in ``docker_user`` in the file, and then execute this file.
|
||||
|
||||
How to use qlib images
|
||||
======================
|
||||
1. Start a new Docker container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker run -it --name <container name> -v <Mounted local directory>:/app <image name>
|
||||
|
||||
2. At this point you are in the docker environment and can run the qlib scripts. An example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
>>> python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
3. Exit the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
>>> exit
|
||||
|
||||
4. Restart the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker start -i -a <container name>
|
||||
|
||||
5. Stop the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker stop -i -a <container name>
|
||||
|
||||
6. Delete the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker rm <container name>
|
||||
|
||||
7. For more information on using docker see the `docker documentation <https://docs.docker.com/reference/cli/docker/>`_.
|
||||
@@ -61,6 +61,7 @@ Document Structure
|
||||
:caption: FOR DEVELOPERS:
|
||||
|
||||
Code Standard & Development Guidance <developer/code_standard_and_dev_guide.rst>
|
||||
How to build image <developer/how_to_build_image.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -5,3 +5,4 @@ scipy
|
||||
scikit-learn
|
||||
pandas
|
||||
tianshou
|
||||
sphinx_rtd_theme
|
||||
|
||||
19
examples/benchmarks/GeneralPtNN/README.md
Normal file
19
examples/benchmarks/GeneralPtNN/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
|
||||
|
||||
# Introduction
|
||||
|
||||
What is GeneralPtNN
|
||||
- Fix previous design that fail to support both Time-series and tabular data
|
||||
- Now you can just replace the Pytorch model structure to run a NN model.
|
||||
|
||||
We provide an example to demonstrate the effectiveness of the current design.
|
||||
- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158-dataset)
|
||||
- `workflow_config_gru2mlp.yaml` to demonstrate we can convert config from time-series to tabular data with minimal changes
|
||||
- You only have to change the net & dataset class to make the conversion.
|
||||
- `workflow_config_mlp.yaml` achieved similar functionality with [MLP](../README.md#Alpha158-dataset)
|
||||
|
||||
# TODO
|
||||
|
||||
- We will align existing models to current design.
|
||||
|
||||
- The result of `workflow_config_mlp.yaml` is different with the result of [MLP](../README.md#Alpha158-dataset) since GeneralPtNN has a different stopping method compared to previous implementations. Specificly, GeneralPtNN controls training according to epoches, whereas previous methods controlled by max_steps.
|
||||
100
examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml
Executable file
100
examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml
Executable file
@@ -0,0 +1,100 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
n_epochs: 200
|
||||
lr: 2e-4
|
||||
early_stop: 10
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
n_jobs: 20
|
||||
GPU: 0
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_gru_ts.GRUModel"
|
||||
pt_model_kwargs: {
|
||||
"d_feat": 20,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.,
|
||||
}
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
93
examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml
Normal file
93
examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
lr: 1e-3
|
||||
n_epochs: 1
|
||||
batch_size: 800
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
|
||||
pt_model_kwargs:
|
||||
input_dim: 20
|
||||
layers: [20,]
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
98
examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml
Normal file
98
examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml
Normal file
@@ -0,0 +1,98 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "CSZFillna",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
}
|
||||
]
|
||||
learn_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "DropnaProcessor",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
},
|
||||
"DropnaLabel",
|
||||
{
|
||||
"class": "CSZScoreNorm",
|
||||
"kwargs": {"fields_group": "label"}
|
||||
}
|
||||
]
|
||||
process_type: "independent"
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
# FIXME: wrong parameters.
|
||||
lr: 2e-3
|
||||
batch_size: 8192
|
||||
loss: mse
|
||||
weight_decay: 0.0002
|
||||
optimizer: adam
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
|
||||
pt_model_kwargs:
|
||||
input_dim: 157
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,14 +1,15 @@
|
||||
import argparse
|
||||
|
||||
import qlib
|
||||
import ruamel.yaml as yaml
|
||||
from ruamel.yaml import YAML
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
|
||||
def main(seed, config_file="configs/config_alstm.yaml"):
|
||||
# set random seed
|
||||
with open(config_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(f)
|
||||
|
||||
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
|
||||
seed_suffix = ""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -35,6 +36,10 @@ class DDGDABench(DDGDA):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
auto_init()
|
||||
kwargs = {}
|
||||
if os.environ.get("PROVIDER_URI", "") == "":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
else:
|
||||
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
|
||||
auto_init(**kwargs)
|
||||
fire.Fire(DDGDABench)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -31,6 +32,10 @@ class RollingBenchmark(Rolling):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
auto_init()
|
||||
kwargs = {}
|
||||
if os.environ.get("PROVIDER_URI", "") == "":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
else:
|
||||
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
|
||||
auto_init(**kwargs)
|
||||
fire.Fire(RollingBenchmark)
|
||||
|
||||
@@ -9,8 +9,8 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from pprint import pprint
|
||||
from ruamel.yaml import YAML
|
||||
import subprocess
|
||||
import yaml
|
||||
from qlib.log import TimeInspector
|
||||
|
||||
from qlib import init
|
||||
@@ -30,7 +30,8 @@ if __name__ == "__main__":
|
||||
subprocess.run(f"qrun {config_path}", shell=True)
|
||||
|
||||
# 2) dump handler
|
||||
task_config = yaml.safe_load(config_path.open())
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
task_config = yaml.load(config_path.open())
|
||||
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
|
||||
pprint(hd_conf)
|
||||
hd: DataHandlerLP = init_instance_by_config(hd_conf)
|
||||
|
||||
@@ -9,10 +9,9 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from pprint import pprint
|
||||
from ruamel.yaml import YAML
|
||||
import subprocess
|
||||
|
||||
import yaml
|
||||
|
||||
from qlib import init
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.log import TimeInspector
|
||||
@@ -29,7 +28,8 @@ if __name__ == "__main__":
|
||||
exp_name = "data_mem_reuse_demo"
|
||||
|
||||
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
|
||||
task_config = yaml.safe_load(config_path.open())
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
task_config = yaml.load(config_path.open())
|
||||
|
||||
# 1) without using processed data in memory
|
||||
with TimeInspector.logt("The original time without reusing processed data in memory:"):
|
||||
|
||||
@@ -16,7 +16,7 @@ Current version of script with default value tries to connect localhost **via de
|
||||
|
||||
Run following command to install necessary libraries
|
||||
```
|
||||
pip install pytest coverage
|
||||
pip install pytest coverage gdown
|
||||
pip install arctic # NOTE: pip may fail to resolve the right package dependency !!! Please make sure the dependency are satisfied.
|
||||
```
|
||||
|
||||
@@ -27,7 +27,8 @@ pip install arctic # NOTE: pip may fail to resolve the right package dependency
|
||||
2. Please follow following steps to download example data
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
python ../../scripts/get_data.py download_data --target_dir . --file_name highfreq_orderbook_example_data.zip
|
||||
gdown https://drive.google.com/uc?id=15nZF7tFT_eKVZAcMFL1qPS4jGyJflH7e # Proxies may be necessary here.
|
||||
python ../../scripts/get_data.py _unzip --file_path highfreq_orderbook_example_data.zip --target_dir .
|
||||
```
|
||||
|
||||
3. Please import the example data to your mongo db
|
||||
|
||||
@@ -20,7 +20,7 @@ We use China stock market data for our example.
|
||||
1. Prepare CSI300 weight:
|
||||
|
||||
```bash
|
||||
wget http://fintech.msra.cn/stock_data/downloads/csi300_weight.zip
|
||||
wget https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/csi300_weight.zip
|
||||
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
|
||||
rm -f csi300_weight.zip
|
||||
```
|
||||
|
||||
@@ -6,7 +6,6 @@ import sys
|
||||
import fire
|
||||
import time
|
||||
import glob
|
||||
import yaml
|
||||
import shutil
|
||||
import signal
|
||||
import inspect
|
||||
@@ -15,6 +14,7 @@ import functools
|
||||
import statistics
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from ruamel.yaml import YAML
|
||||
from pathlib import Path
|
||||
from operator import xor
|
||||
from pprint import pprint
|
||||
@@ -188,7 +188,8 @@ def gen_and_save_md_table(metrics, dataset):
|
||||
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
||||
def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
|
||||
with open(yaml_path, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(fp)
|
||||
try:
|
||||
del config["task"]["model"]["kwargs"]["seed"]
|
||||
except KeyError:
|
||||
|
||||
@@ -161,7 +161,7 @@
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# model initiaiton\n",
|
||||
"# model initialization\n",
|
||||
"model = init_instance_by_config(task[\"model\"])\n",
|
||||
"dataset = init_instance_by_config(task[\"dataset\"])\n",
|
||||
"\n",
|
||||
|
||||
@@ -1,2 +1,93 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "numpy", "Cython"]
|
||||
requires = ["setuptools", "cython", "numpy>=1.24.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Operating System :: MacOS",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
name = "pyqlib"
|
||||
dynamic = ["version"]
|
||||
description = "A Quantitative-research Platform"
|
||||
requires-python = ">=3.8.0"
|
||||
readme = {file = "README.md", content-type = "text/markdown"}
|
||||
|
||||
dependencies = [
|
||||
"pyyaml",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"mlflow",
|
||||
"filelock>=3.16.0",
|
||||
"redis",
|
||||
"dill",
|
||||
"fire",
|
||||
"ruamel.yaml>=0.17.38",
|
||||
"python-redis-lock",
|
||||
"tqdm",
|
||||
"pymongo",
|
||||
"loguru",
|
||||
"lightgbm",
|
||||
"gym",
|
||||
"cvxpy",
|
||||
"joblib",
|
||||
"matplotlib",
|
||||
"jupyter",
|
||||
"nbconvert",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest",
|
||||
"statsmodels",
|
||||
]
|
||||
# On macos-13 system, when using python version greater than or equal to 3.10,
|
||||
# pytorch can't fully support Numpy version above 2.0, so, when you want to install torch,
|
||||
# it will limit the version of Numpy less than 2.0.
|
||||
rl = [
|
||||
"tianshou<=0.4.10",
|
||||
"torch",
|
||||
"numpy<2.0.0",
|
||||
]
|
||||
lint = [
|
||||
"black",
|
||||
"pylint",
|
||||
"mypy<1.5.0",
|
||||
"flake8",
|
||||
"nbqa",
|
||||
]
|
||||
docs = [
|
||||
"sphinx",
|
||||
"sphinx_rtd_theme",
|
||||
"readthedocs_sphinx_ext",
|
||||
]
|
||||
package = [
|
||||
"twine",
|
||||
"build",
|
||||
]
|
||||
# test_pit dependency packages
|
||||
test = [
|
||||
"yahooquery",
|
||||
"baostock",
|
||||
]
|
||||
analysis = [
|
||||
"plotly",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = [
|
||||
"qlib",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
qrun = "qlib.workflow.cli:run"
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.9.4"
|
||||
__version__ = "0.9.6"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
import yaml
|
||||
from ruamel.yaml import YAML
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
@@ -176,7 +176,8 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
config = {}
|
||||
else:
|
||||
with open(conf_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(f)
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
@@ -272,7 +273,8 @@ def auto_init(**kwargs):
|
||||
logger = get_module_logger("Initialization")
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
conf = yaml.load(f)
|
||||
|
||||
conf_type = conf.get("conf_type", "origin")
|
||||
if conf_type == "origin":
|
||||
|
||||
@@ -278,7 +278,7 @@ class BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `empty` method")
|
||||
|
||||
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
|
||||
"""Replace np.NaN with fill_value in two metrics and add them."""
|
||||
"""Replace np.nan with fill_value in two metrics and add them."""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `add` method")
|
||||
|
||||
@@ -412,7 +412,7 @@ class BaseOrderIndicator:
|
||||
metrics : Union[str, List[str]]
|
||||
all metrics needs to be sumed.
|
||||
fill_value : float, optional
|
||||
fill np.NaN with value. By default None.
|
||||
fill np.nan with value. By default None.
|
||||
"""
|
||||
|
||||
raise NotImplementedError(f"Please implement the 'sum_all_indicators' method")
|
||||
|
||||
@@ -325,9 +325,9 @@ class Indicator:
|
||||
|
||||
def _update_order_fulfill_rate(self) -> None:
|
||||
def func(deal_amount, amount):
|
||||
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
|
||||
# deal_amount is np.nan or None when there is no inner decision. So full fill rate is 0.
|
||||
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
|
||||
tmp_deal_amount = tmp_deal_amount.replace({np.NaN: 0})
|
||||
tmp_deal_amount = tmp_deal_amount.replace({np.nan: 0})
|
||||
return tmp_deal_amount / amount
|
||||
|
||||
self.order_indicator.transfer(func, "ffr")
|
||||
@@ -354,8 +354,8 @@ class Indicator:
|
||||
)
|
||||
|
||||
def func(trade_price, deal_amount):
|
||||
# trade_price is np.NaN instead of inf when deal_amount is zero.
|
||||
tmp_deal_amount = deal_amount.replace({0: np.NaN})
|
||||
# trade_price is np.nan instead of inf when deal_amount is zero.
|
||||
tmp_deal_amount = deal_amount.replace({0: np.nan})
|
||||
return trade_price / tmp_deal_amount
|
||||
|
||||
self.order_indicator.transfer(func, "trade_price")
|
||||
@@ -425,7 +425,7 @@ class Indicator:
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
|
||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||
# ~(np.NaN < 1e-8) -> ~(False) -> True
|
||||
# ~(np.nan < 1e-8) -> ~(False) -> True
|
||||
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
if agg == "vwap":
|
||||
|
||||
@@ -173,7 +173,11 @@ _default_config = {
|
||||
"filters": ["field_not_found"],
|
||||
}
|
||||
},
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||
# Normally this should be set to `False` to avoid duplicated logging [1].
|
||||
# However, due to bug in pytest, it requires log message to propagate to root logger to be captured by `caplog` [2].
|
||||
# [1] https://github.com/microsoft/qlib/pull/1661
|
||||
# [2] https://github.com/pytest-dev/pytest/issues/3697
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"], "propagate": False}},
|
||||
# To let qlib work with other packages, we shouldn't disable existing loggers.
|
||||
# Note that this param is default to True according to the documentation of logging.
|
||||
"disable_existing_loggers": False,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.contrib.data.loader import Alpha158DL, Alpha360DL
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_callable_kwargs
|
||||
@@ -57,7 +58,7 @@ class Alpha360(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
@@ -66,7 +67,7 @@ class Alpha360(DataHandlerLP):
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": self.get_feature_config(),
|
||||
"feature": Alpha360DL.get_feature_config(),
|
||||
"label": kwargs.pop("label", self.get_label_config()),
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
@@ -82,57 +83,12 @@ class Alpha360(DataHandlerLP):
|
||||
data_loader=data_loader,
|
||||
learn_processors=learn_processors,
|
||||
infer_processors=infer_processors,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_label_config(self):
|
||||
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
|
||||
|
||||
@staticmethod
|
||||
def get_feature_config():
|
||||
# NOTE:
|
||||
# Alpha360 tries to provide a dataset with original price data
|
||||
# the original price data includes the prices and volume in the last 60 days.
|
||||
# To make it easier to learn models from this dataset, all the prices and volume
|
||||
# are normalized by the latest price and volume data ( dividing by $close, $volume)
|
||||
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
|
||||
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($close, %d)/$close" % i]
|
||||
names += ["CLOSE%d" % i]
|
||||
fields += ["$close/$close"]
|
||||
names += ["CLOSE0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($open, %d)/$close" % i]
|
||||
names += ["OPEN%d" % i]
|
||||
fields += ["$open/$close"]
|
||||
names += ["OPEN0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($high, %d)/$close" % i]
|
||||
names += ["HIGH%d" % i]
|
||||
fields += ["$high/$close"]
|
||||
names += ["HIGH0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($low, %d)/$close" % i]
|
||||
names += ["LOW%d" % i]
|
||||
fields += ["$low/$close"]
|
||||
names += ["LOW0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($vwap, %d)/$close" % i]
|
||||
names += ["VWAP%d" % i]
|
||||
fields += ["$vwap/$close"]
|
||||
names += ["VWAP0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % i]
|
||||
names += ["VOLUME%d" % i]
|
||||
fields += ["$volume/($volume+1e-12)"]
|
||||
names += ["VOLUME0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class Alpha360vwap(Alpha360):
|
||||
def get_label_config(self):
|
||||
@@ -153,7 +109,7 @@ class Alpha158(DataHandlerLP):
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
@@ -178,7 +134,7 @@ class Alpha158(DataHandlerLP):
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
process_type=process_type,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
@@ -190,242 +146,11 @@ class Alpha158(DataHandlerLP):
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
return self.parse_config_to_fields(conf)
|
||||
return Alpha158DL.get_feature_config(conf)
|
||||
|
||||
def get_label_config(self):
|
||||
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
|
||||
|
||||
@staticmethod
|
||||
def parse_config_to_fields(config):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % d if d != 0 else "$volume/($volume+1e-12)" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
|
||||
def use(x):
|
||||
return x not in exclude and (include is None or x in include)
|
||||
|
||||
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
|
||||
if use("ROC"):
|
||||
# https://www.investopedia.com/terms/r/rateofchange.asp
|
||||
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
|
||||
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
# The rate of close price change in the past d days, divided by latest close price to remove unit
|
||||
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
# The R-sqaure value of linear regression for the past d days, represent the trend linear
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
# The max price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
# The low price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
# Used with MIN and MAX
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
# Get the percentile of current close price in past d day's close price.
|
||||
# Represent the current price level comparing to past N days, add additional information to moving average.
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
# Represent the price position between upper and lower resistent price for past d days.
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
# The number of days between current date and previous highest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
# The number of days between current date and previous lowest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
# The time period between previous lowest-price date occur after highest price date.
|
||||
# Large value suggest downward momemtum.
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
# The correlation between absolute close price and log scaled trading volume
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
# The correlation between price change ratio and volume change ratio
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
# The percentage of days in past d days that price go up.
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
# The percentage of days in past d days that price go down.
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
# The diff between past up day and past down day
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
# The total gain / the absolute total price changed
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
# The total lose / the absolute total price changed
|
||||
# Can be derived from SUMP by SUMN = 1 - SUMP
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
# The diff ratio between total gain and total lose
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
# The standard deviation for volume in past d days.
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
# The volume weighted price change volatility
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
# The diff ratio between total volume increase and total volume decrease
|
||||
# RSI indicator for volume
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class Alpha158vwap(Alpha158):
|
||||
def get_label_config(self):
|
||||
|
||||
310
qlib/contrib/data/loader.py
Normal file
310
qlib/contrib/data/loader.py
Normal file
@@ -0,0 +1,310 @@
|
||||
from qlib.data.dataset.loader import QlibDataLoader
|
||||
|
||||
|
||||
class Alpha360DL(QlibDataLoader):
|
||||
"""Dataloader to get Alpha360"""
|
||||
|
||||
def __init__(self, config=None, **kwargs):
|
||||
_config = {
|
||||
"feature": self.get_feature_config(),
|
||||
}
|
||||
if config is not None:
|
||||
_config.update(config)
|
||||
super().__init__(config=_config, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_feature_config():
|
||||
# NOTE:
|
||||
# Alpha360 tries to provide a dataset with original price data
|
||||
# the original price data includes the prices and volume in the last 60 days.
|
||||
# To make it easier to learn models from this dataset, all the prices and volume
|
||||
# are normalized by the latest price and volume data ( dividing by $close, $volume)
|
||||
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
|
||||
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($close, %d)/$close" % i]
|
||||
names += ["CLOSE%d" % i]
|
||||
fields += ["$close/$close"]
|
||||
names += ["CLOSE0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($open, %d)/$close" % i]
|
||||
names += ["OPEN%d" % i]
|
||||
fields += ["$open/$close"]
|
||||
names += ["OPEN0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($high, %d)/$close" % i]
|
||||
names += ["HIGH%d" % i]
|
||||
fields += ["$high/$close"]
|
||||
names += ["HIGH0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($low, %d)/$close" % i]
|
||||
names += ["LOW%d" % i]
|
||||
fields += ["$low/$close"]
|
||||
names += ["LOW0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($vwap, %d)/$close" % i]
|
||||
names += ["VWAP%d" % i]
|
||||
fields += ["$vwap/$close"]
|
||||
names += ["VWAP0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % i]
|
||||
names += ["VOLUME%d" % i]
|
||||
fields += ["$volume/($volume+1e-12)"]
|
||||
names += ["VOLUME0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class Alpha158DL(QlibDataLoader):
|
||||
"""Dataloader to get Alpha158"""
|
||||
|
||||
def __init__(self, config=None, **kwargs):
|
||||
_config = {
|
||||
"feature": self.get_feature_config(),
|
||||
}
|
||||
if config is not None:
|
||||
_config.update(config)
|
||||
super().__init__(config=_config, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_feature_config(
|
||||
config={
|
||||
"kbar": {},
|
||||
"price": {
|
||||
"windows": [0],
|
||||
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % d if d != 0 else "$volume/($volume+1e-12)" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
|
||||
def use(x):
|
||||
return x not in exclude and (include is None or x in include)
|
||||
|
||||
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
|
||||
if use("ROC"):
|
||||
# https://www.investopedia.com/terms/r/rateofchange.asp
|
||||
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
|
||||
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
# The rate of close price change in the past d days, divided by latest close price to remove unit
|
||||
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
# The R-sqaure value of linear regression for the past d days, represent the trend linear
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
# The max price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
# The low price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
# Used with MIN and MAX
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
# Get the percentile of current close price in past d day's close price.
|
||||
# Represent the current price level comparing to past N days, add additional information to moving average.
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
# Represent the price position between upper and lower resistent price for past d days.
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
# The number of days between current date and previous highest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
# The number of days between current date and previous lowest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
# The time period between previous lowest-price date occur after highest price date.
|
||||
# Large value suggest downward momemtum.
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
# The correlation between absolute close price and log scaled trading volume
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
# The correlation between price change ratio and volume change ratio
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
# The percentage of days in past d days that price go up.
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
# The percentage of days in past d days that price go down.
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
# The diff between past up day and past down day
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
# The total gain / the absolute total price changed
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
# The total lose / the absolute total price changed
|
||||
# Can be derived from SUMP by SUMN = 1 - SUMP
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
# The diff ratio between total gain and total lose
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
# The standard deviation for volume in past d days.
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
# The volume weighted price change volatility
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
# The diff ratio between total volume increase and total volume decrease
|
||||
# RSI indicator for volume
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
@@ -243,7 +243,7 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
trunc_days: int = None,
|
||||
rolling_ext_days: int = 0,
|
||||
exp_name: Union[str, InternalData],
|
||||
segments: Union[Dict[Text, Tuple], float],
|
||||
segments: Union[Dict[Text, Tuple], float, str],
|
||||
hist_step_n: int = 10,
|
||||
task_mode: str = MetaTask.PROC_MODE_FULL,
|
||||
fill_method: str = "max",
|
||||
@@ -271,12 +271,16 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
- str: the name of the experiment to store the performance of data
|
||||
- InternalData: a prepared internal data
|
||||
segments: Union[Dict[Text, Tuple], float]
|
||||
the segments to divide data
|
||||
both left and right
|
||||
if the segment is a Dict
|
||||
the segments to divide data
|
||||
both left and right are included
|
||||
if segments is a float:
|
||||
the float represents the percentage of data for training
|
||||
if segments is a string:
|
||||
it will try its best to put its data in training and ensure that the date `segments` is in the test set
|
||||
hist_step_n: int
|
||||
length of historical steps for the meta infomation
|
||||
Number of steps of the data similarity information
|
||||
task_mode : str
|
||||
Please refer to the docs of MetaTask
|
||||
"""
|
||||
@@ -383,10 +387,30 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
if isinstance(self.segments, float):
|
||||
train_task_n = int(len(self.meta_task_l) * self.segments)
|
||||
if segment == "train":
|
||||
return self.meta_task_l[:train_task_n]
|
||||
train_tasks = self.meta_task_l[:train_task_n]
|
||||
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
|
||||
return train_tasks
|
||||
elif segment == "test":
|
||||
return self.meta_task_l[train_task_n:]
|
||||
test_tasks = self.meta_task_l[train_task_n:]
|
||||
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
|
||||
return test_tasks
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
elif isinstance(self.segments, str):
|
||||
train_tasks = []
|
||||
test_tasks = []
|
||||
for t in self.meta_task_l:
|
||||
test_end = t.task["dataset"]["kwargs"]["segments"]["test"][1]
|
||||
if test_end is None or pd.Timestamp(test_end) < pd.Timestamp(self.segments):
|
||||
train_tasks.append(t)
|
||||
else:
|
||||
test_tasks.append(t)
|
||||
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
|
||||
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
|
||||
if segment == "train":
|
||||
return train_tasks
|
||||
elif segment == "test":
|
||||
return test_tasks
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@@ -53,7 +53,12 @@ class MetaModelDS(MetaTaskModel):
|
||||
max_epoch=100,
|
||||
seed=43,
|
||||
alpha=0.0,
|
||||
loss_skip_thresh=50,
|
||||
):
|
||||
"""
|
||||
loss_skip_size: int
|
||||
The number of threshold to skip the loss calculation for each day.
|
||||
"""
|
||||
self.step = step
|
||||
self.hist_step_n = hist_step_n
|
||||
self.clip_method = clip_method
|
||||
@@ -63,6 +68,7 @@ class MetaModelDS(MetaTaskModel):
|
||||
self.max_epoch = max_epoch
|
||||
self.fitted = False
|
||||
self.alpha = alpha
|
||||
self.loss_skip_thresh = loss_skip_thresh
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
|
||||
@@ -88,12 +94,14 @@ class MetaModelDS(MetaTaskModel):
|
||||
criterion = nn.MSELoss()
|
||||
loss = criterion(pred, meta_input["y_test"])
|
||||
elif self.criterion == "ic_loss":
|
||||
criterion = ICLoss()
|
||||
criterion = ICLoss(self.loss_skip_thresh)
|
||||
try:
|
||||
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"], skip_size=50)
|
||||
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"])
|
||||
except ValueError as e:
|
||||
get_module_logger("MetaModelDS").warning(f"Exception `{e}` when calculating IC loss")
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unknown criterion: {self.criterion}")
|
||||
|
||||
assert not np.isnan(loss.detach().item()), "NaN loss!"
|
||||
|
||||
|
||||
@@ -10,7 +10,11 @@ from qlib.log import get_module_logger
|
||||
|
||||
|
||||
class ICLoss(nn.Module):
|
||||
def forward(self, pred, y, idx, skip_size=50):
|
||||
def __init__(self, skip_size=50):
|
||||
super().__init__()
|
||||
self.skip_size = skip_size
|
||||
|
||||
def forward(self, pred, y, idx):
|
||||
"""forward.
|
||||
FIXME:
|
||||
- Some times it will be a slightly different from the result from `pandas.corr()`
|
||||
@@ -33,7 +37,7 @@ class ICLoss(nn.Module):
|
||||
skip_n = 0
|
||||
for start_i, end_i in zip(diff_point, diff_point[1:]):
|
||||
pred_focus = pred[start_i:end_i] # TODO: just for fake
|
||||
if pred_focus.shape[0] < skip_size:
|
||||
if pred_focus.shape[0] < self.skip_size:
|
||||
# skip some days which have very small amount of stock.
|
||||
skip_n += 1
|
||||
continue
|
||||
@@ -50,6 +54,7 @@ class ICLoss(nn.Module):
|
||||
)
|
||||
ic_all += ic_day
|
||||
if len(diff_point) - 1 - skip_n <= 0:
|
||||
__import__("ipdb").set_trace()
|
||||
raise ValueError("No enough data for calculating IC")
|
||||
if skip_n > 0:
|
||||
get_module_logger("ICLoss").info(
|
||||
|
||||
@@ -33,7 +33,7 @@ class CatBoostModel(Model, FeatureInt):
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
reweighter=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
|
||||
@@ -31,7 +31,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
sub_weights=None,
|
||||
epochs=100,
|
||||
early_stopping_rounds=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
|
||||
self.num_models = num_models # the number of sub-models
|
||||
|
||||
@@ -63,6 +63,7 @@ class LinearModel(Model):
|
||||
df_train = pd.concat([df_train, df_valid])
|
||||
except KeyError:
|
||||
get_module_logger("LinearModel").info("include_valid=True, but valid does not exist")
|
||||
df_train = df_train.dropna()
|
||||
if df_train.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
if reweighter is not None:
|
||||
|
||||
@@ -56,7 +56,7 @@ class ADARNN(Model):
|
||||
n_splits=2,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**_
|
||||
**_,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADARNN")
|
||||
@@ -154,10 +154,7 @@ class ADARNN(Model):
|
||||
self.model.train()
|
||||
criterion = nn.MSELoss()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||
len_loader = np.inf
|
||||
for loader in train_loader_list:
|
||||
if len(loader) < len_loader:
|
||||
len_loader = len(loader)
|
||||
out_weight_list = None
|
||||
for data_all in zip(*train_loader_list):
|
||||
# for data_all in zip(*train_loader_list):
|
||||
self.train_optimizer.zero_grad()
|
||||
@@ -571,6 +568,7 @@ class TransferLoss:
|
||||
Returns:
|
||||
[tensor] -- transfer loss
|
||||
"""
|
||||
loss = None
|
||||
if self.loss_type in ("mmd_lin", "mmd"):
|
||||
mmdloss = MMD_loss(kernel_type="linear")
|
||||
loss = mmdloss(X, Y)
|
||||
|
||||
@@ -63,7 +63,7 @@ class ADD(Model):
|
||||
mu=0.05,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADD")
|
||||
|
||||
@@ -52,7 +52,7 @@ class ALSTM(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ALSTM")
|
||||
|
||||
@@ -56,7 +56,7 @@ class ALSTM(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ALSTM")
|
||||
@@ -160,6 +160,10 @@ class ALSTM(Model):
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
elif self.metric == "mse":
|
||||
mask = ~torch.isnan(label)
|
||||
weight = torch.ones_like(label)
|
||||
return -self.mse(pred[mask], label[mask], weight[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class GATs(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GATs")
|
||||
|
||||
@@ -73,7 +73,7 @@ class GATs(Model):
|
||||
GPU=0,
|
||||
n_jobs=10,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GATs")
|
||||
|
||||
358
qlib/contrib/model/pytorch_general_nn.py
Normal file
358
qlib/contrib/model/pytorch_general_nn.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...utils import (
|
||||
init_instance_by_config,
|
||||
get_or_create_path,
|
||||
)
|
||||
from ...log import get_module_logger
|
||||
|
||||
from ...model.utils import ConcatDataset
|
||||
|
||||
|
||||
class GeneralPTNN(Model):
|
||||
"""
|
||||
Motivation:
|
||||
We want to provide a Qlib General Pytorch Model Adaptor
|
||||
You can reuse it for all kinds of Pytorch models.
|
||||
It should include the training and predict process
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
weight_decay=0.0,
|
||||
optimizer="adam",
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
|
||||
pt_model_kwargs={
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.0,
|
||||
},
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GeneralPTNN")
|
||||
self.logger.info("GeneralPTNN pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.weight_decay = weight_decay
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.seed = seed
|
||||
|
||||
self.pt_model_uri, self.pt_model_kwargs = pt_model_uri, pt_model_kwargs
|
||||
self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs})
|
||||
|
||||
self.logger.info(
|
||||
"GeneralPTNN parameters setting:"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nweight_decay : {}"
|
||||
"\nseed : {}"
|
||||
"\npt_model_uri: {}"
|
||||
"\npt_model_kwargs: {}".format(
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
weight_decay,
|
||||
seed,
|
||||
pt_model_uri,
|
||||
pt_model_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.logger.info("model:\n{:}".format(self.dnn_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.dnn_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label, weight):
|
||||
loss = weight * (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label, weight=None):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if weight is None:
|
||||
weight = torch.ones_like(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask], weight[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def _get_fl(self, data: torch.Tensor):
|
||||
"""
|
||||
get feature and label from data
|
||||
- Handle the different data shape of time series and tabular data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : torch.Tensor
|
||||
input data which maybe 3 dimension or 2 dimension
|
||||
- 3dim: [batch_size, time_step, feature_dim]
|
||||
- 2dim: [batch_size, feature_dim]
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
if data.dim() == 3:
|
||||
# it is a time series dataset
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
elif data.dim() == 2:
|
||||
# it is a tabular dataset
|
||||
feature = data[:, 0:-1].to(self.device)
|
||||
label = data[:, -1].to(self.device)
|
||||
else:
|
||||
raise ValueError("Unsupported data shape.")
|
||||
return feature, label
|
||||
|
||||
def train_epoch(self, data_loader):
|
||||
self.dnn_model.train()
|
||||
|
||||
for data, weight in data_loader:
|
||||
feature, label = self._get_fl(data)
|
||||
|
||||
pred = self.dnn_model(feature.float())
|
||||
loss = self.loss_fn(pred, label, weight.to(self.device))
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.dnn_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_loader):
|
||||
self.dnn_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
for data, weight in data_loader:
|
||||
feature, label = self._get_fl(data)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.dnn_model(feature.float())
|
||||
loss = self.loss_fn(pred, label, weight.to(self.device))
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: Union[DatasetH, TSDatasetH],
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
reweighter=None,
|
||||
):
|
||||
ists = isinstance(dataset, TSDatasetH) # is this time series dataset
|
||||
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if dl_train.empty or dl_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
if reweighter is None:
|
||||
wl_train = np.ones(len(dl_train))
|
||||
wl_valid = np.ones(len(dl_valid))
|
||||
elif isinstance(reweighter, Reweighter):
|
||||
wl_train = reweighter.reweight(dl_train)
|
||||
wl_valid = reweighter.reweight(dl_valid)
|
||||
else:
|
||||
raise ValueError("Unsupported reweighter type.")
|
||||
|
||||
# Preprocess for data. To align to Dataset Interface for DataLoader
|
||||
if ists:
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
else:
|
||||
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
|
||||
dl_train = dl_train.values
|
||||
dl_valid = dl_valid.values
|
||||
|
||||
train_loader = DataLoader(
|
||||
ConcatDataset(dl_train, wl_train),
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.n_jobs,
|
||||
drop_last=True,
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
ConcatDataset(dl_valid, wl_valid),
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.n_jobs,
|
||||
drop_last=True,
|
||||
)
|
||||
del dl_train, dl_valid, wl_train, wl_valid
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(train_loader)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(train_loader)
|
||||
val_loss, val_score = self.test_epoch(valid_loader)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if step == 0:
|
||||
best_param = copy.deepcopy(self.dnn_model.state_dict())
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.dnn_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.dnn_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(
|
||||
self,
|
||||
dataset: Union[DatasetH, TSDatasetH],
|
||||
batch_size=None,
|
||||
n_jobs=None,
|
||||
):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
|
||||
if isinstance(dataset, TSDatasetH):
|
||||
dl_test.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
index = dl_test.get_index()
|
||||
else:
|
||||
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
|
||||
index = dl_test.index
|
||||
dl_test = dl_test.values
|
||||
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.dnn_model.eval()
|
||||
preds = []
|
||||
|
||||
for data in test_loader:
|
||||
feature, _ = self._get_fl(data)
|
||||
feature = feature.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.dnn_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
preds_concat = np.concatenate(preds)
|
||||
if preds_concat.ndim != 1:
|
||||
preds_concat = preds_concat.ravel()
|
||||
|
||||
return pd.Series(preds_concat, index=index)
|
||||
@@ -1,25 +1,25 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import copy
|
||||
from typing import Text, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from qlib.workflow import R
|
||||
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...log import get_module_logger
|
||||
from ...model.base import Model
|
||||
from ...utils import get_or_create_path
|
||||
from .pytorch_utils import count_parameters
|
||||
|
||||
|
||||
class GRU(Model):
|
||||
@@ -52,7 +52,7 @@ class GRU(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GRU")
|
||||
@@ -212,16 +212,31 @@ class GRU(Model):
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
# prepare training and validation data
|
||||
dfs = {
|
||||
k: dataset.prepare(
|
||||
k,
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
for k in ["train", "valid"]
|
||||
if k in dataset.segments
|
||||
}
|
||||
df_train, df_valid = dfs.get("train", pd.DataFrame()), dfs.get("valid", pd.DataFrame())
|
||||
|
||||
# check if training data is empty
|
||||
if df_train.empty:
|
||||
raise ValueError("Empty training data from dataset, please check your dataset config.")
|
||||
|
||||
df_train = df_train.dropna()
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
# check if validation data is provided
|
||||
if not df_valid.empty:
|
||||
df_valid = df_valid.dropna()
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
else:
|
||||
x_valid, y_valid = None, None
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
@@ -235,32 +250,42 @@ class GRU(Model):
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
best_param = copy.deepcopy(self.gru_model.state_dict())
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.gru_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
# evaluate on validation data if provided
|
||||
if x_valid is not None and y_valid is not None:
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.gru_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.gru_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
# Logging
|
||||
rec = R.get_recorder()
|
||||
for k, v_l in evals_result.items():
|
||||
for i, v in enumerate(v_l):
|
||||
rec.log_metrics(step=i, **{k: v})
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ class GRU(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GRU")
|
||||
|
||||
@@ -59,7 +59,7 @@ class HIST(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("HIST")
|
||||
@@ -256,7 +256,7 @@ class HIST(Model):
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
if not os.path.exists(self.stock2concept):
|
||||
url = "http://fintech.msra.cn/stock_data/downloads/qlib_csi300_stock2concept.npy"
|
||||
url = "https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/qlib_csi300_stock2concept.npy"
|
||||
urllib.request.urlretrieve(url, self.stock2concept)
|
||||
|
||||
stock_index = np.load(self.stock_index, allow_pickle=True).item()
|
||||
|
||||
@@ -55,7 +55,7 @@ class IGMTF(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("IGMTF")
|
||||
|
||||
@@ -255,7 +255,7 @@ class KRNN(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("KRNN")
|
||||
|
||||
@@ -44,7 +44,7 @@ class LocalformerModel(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
|
||||
@@ -42,7 +42,7 @@ class LocalformerModel(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
|
||||
@@ -51,7 +51,7 @@ class LSTM(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("LSTM")
|
||||
|
||||
@@ -53,7 +53,7 @@ class LSTM(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("LSTM")
|
||||
|
||||
@@ -35,7 +35,7 @@ class SandwichModel(nn.Module):
|
||||
rnn_layers,
|
||||
dropout,
|
||||
device,
|
||||
**params
|
||||
**params,
|
||||
):
|
||||
"""Build a Sandwich model
|
||||
|
||||
@@ -129,7 +129,7 @@ class Sandwich(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("Sandwich")
|
||||
|
||||
@@ -212,7 +212,7 @@ class SFM(Model):
|
||||
optimizer="gd",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("SFM")
|
||||
|
||||
@@ -56,7 +56,7 @@ class TCN(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCN")
|
||||
|
||||
@@ -54,7 +54,7 @@ class TCN(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCN")
|
||||
|
||||
@@ -58,7 +58,7 @@ class TCTS(Model):
|
||||
mode="soft",
|
||||
seed=None,
|
||||
lowest_valid_performance=0.993,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCTS")
|
||||
|
||||
@@ -43,7 +43,7 @@ class TransformerModel(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
|
||||
@@ -41,7 +41,7 @@ class TransformerModel(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
|
||||
@@ -28,7 +28,7 @@ class XGBModel(Model, FeatureInt):
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
reweighter=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
@@ -63,7 +63,7 @@ class XGBModel(Model, FeatureInt):
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import yaml
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
import shutil
|
||||
from ruamel.yaml import YAML
|
||||
from ...backtest.account import Account
|
||||
from .user import User
|
||||
from .utils import load_instance, save_instance
|
||||
@@ -110,7 +110,8 @@ class UserManager:
|
||||
raise ValueError("User data for {} already exists".format(user_id))
|
||||
|
||||
with config_file.open("r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(fp)
|
||||
# load model
|
||||
model = init_instance_by_config(config["model"])
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
|
||||
import pathlib
|
||||
import pickle
|
||||
import yaml
|
||||
import pandas as pd
|
||||
from ruamel.yaml import YAML
|
||||
from ...data import D
|
||||
from ...config import C
|
||||
from ...log import get_module_logger
|
||||
@@ -91,7 +91,8 @@ def prepare(um, today, user_id, exchange_config=None):
|
||||
dates.append(get_next_trading_date(dates[-1], future=True))
|
||||
if exchange_config:
|
||||
with pathlib.Path(exchange_config).open("r") as fp:
|
||||
exchange_paras = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
exchange_paras = yaml.load(fp)
|
||||
else:
|
||||
exchange_paras = {}
|
||||
trade_exchange = Exchange(trade_dates=dates, **exchange_paras)
|
||||
|
||||
@@ -1,5 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Here we have a comprehensive set of analysis classes.
|
||||
|
||||
Here is an example.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.contrib.report.data.ana import FeaMeanStd
|
||||
fa = FeaMeanStd(ret_df)
|
||||
fa.plot_all(wspace=0.3, sub_figsize=(12, 3), col_n=5)
|
||||
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from qlib.contrib.report.data.base import FeaAnalyser
|
||||
@@ -152,6 +164,7 @@ class FeaSkewTurt(NumFeaAnalyser):
|
||||
self._kurt[col].plot(ax=right_ax, label="kurt", color="green")
|
||||
right_ax.set_xlabel("")
|
||||
right_ax.set_ylabel("kurt")
|
||||
right_ax.grid(None) # set the grid to None to avoid two layer of grid
|
||||
|
||||
h1, l1 = ax.get_legend_handles_labels()
|
||||
h2, l2 = right_ax.get_legend_handles_labels()
|
||||
@@ -171,12 +184,15 @@ class FeaMeanStd(NumFeaAnalyser):
|
||||
ax.set_xlabel("")
|
||||
ax.set_ylabel("mean")
|
||||
ax.legend()
|
||||
ax.tick_params(axis="x", rotation=90)
|
||||
|
||||
right_ax = ax.twinx()
|
||||
|
||||
self._std[col].plot(ax=right_ax, label="std", color="green")
|
||||
right_ax.set_xlabel("")
|
||||
right_ax.set_ylabel("std")
|
||||
right_ax.tick_params(axis="x", rotation=90)
|
||||
right_ax.grid(None) # set the grid to None to avoid two layer of grid
|
||||
|
||||
h1, l1 = ax.get_legend_handles_labels()
|
||||
h2, l2 = right_ax.get_legend_handles_labels()
|
||||
|
||||
@@ -14,6 +14,24 @@ from qlib.contrib.report.utils import sub_fig_generator
|
||||
|
||||
class FeaAnalyser:
|
||||
def __init__(self, dataset: pd.DataFrame):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : pd.DataFrame
|
||||
|
||||
We often have multiple columns for dataset. Each column corresponds to one sub figure.
|
||||
There will be a datatime column in the index levels.
|
||||
Aggretation will be used for more summarized metrics overtime.
|
||||
Here is an example of data:
|
||||
|
||||
.. code-block::
|
||||
|
||||
return
|
||||
datetime instrument
|
||||
2007-02-06 equity_tpx 0.010087
|
||||
equity_spx 0.000786
|
||||
"""
|
||||
self._dataset = dataset
|
||||
with TimeInspector.logt("calc_stat_values"):
|
||||
self.calc_stat_values()
|
||||
|
||||
@@ -176,7 +176,7 @@ class HeatmapGraph(BaseGraph):
|
||||
x=self._df.columns,
|
||||
y=self._df.index,
|
||||
z=self._df.values.tolist(),
|
||||
**self._graph_kwargs
|
||||
**self._graph_kwargs,
|
||||
)
|
||||
]
|
||||
return _data
|
||||
@@ -213,7 +213,7 @@ class SubplotsGraph:
|
||||
sub_graph_layout: dict = None,
|
||||
sub_graph_data: list = None,
|
||||
subplots_kwargs: dict = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -355,7 +355,7 @@ class SubplotsGraph:
|
||||
df=self._df.loc[:, [column_name]],
|
||||
name_dict={column_name: temp_name},
|
||||
graph_kwargs=_graph_kwargs,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise TypeError()
|
||||
|
||||
@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
|
||||
def sub_fig_generator(sub_figsize=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
|
||||
"""sub_fig_generator.
|
||||
it will return a generator, each row contains <col_n> sub graph
|
||||
|
||||
@@ -13,7 +13,7 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sub_fs :
|
||||
sub_figsize :
|
||||
the figure size of each subgraph in <col_n> * <row_n> subgraphs
|
||||
col_n :
|
||||
the number of subgraph in each row; It will generating a new graph after generating <col_n> of subgraphs.
|
||||
@@ -33,7 +33,7 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None
|
||||
|
||||
while True:
|
||||
fig, axes = plt.subplots(
|
||||
row_n, col_n, figsize=(sub_fs[0] * col_n, sub_fs[1] * row_n), sharex=sharex, sharey=sharey
|
||||
row_n, col_n, figsize=(sub_figsize[0] * col_n, sub_figsize[1] * row_n), sharex=sharex, sharey=sharey
|
||||
)
|
||||
plt.subplots_adjust(wspace=wspace, hspace=hspace)
|
||||
axes = axes.reshape(row_n, col_n)
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# Licensed under the MIT License.
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from ruamel.yaml import YAML
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
import yaml
|
||||
|
||||
from qlib import auto_init
|
||||
from qlib.log import get_module_logger
|
||||
@@ -73,8 +73,8 @@ class Rolling:
|
||||
The horizon of the prediction target.
|
||||
This is used to override the prediction horizon of the file.
|
||||
h_path : Optional[str]
|
||||
the dumped data handler;
|
||||
It may come from other data source. It will override the data handler in the config.
|
||||
It is other data source that is dumped as a handler. It will override the data handler section in the config.
|
||||
If it is not given, it will create a customized cache for the handler when `enable_handler_cache=True`
|
||||
test_end : Optional[str]
|
||||
the test end for the data. It is typically used together with the handler
|
||||
You can do the same thing with task_ext_conf in a more complicated way
|
||||
@@ -117,9 +117,10 @@ class Rolling:
|
||||
|
||||
def _raw_conf(self) -> dict:
|
||||
with self.conf_path.open("r") as f:
|
||||
return yaml.safe_load(f)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
return yaml.load(f)
|
||||
|
||||
def _replace_hanler_with_cache(self, task: dict):
|
||||
def _replace_handler_with_cache(self, task: dict):
|
||||
"""
|
||||
Due to the data processing part in original rolling is slow. So we have to
|
||||
This class tries to add more feature
|
||||
@@ -159,13 +160,20 @@ class Rolling:
|
||||
# - get horizon automatically from the expression!!!!
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
else:
|
||||
self.logger.info("The prediction horizon is overrided")
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [
|
||||
"Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1)
|
||||
]
|
||||
if enable_handler_cache and self.h_path is not None:
|
||||
self.logger.info("Fail to override the horizon due to data handler cache")
|
||||
else:
|
||||
self.logger.info("The prediction horizon is overrided")
|
||||
if isinstance(task["dataset"]["kwargs"]["handler"], dict):
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [
|
||||
"Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1)
|
||||
]
|
||||
else:
|
||||
self.logger.warning("Try to automatically configure the lablel but failed.")
|
||||
|
||||
if enable_handler_cache:
|
||||
task = self._replace_hanler_with_cache(task)
|
||||
if self.h_path is not None or enable_handler_cache:
|
||||
# if we already have provided data source or we want to create one
|
||||
task = self._replace_handler_with_cache(task)
|
||||
task = self._update_start_end_time(task)
|
||||
|
||||
if self.task_ext_conf is not None:
|
||||
@@ -173,6 +181,16 @@ class Rolling:
|
||||
self.logger.info(task)
|
||||
return task
|
||||
|
||||
def run_basic_task(self):
|
||||
"""
|
||||
Run the basic task without rolling.
|
||||
This is for fast testing for model tunning.
|
||||
"""
|
||||
task = self.basic_task()
|
||||
print(task)
|
||||
trainer = TrainerR(experiment_name=self.exp_name)
|
||||
trainer([task])
|
||||
|
||||
def get_task_list(self) -> List[dict]:
|
||||
"""return a batch of tasks for rolling."""
|
||||
task = self.basic_task()
|
||||
|
||||
@@ -80,6 +80,11 @@ class DDGDA(Rolling):
|
||||
sim_task_model: UTIL_MODEL_TYPE = "gbdt",
|
||||
meta_1st_train_end: Optional[str] = None,
|
||||
alpha: float = 0.01,
|
||||
loss_skip_thresh: int = 50,
|
||||
fea_imp_n: Optional[int] = 30,
|
||||
meta_data_proc: Optional[str] = "V01",
|
||||
segments: Union[float, str] = 0.62,
|
||||
hist_step_n: int = 30,
|
||||
working_dir: Optional[Union[str, Path]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -94,6 +99,15 @@ class DDGDA(Rolling):
|
||||
alpha: float
|
||||
Setting the L2 regularization for ridge
|
||||
The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)
|
||||
loss_skip_thresh: int
|
||||
The thresh to skip the loss calculation for each day. If the number of item is less than it, it will skip the loss on that day.
|
||||
meta_data_proc : Optional[str]
|
||||
How we process the meta dataset for learning meta model.
|
||||
segments : Union[float, str]
|
||||
if segments is a float:
|
||||
The ratio of training data in the meta task dataset
|
||||
if segments is a string:
|
||||
it will try its best to put its data in training and ensure that the date `segments` is in the test set
|
||||
"""
|
||||
# NOTE:
|
||||
# the horizon must match the meaning in the base task template
|
||||
@@ -104,14 +118,22 @@ class DDGDA(Rolling):
|
||||
super().__init__(**kwargs)
|
||||
self.working_dir = self.conf_path.parent if working_dir is None else Path(working_dir)
|
||||
self.proxy_hd = self.working_dir / "handler_proxy.pkl"
|
||||
self.fea_imp_n = fea_imp_n
|
||||
self.meta_data_proc = meta_data_proc
|
||||
self.loss_skip_thresh = loss_skip_thresh
|
||||
self.segments = segments
|
||||
self.hist_step_n = hist_step_n
|
||||
|
||||
def _adjust_task(self, task: dict, astype: UTIL_MODEL_TYPE):
|
||||
"""
|
||||
some task are use for special purpose.
|
||||
Base on the original task, we need to do some extra things.
|
||||
|
||||
For example:
|
||||
- GBDT for calculating feature importance
|
||||
- Linear or GBDT for calculating similarity
|
||||
- Datset (well processed) that aligned to Linear that for meta learning
|
||||
|
||||
So we may need to change the dataset and model for the special purpose and other settings remains the same.
|
||||
"""
|
||||
# NOTE: here is just for aligning with previous implementation
|
||||
# It is not necessary for the current implementation
|
||||
@@ -119,12 +141,16 @@ class DDGDA(Rolling):
|
||||
if astype == "gbdt":
|
||||
task["model"] = LGBM_MODEL
|
||||
if isinstance(handler, dict):
|
||||
# We don't need preprocessing when using GBDT model
|
||||
for k in ["infer_processors", "learn_processors"]:
|
||||
if k in handler.setdefault("kwargs", {}):
|
||||
handler["kwargs"].pop(k)
|
||||
elif astype == "linear":
|
||||
task["model"] = LINEAR_MODEL
|
||||
handler["kwargs"].update(PROC_ARGS)
|
||||
if isinstance(handler, dict):
|
||||
handler["kwargs"].update(PROC_ARGS)
|
||||
else:
|
||||
self.logger.warning("The handler can't be adjusted.")
|
||||
else:
|
||||
raise ValueError(f"astype not supported: {astype}")
|
||||
return task
|
||||
@@ -155,12 +181,15 @@ class DDGDA(Rolling):
|
||||
The meta model will be trained upon the proxy forecasting model.
|
||||
This dataset is for the proxy forecasting model.
|
||||
"""
|
||||
topk = 30
|
||||
fi = self._get_feature_importance()
|
||||
col_selected = fi.nlargest(topk)
|
||||
|
||||
# NOTE: adjusting to `self.sim_task_model` just for aligning with previous implementation.
|
||||
# In previous version. The data for proxy model is using sim_task_model's way for processing
|
||||
task = self._adjust_task(self.basic_task(enable_handler_cache=False), self.sim_task_model)
|
||||
task = replace_task_handler_with_cache(task, self.working_dir)
|
||||
# if self.meta_data_proc is not None:
|
||||
# else:
|
||||
# # Otherwise, we don't need futher processing
|
||||
# task = self.basic_task()
|
||||
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -168,12 +197,18 @@ class DDGDA(Rolling):
|
||||
feature_df = prep_ds["feature"]
|
||||
label_df = prep_ds["label"]
|
||||
|
||||
feature_selected = feature_df.loc[:, col_selected.index]
|
||||
if self.fea_imp_n is not None:
|
||||
fi = self._get_feature_importance()
|
||||
col_selected = fi.nlargest(self.fea_imp_n)
|
||||
feature_selected = feature_df.loc[:, col_selected.index]
|
||||
else:
|
||||
feature_selected = feature_df
|
||||
|
||||
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
|
||||
lambda df: (df - df.mean()).div(df.std())
|
||||
)
|
||||
feature_selected = feature_selected.fillna(0.0)
|
||||
if self.meta_data_proc == "V01":
|
||||
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
|
||||
lambda df: (df - df.mean()).div(df.std())
|
||||
)
|
||||
feature_selected = feature_selected.fillna(0.0)
|
||||
|
||||
df_all = {
|
||||
"label": label_df.reindex(feature_selected.index),
|
||||
@@ -223,7 +258,10 @@ class DDGDA(Rolling):
|
||||
# 1) leverage the simplified proxy forecasting model to train meta model.
|
||||
# - Only the dataset part is important, in current version of meta model will integrate the
|
||||
|
||||
# the train_start for training meta model does not necessarily align with final rolling
|
||||
# NOTE:
|
||||
# - The train_start for training meta model does not necessarily align with final rolling
|
||||
# But please select a right time to make sure the finnal rolling tasks are not leaked in the training data.
|
||||
# - The test_start is automatically aligned to the next day of test_end. Validation is ignored.
|
||||
train_start = "2008-01-01" if self.train_start is None else self.train_start
|
||||
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
|
||||
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
@@ -249,9 +287,9 @@ class DDGDA(Rolling):
|
||||
kwargs = dict(
|
||||
task_tpl=proxy_forecast_model_task,
|
||||
step=self.step,
|
||||
segments=0.62, # keep test period consistent with the dataset yaml
|
||||
segments=self.segments, # keep test period consistent with the dataset yaml
|
||||
trunc_days=1 + self.horizon,
|
||||
hist_step_n=30,
|
||||
hist_step_n=self.hist_step_n,
|
||||
fill_method=fill_method,
|
||||
rolling_ext_days=0,
|
||||
)
|
||||
@@ -268,7 +306,13 @@ class DDGDA(Rolling):
|
||||
with R.start(experiment_name=self.meta_exp_name):
|
||||
R.log_params(**kwargs)
|
||||
mm = MetaModelDS(
|
||||
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=30, seed=43, alpha=self.alpha
|
||||
step=self.step,
|
||||
hist_step_n=kwargs["hist_step_n"],
|
||||
lr=0.001,
|
||||
max_epoch=30,
|
||||
seed=43,
|
||||
alpha=self.alpha,
|
||||
loss_skip_thresh=self.loss_skip_thresh,
|
||||
)
|
||||
mm.fit(md)
|
||||
R.save_objects(model=mm)
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import yaml
|
||||
import copy
|
||||
import os
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
|
||||
class TunerConfigManager:
|
||||
@@ -16,7 +16,8 @@ class TunerConfigManager:
|
||||
self.config_path = config_path
|
||||
|
||||
with open(config_path) as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(fp)
|
||||
self.config = copy.deepcopy(config)
|
||||
|
||||
self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self)
|
||||
|
||||
@@ -35,7 +35,7 @@ class Client:
|
||||
def connect_server(self):
|
||||
"""Connect to server."""
|
||||
try:
|
||||
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
|
||||
self.sio.connect(f"ws://{self.server_host}:{self.server_port}")
|
||||
except socketio.exceptions.ConnectionError:
|
||||
self.logger.error("Cannot connect to server - check your network or server status")
|
||||
|
||||
|
||||
@@ -536,7 +536,6 @@ class DatasetProvider(abc.ABC):
|
||||
"""
|
||||
if len(fields) == 0:
|
||||
raise ValueError("fields cannot be empty")
|
||||
fields = fields.copy()
|
||||
column_names = [str(f) for f in fields]
|
||||
return column_names
|
||||
|
||||
@@ -617,7 +616,7 @@ class DatasetProvider(abc.ABC):
|
||||
|
||||
data = pd.DataFrame(obj)
|
||||
if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")):
|
||||
# If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format
|
||||
# If the underlaying provides the data not in datetime format, we'll convert it into datetime format
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
data.index = _calendar[data.index.values.astype(int)]
|
||||
data.index.names = ["datetime"]
|
||||
|
||||
@@ -403,7 +403,7 @@ class TSDataSampler:
|
||||
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
|
||||
axis=0,
|
||||
)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716
|
||||
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
|
||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
import warnings
|
||||
import pandas as pd
|
||||
|
||||
from typing import Tuple, Union, List
|
||||
from typing import Tuple, Union, List, Dict
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
|
||||
@@ -41,6 +41,7 @@ class DataLoader(abc.ABC):
|
||||
----------
|
||||
instruments : str or dict
|
||||
it can either be the market name or the config file of instruments generated by InstrumentProvider.
|
||||
If the value of instruments is None, it means that no filtering is done.
|
||||
start_time : str
|
||||
start of the time range.
|
||||
end_time : str
|
||||
@@ -50,6 +51,11 @@ class DataLoader(abc.ABC):
|
||||
-------
|
||||
pd.DataFrame:
|
||||
data load from the under layer source
|
||||
|
||||
Raise
|
||||
-----
|
||||
KeyError:
|
||||
if the instruments filter is not supported, raise KeyError
|
||||
"""
|
||||
|
||||
|
||||
@@ -247,10 +253,14 @@ class StaticDataLoader(DataLoader, Serializable):
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
self._maybe_load_raw_data()
|
||||
|
||||
# 1) Filter by instruments
|
||||
if instruments is None:
|
||||
df = self._data
|
||||
else:
|
||||
df = self._data.loc(axis=0)[:, instruments]
|
||||
|
||||
# 2) Filter by Datetime
|
||||
if start_time is None and end_time is None:
|
||||
return df # NOTE: avoid copy by loc
|
||||
# pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
|
||||
@@ -275,6 +285,61 @@ class StaticDataLoader(DataLoader, Serializable):
|
||||
self._data = self._config
|
||||
|
||||
|
||||
class NestedDataLoader(DataLoader):
|
||||
"""
|
||||
We have multiple DataLoader, we can use this class to combine them.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader_l: List[Dict], join="left") -> None:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataloader_l : list[dict]
|
||||
A list of dataloader, for exmaple
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
nd = NestedDataLoader(
|
||||
dataloader_l=[
|
||||
{
|
||||
"class": "qlib.contrib.data.loader.Alpha158DL",
|
||||
}, {
|
||||
"class": "qlib.contrib.data.loader.Alpha360DL",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"label": ( ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
join :
|
||||
it will pass to pd.concat when merging it.
|
||||
"""
|
||||
super().__init__()
|
||||
self.data_loader_l = [
|
||||
(dl if isinstance(dl, DataLoader) else init_instance_by_config(dl)) for dl in dataloader_l
|
||||
]
|
||||
self.join = join
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
df_full = None
|
||||
for dl in self.data_loader_l:
|
||||
try:
|
||||
df_current = dl.load(instruments, start_time, end_time)
|
||||
except KeyError:
|
||||
warnings.warn(
|
||||
"If the value of `instruments` cannot be processed, it will set instruments to None to get all the data."
|
||||
)
|
||||
df_current = dl.load(instruments=None, start_time=start_time, end_time=end_time)
|
||||
if df_full is None:
|
||||
df_full = df_current
|
||||
else:
|
||||
df_full = pd.merge(df_full, df_current, left_index=True, right_index=True, how=self.join)
|
||||
return df_full.sort_index(axis=1)
|
||||
|
||||
|
||||
class DataLoaderDH(DataLoader):
|
||||
"""DataLoaderDH
|
||||
DataLoader based on (D)ata (H)andler
|
||||
|
||||
@@ -104,15 +104,24 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
"""
|
||||
|
||||
stock_selector = slice(None)
|
||||
time_selector = slice(None) # by default not filter by time.
|
||||
|
||||
if level is None:
|
||||
# For directly applying.
|
||||
if isinstance(selector, tuple) and self.stock_level < len(selector):
|
||||
# full selector format
|
||||
stock_selector = selector[self.stock_level]
|
||||
time_selector = selector[1 - self.stock_level]
|
||||
elif isinstance(selector, (list, str)) and self.stock_level == 0:
|
||||
# only stock selector
|
||||
stock_selector = selector
|
||||
elif level in ("instrument", self.stock_level):
|
||||
if isinstance(selector, tuple):
|
||||
# NOTE: How could the stock level selector be a tuple?
|
||||
stock_selector = selector[0]
|
||||
raise TypeError(
|
||||
"I forget why would this case appear. But I think it does not make sense. So we raise a error for that case."
|
||||
)
|
||||
elif isinstance(selector, (list, str)):
|
||||
stock_selector = selector
|
||||
|
||||
@@ -120,7 +129,7 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
|
||||
|
||||
if stock_selector == slice(None):
|
||||
return self.hash_df
|
||||
return self.hash_df, time_selector
|
||||
|
||||
if isinstance(stock_selector, str):
|
||||
stock_selector = [stock_selector]
|
||||
@@ -129,7 +138,7 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
for each_stock in sorted(stock_selector):
|
||||
if each_stock in self.hash_df:
|
||||
select_dict[each_stock] = self.hash_df[each_stock]
|
||||
return select_dict
|
||||
return select_dict, time_selector
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
@@ -138,10 +147,13 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
|
||||
fetch_stock_df_list, time_selector = self._fetch_hash_df_by_stock(selector=selector, level=level)
|
||||
fetch_stock_df_list = list(fetch_stock_df_list.values())
|
||||
for _index, stock_df in enumerate(fetch_stock_df_list):
|
||||
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
|
||||
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig)
|
||||
fetch_index_df = fetch_df_by_index(
|
||||
df=fetch_col_df, selector=time_selector, level="datetime", fetch_orig=fetch_orig
|
||||
)
|
||||
fetch_stock_df_list[_index] = fetch_index_df
|
||||
if len(fetch_stock_df_list) == 0:
|
||||
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
|
||||
|
||||
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
||||
from qlib.data.dataset import DataHandler
|
||||
|
||||
|
||||
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
def get_level_index(df: pd.DataFrame, level: Union[str, int]) -> int:
|
||||
"""
|
||||
|
||||
get the level index of `df` given `level`
|
||||
|
||||
@@ -164,6 +164,7 @@ class SeriesDFilter(BaseDFilter):
|
||||
timestamp = []
|
||||
_lbool = None
|
||||
_ltime = None
|
||||
_cur_start = None
|
||||
for _ts, _bool in timestamp_series.items():
|
||||
# there is likely to be NAN when the filter series don't have the
|
||||
# bool value, so we just change the NAN into False
|
||||
|
||||
@@ -51,3 +51,6 @@ class MetaTask:
|
||||
Return the **processed** meta_info
|
||||
"""
|
||||
return self.meta_info
|
||||
|
||||
def __repr__(self):
|
||||
return f"MetaTask(task={self.task}, meta_info={self.meta_info})"
|
||||
|
||||
@@ -41,7 +41,7 @@ def _log_task_info(task_config: dict):
|
||||
|
||||
def _exe_task(task_config: dict):
|
||||
rec = R.get_recorder()
|
||||
# model & dataset initiation
|
||||
# model & dataset initialization
|
||||
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
|
||||
reweighter: Reweighter = task_config.get("reweighter", None)
|
||||
|
||||
@@ -7,8 +7,7 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from importlib import import_module
|
||||
|
||||
import yaml
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
|
||||
DELETE_KEY = "_delete_"
|
||||
@@ -57,7 +56,8 @@ def parse_backtest_config(path: str) -> dict:
|
||||
del sys.modules[tmp_module_name]
|
||||
else:
|
||||
with open(tmp_config_file.name) as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(input_stream)
|
||||
|
||||
if "_base_" in config:
|
||||
base_file_name = config.pop("_base_")
|
||||
|
||||
@@ -8,12 +8,12 @@ import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from ruamel.yaml import YAML
|
||||
from typing import cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
@@ -263,6 +263,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_path, "r") as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(input_stream)
|
||||
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
|
||||
@@ -12,15 +12,11 @@ import datetime
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from cryptography.fernet import Fernet
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
class GetData:
|
||||
REMOTE_URL = "https://qlibpublic.blob.core.windows.net/data/default/stock_data"
|
||||
# "?" is not included in the token.
|
||||
TOKEN = b"gAAAAABkmDhojHc0VSCDdNK1MqmRzNLeDFXe5hy8obHpa6SDQh4de6nW5gtzuD-fa6O_WZb0yyqYOL7ndOfJX_751W3xN5YB4-n-P22jK-t6ucoZqhT70KPD0Lf0_P328QPJVZ1gDnjIdjhi2YLOcP4BFTHLNYO0mvzszR8TKm9iT5AKRvuysWnpi8bbYwGU9zAcJK3x9EPL43hOGtxliFHcPNGMBoJW4g_ercdhi0-Qgv5_JLsV-29_MV-_AhuaYvJuN2dEywBy"
|
||||
KEY = "EYcA8cgorA8X9OhyMwVfuFxn_1W3jGk6jCbs3L2oPoA="
|
||||
REMOTE_URL = "https://github.com/SunsetWolf/qlib_dataset/releases/download"
|
||||
|
||||
def __init__(self, delete_zip_file=False):
|
||||
"""
|
||||
@@ -33,9 +29,45 @@ class GetData:
|
||||
self.delete_zip_file = delete_zip_file
|
||||
|
||||
def merge_remote_url(self, file_name: str):
|
||||
fernet = Fernet(self.KEY)
|
||||
token = fernet.decrypt(self.TOKEN).decode()
|
||||
return f"{self.REMOTE_URL}/{file_name}?{token}"
|
||||
"""
|
||||
Generate download links.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_name: str
|
||||
The name of the file to be downloaded.
|
||||
The file name can be accompanied by a version number, (e.g.: v2/qlib_data_simple_cn_1d_latest.zip),
|
||||
if no version number is attached, it will be downloaded from v0 by default.
|
||||
"""
|
||||
return f"{self.REMOTE_URL}/{file_name}" if "/" in file_name else f"{self.REMOTE_URL}/v0/{file_name}"
|
||||
|
||||
def download(self, url: str, target_path: [Path, str]):
|
||||
"""
|
||||
Download a file from the specified url.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url: str
|
||||
The url of the data.
|
||||
target_path: str
|
||||
The location where the data is saved, including the file name.
|
||||
"""
|
||||
file_name = str(target_path).rsplit("/", maxsplit=1)[-1]
|
||||
resp = requests.get(url, stream=True, timeout=60)
|
||||
resp.raise_for_status()
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
chunk_size = 1024
|
||||
logger.warning(
|
||||
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
|
||||
)
|
||||
logger.info(f"{os.path.basename(file_name)} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
fp.write(chunk)
|
||||
p_bar.update(chunk_size)
|
||||
|
||||
def download_data(self, file_name: str, target_dir: [Path, str], delete_old: bool = True):
|
||||
"""
|
||||
@@ -70,21 +102,7 @@ class GetData:
|
||||
target_path = target_dir.joinpath(_target_file_name)
|
||||
|
||||
url = self.merge_remote_url(file_name)
|
||||
resp = requests.get(url, stream=True, timeout=60)
|
||||
resp.raise_for_status()
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
chunk_size = 1024
|
||||
logger.warning(
|
||||
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
|
||||
)
|
||||
logger.info(f"{os.path.basename(file_name)} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
fp.write(chunk)
|
||||
p_bar.update(chunk_size)
|
||||
self.download(url=url, target_path=target_path)
|
||||
|
||||
self._unzip(target_path, target_dir, delete_old)
|
||||
if self.delete_zip_file:
|
||||
@@ -99,7 +117,9 @@ class GetData:
|
||||
return status
|
||||
|
||||
@staticmethod
|
||||
def _unzip(file_path: Path, target_dir: Path, delete_old: bool = True):
|
||||
def _unzip(file_path: [Path, str], target_dir: [Path, str], delete_old: bool = True):
|
||||
file_path = Path(file_path)
|
||||
target_dir = Path(target_dir)
|
||||
if delete_old:
|
||||
logger.warning(
|
||||
f"will delete the old qlib data directory(features, instruments, calendars, features_cache, dataset_cache): {target_dir}"
|
||||
|
||||
@@ -10,7 +10,6 @@ import os
|
||||
import re
|
||||
import copy
|
||||
import json
|
||||
import yaml
|
||||
import redis
|
||||
import bisect
|
||||
import struct
|
||||
@@ -25,7 +24,13 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Callable
|
||||
from packaging import version
|
||||
from .file import get_or_create_path, save_multiple_parts_file, unpack_archive_with_buffer, get_tmp_file_with_buffer
|
||||
from ruamel.yaml import YAML
|
||||
from .file import (
|
||||
get_or_create_path,
|
||||
save_multiple_parts_file,
|
||||
unpack_archive_with_buffer,
|
||||
get_tmp_file_with_buffer,
|
||||
)
|
||||
from ..config import C
|
||||
from ..log import get_module_logger, set_log_with_config
|
||||
|
||||
@@ -37,7 +42,12 @@ is_deprecated_lexsorted_pandas = version.parse(pd.__version__) > version.parse("
|
||||
#################### Server ####################
|
||||
def get_redis_connection():
|
||||
"""get redis connection instance."""
|
||||
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password)
|
||||
return redis.StrictRedis(
|
||||
host=C.redis_host,
|
||||
port=C.redis_port,
|
||||
db=C.redis_task_db,
|
||||
password=C.redis_password,
|
||||
)
|
||||
|
||||
|
||||
#################### Data ####################
|
||||
@@ -96,7 +106,14 @@ def get_period_offset(first_year, period, quarterly):
|
||||
return offset
|
||||
|
||||
|
||||
def read_period_data(index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None):
|
||||
def read_period_data(
|
||||
index_path,
|
||||
data_path,
|
||||
period,
|
||||
cur_date_int: int,
|
||||
quarterly,
|
||||
last_period_index: int = None,
|
||||
):
|
||||
"""
|
||||
At `cur_date`(e.g. 20190102), read the information at `period`(e.g. 201803).
|
||||
Only the updating info before cur_date or at cur_date will be used.
|
||||
@@ -227,12 +244,13 @@ def parse_config(config):
|
||||
if not isinstance(config, str):
|
||||
return config
|
||||
# Check whether config is file
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
if os.path.exists(config):
|
||||
with open(config, "r") as f:
|
||||
return yaml.safe_load(f)
|
||||
return yaml.load(f)
|
||||
# Check whether the str can be parsed
|
||||
try:
|
||||
return yaml.safe_load(config)
|
||||
return yaml.load(config)
|
||||
except BaseException as base_exp:
|
||||
raise ValueError("cannot parse config!") from base_exp
|
||||
|
||||
@@ -273,7 +291,10 @@ def parse_field(field):
|
||||
# \uff09 -> )
|
||||
chinese_punctuation_regex = r"\u3001\uff1a\uff08\uff09"
|
||||
for pattern, new in [
|
||||
(rf"\$\$([\w{chinese_punctuation_regex}]+)", r'PFeature("\1")'), # $$ must be before $
|
||||
(
|
||||
rf"\$\$([\w{chinese_punctuation_regex}]+)",
|
||||
r'PFeature("\1")',
|
||||
), # $$ must be before $
|
||||
(rf"\$([\w{chinese_punctuation_regex}]+)", r'Feature("\1")'),
|
||||
(r"(\w+\s*)\(", r"Operators.\1("),
|
||||
]: # Features # Operators
|
||||
@@ -383,7 +404,14 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
|
||||
return calendar
|
||||
|
||||
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None):
|
||||
def get_date_by_shift(
|
||||
trading_date,
|
||||
shift,
|
||||
future=False,
|
||||
clip_shift=True,
|
||||
freq="day",
|
||||
align: Optional[str] = None,
|
||||
):
|
||||
"""get trading date with shift bias will cur_date
|
||||
e.g. : shift == 1, return next trading date
|
||||
shift == -1, return previous trading date
|
||||
@@ -569,7 +597,38 @@ def exists_qlib_data(qlib_dir):
|
||||
# check instruments
|
||||
code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir()))
|
||||
_instrument = instruments_dir.joinpath("all.txt")
|
||||
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
|
||||
# Removed two possible ticker names "NA" and "NULL" from the default na_values list for column 0
|
||||
miss_code = set(
|
||||
pd.read_csv(
|
||||
_instrument,
|
||||
sep="\t",
|
||||
header=None,
|
||||
keep_default_na=False,
|
||||
na_values={
|
||||
0: [
|
||||
" ",
|
||||
"#N/A",
|
||||
"#N/A N/A",
|
||||
"#NA",
|
||||
"-1.#IND",
|
||||
"-1.#QNAN",
|
||||
"-NaN",
|
||||
"-nan",
|
||||
"1.#IND",
|
||||
"1.#QNAN",
|
||||
"<NA>",
|
||||
"N/A",
|
||||
"NaN",
|
||||
"None",
|
||||
"n/a",
|
||||
"nan",
|
||||
"null ",
|
||||
]
|
||||
},
|
||||
)
|
||||
.loc[:, 0]
|
||||
.apply(str.lower)
|
||||
) - set(code_names)
|
||||
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
|
||||
return False
|
||||
|
||||
@@ -741,6 +800,7 @@ def fill_placeholder(config: dict, config_extend: dict):
|
||||
)
|
||||
return value
|
||||
|
||||
item_keys = None
|
||||
while top < tail:
|
||||
now_item = item_queue[top]
|
||||
top += 1
|
||||
|
||||
@@ -44,7 +44,7 @@ def concat(data_list: Union[SingleData], axis=0) -> MultiData:
|
||||
all_index_map = dict(zip(all_index, range(len(all_index))))
|
||||
|
||||
# concat all
|
||||
tmp_data = np.full((len(all_index), len(data_list)), np.NaN)
|
||||
tmp_data = np.full((len(all_index), len(data_list)), np.nan)
|
||||
for data_id, index_data in enumerate(data_list):
|
||||
assert isinstance(index_data, SingleData)
|
||||
now_data_map = [all_index_map[index] for index in index_data.index]
|
||||
@@ -64,7 +64,7 @@ def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) ->
|
||||
new_index : list
|
||||
the new_index of new SingleData.
|
||||
fill_value : float
|
||||
fill the missing values or replace np.NaN.
|
||||
fill the missing values or replace np.nan.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -108,6 +108,12 @@ class Index:
|
||||
self.index_map = self.idx_list = np.arange(idx_list)
|
||||
self._is_sorted = True
|
||||
else:
|
||||
# Check if all elements in idx_list are of the same type
|
||||
if not all(isinstance(x, type(idx_list[0])) for x in idx_list):
|
||||
raise TypeError("All elements in idx_list must be of the same type")
|
||||
# Check if all elements in idx_list are of the same datetime64 precision
|
||||
if isinstance(idx_list[0], np.datetime64) and not all(x.dtype == idx_list[0].dtype for x in idx_list):
|
||||
raise TypeError("All elements in idx_list must be of the same datetime64 precision")
|
||||
self.idx_list = np.array(idx_list)
|
||||
# NOTE: only the first appearance is indexed
|
||||
self.index_map = dict(zip(self.idx_list, range(len(self))))
|
||||
@@ -131,7 +137,12 @@ class Index:
|
||||
if self.idx_list.dtype.type is np.datetime64:
|
||||
if isinstance(item, pd.Timestamp):
|
||||
# This happens often when creating index based on pandas.DatetimeIndex and query with pd.Timestamp
|
||||
return item.to_numpy()
|
||||
return item.to_numpy().astype(self.idx_list.dtype)
|
||||
elif isinstance(item, np.datetime64):
|
||||
# This happens often when creating index based on np.datetime64 and query with another precision
|
||||
return item.astype(self.idx_list.dtype)
|
||||
# NOTE: It is hard to consider every case at first.
|
||||
# We just try to cover part of cases to make it more user-friendly
|
||||
return item
|
||||
|
||||
def index(self, item) -> int:
|
||||
@@ -433,7 +444,7 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
return self.__class__(~self.data.astype(bool), *self.indices)
|
||||
|
||||
def abs(self):
|
||||
"""get the abs of data except np.NaN."""
|
||||
"""get the abs of data except np.nan."""
|
||||
tmp_data = np.absolute(self.data)
|
||||
return self.__class__(tmp_data, *self.indices)
|
||||
|
||||
@@ -555,8 +566,8 @@ class SingleData(IndexData):
|
||||
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
|
||||
)
|
||||
|
||||
def reindex(self, index: Index, fill_value=np.NaN) -> SingleData:
|
||||
"""reindex data and fill the missing value with np.NaN.
|
||||
def reindex(self, index: Index, fill_value=np.nan) -> SingleData:
|
||||
"""reindex data and fill the missing value with np.nan.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -604,7 +615,7 @@ class SingleData(IndexData):
|
||||
return pd.Series(self.data, index=self.index)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(pd.Series(self.data, index=self.index))
|
||||
return str(pd.Series(self.data, index=self.index.tolist()))
|
||||
|
||||
|
||||
class MultiData(IndexData):
|
||||
@@ -640,4 +651,4 @@ class MultiData(IndexData):
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(pd.DataFrame(self.data, index=self.index, columns=self.columns))
|
||||
return str(pd.DataFrame(self.data, index=self.index.tolist(), columns=self.columns.tolist()))
|
||||
|
||||
@@ -161,7 +161,13 @@ def init_instance_by_config(
|
||||
# path like 'file:///<path to pickle file>/obj.pkl'
|
||||
pr = urlparse(config)
|
||||
if pr.scheme == "file":
|
||||
pr_path = os.path.join(pr.netloc, pr.path) if bool(pr.path) else pr.netloc
|
||||
|
||||
# To enable relative path like file://data/a/b/c.pkl. pr.netloc will be data
|
||||
path = pr.path
|
||||
if pr.netloc != "":
|
||||
path = path.lstrip("/")
|
||||
|
||||
pr_path = os.path.join(pr.netloc, path) if bool(pr.path) else pr.netloc
|
||||
with open(os.path.normpath(pr_path), "rb") as f:
|
||||
return pickle.load(f)
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import threading
|
||||
from functools import partial
|
||||
from threading import Thread
|
||||
from typing import Callable, Text, Union
|
||||
@@ -9,7 +10,7 @@ from joblib import Parallel, delayed
|
||||
from joblib._parallel_backends import MultiprocessingBackend
|
||||
import pandas as pd
|
||||
|
||||
from queue import Queue
|
||||
from queue import Empty, Queue
|
||||
import concurrent
|
||||
|
||||
from qlib.config import C, QlibConfig
|
||||
@@ -85,7 +86,17 @@ class AsyncCaller:
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
data = self._q.get()
|
||||
# NOTE:
|
||||
# atexit will only trigger when all the threads ended. So it may results in deadlock.
|
||||
# So the child-threading should actively watch the status of main threading to stop itself.
|
||||
main_thread = threading.main_thread()
|
||||
if not main_thread.is_alive():
|
||||
break
|
||||
try:
|
||||
data = self._q.get(timeout=1)
|
||||
except Empty:
|
||||
# NOTE: avoid deadlock. make checking main thread possible
|
||||
continue
|
||||
if data == self.STOP_MARK:
|
||||
break
|
||||
data()
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import fire
|
||||
from jinja2 import Template, meta
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import ruamel.yaml as yaml
|
||||
from qlib.config import C
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.utils.data import update_config
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.utils import set_log_with_config
|
||||
from qlib.utils.data import update_config
|
||||
|
||||
set_log_with_config(C.logging_config)
|
||||
logger = get_module_logger("qrun", logging.INFO)
|
||||
@@ -47,6 +49,39 @@ def sys_config(config, config_path):
|
||||
sys.path.append(str(Path(config_path).parent.resolve().absolute() / p))
|
||||
|
||||
|
||||
def render_template(config_path: str) -> str:
|
||||
"""
|
||||
render the template based on the environment
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str
|
||||
configuration path
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
the rendered content
|
||||
"""
|
||||
with open(config_path, "r") as f:
|
||||
config = f.read()
|
||||
# Set up the Jinja2 environment
|
||||
template = Template(config)
|
||||
|
||||
# Parse the template to find undeclared variables
|
||||
env = template.environment
|
||||
parsed_content = env.parse(config)
|
||||
variables = meta.find_undeclared_variables(parsed_content)
|
||||
|
||||
# Get context from os.environ according to the variables
|
||||
context = {var: os.getenv(var, "") for var in variables if var in os.environ}
|
||||
logger.info(f"Render the template with the context: {context}")
|
||||
|
||||
# Render the template with the context
|
||||
rendered_content = template.render(context)
|
||||
return rendered_content
|
||||
|
||||
|
||||
# workflow handler function
|
||||
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
"""
|
||||
@@ -67,8 +102,10 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
market: csi300
|
||||
|
||||
"""
|
||||
with open(config_path) as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
# Render the template
|
||||
rendered_yaml = render_template(config_path)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(rendered_yaml)
|
||||
|
||||
base_config_path = config.get("BASE_CONFIG_PATH", None)
|
||||
if base_config_path:
|
||||
@@ -90,7 +127,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}")
|
||||
|
||||
with open(path) as fp:
|
||||
base_config = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
base_config = yaml.load(fp)
|
||||
logger.info(f"Load BASE_CONFIG_PATH succeed: {path.resolve()}")
|
||||
config = update_config(base_config, config)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCod
|
||||
from mlflow.entities import ViewType
|
||||
import os
|
||||
from typing import Optional, Text
|
||||
from pathlib import Path
|
||||
|
||||
from .exp import MLflowExperiment, Experiment
|
||||
from ..config import C
|
||||
@@ -233,7 +234,7 @@ class ExpManager:
|
||||
# So we supported it in the interface wrapper
|
||||
pr = urlparse(self.uri)
|
||||
if pr.scheme == "file":
|
||||
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")): # pylint: disable=E0110
|
||||
with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110
|
||||
return self.create_exp(experiment_name), True
|
||||
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
|
||||
try:
|
||||
@@ -421,7 +422,11 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
def list_experiments(self):
|
||||
# retrieve all the existing experiments
|
||||
exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY)
|
||||
mlflow_version = int(mlflow.__version__.split(".", maxsplit=1)[0])
|
||||
if mlflow_version >= 2:
|
||||
exps = self.client.search_experiments(view_type=ViewType.ACTIVE_ONLY)
|
||||
else:
|
||||
exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY) # pylint: disable=E1101
|
||||
experiments = dict()
|
||||
for exp in exps:
|
||||
experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)
|
||||
|
||||
@@ -9,6 +9,7 @@ import shutil
|
||||
import pickle
|
||||
import tempfile
|
||||
import subprocess
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
@@ -316,7 +317,10 @@ class MLflowRecorder(Recorder):
|
||||
This function will return the directory path of this recorder.
|
||||
"""
|
||||
if self.artifact_uri is not None:
|
||||
local_dir_path = Path(self.artifact_uri.lstrip("file:")) / ".."
|
||||
if platform.system() == "Windows":
|
||||
local_dir_path = Path(self.artifact_uri.lstrip("file:").lstrip("/")).parent
|
||||
else:
|
||||
local_dir_path = Path(self.artifact_uri.lstrip("file:")).parent
|
||||
local_dir_path = str(local_dir_path.resolve())
|
||||
if os.path.isdir(local_dir_path):
|
||||
return local_dir_path
|
||||
|
||||
@@ -242,7 +242,7 @@ class TimeAdjuster:
|
||||
|
||||
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
|
||||
"""
|
||||
Shift the datatime of segment
|
||||
Shift the datetime of segment
|
||||
|
||||
If there are None (which indicates unbounded index) in the segment, this method will return None.
|
||||
|
||||
|
||||
@@ -301,6 +301,7 @@ class Normalize:
|
||||
na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns},
|
||||
)
|
||||
|
||||
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if df is not None and not df.empty:
|
||||
if self._end_date is not None:
|
||||
|
||||
@@ -28,7 +28,7 @@ termcolor==1.1.0
|
||||
tqdm==4.63.0
|
||||
trio==0.20.0
|
||||
trio-websocket==0.9.2
|
||||
urllib3==1.26.8
|
||||
urllib3==1.26.19
|
||||
wget==3.2
|
||||
wsproto==1.1.0
|
||||
yahooquery==2.2.15
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user