mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-12 00:40:58 +08:00
Compare commits
18 Commits
ptnn4both
...
update_rea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8ac25ff4bd | ||
|
|
b101006750 | ||
|
|
5a84aaf1dc | ||
|
|
afbb178e24 | ||
|
|
a0cef033cb | ||
|
|
7acb4f3484 | ||
|
|
431f574967 | ||
|
|
b604fe56b3 | ||
|
|
af4b8772d2 | ||
|
|
18fcdf1521 | ||
|
|
f2caf452e9 | ||
|
|
ca9f1861a4 | ||
|
|
b45b006ef2 | ||
|
|
82cf438401 | ||
|
|
9e635168c0 | ||
|
|
b7ace1a622 | ||
|
|
c9ed050ef0 | ||
|
|
2c33332dd6 |
8
.dockerignore
Normal file
8
.dockerignore
Normal file
@@ -0,0 +1,8 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
.env
|
||||
.git
|
||||
|
||||
66
.github/workflows/python-publish.yml
vendored
66
.github/workflows/python-publish.yml
vendored
@@ -12,70 +12,54 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-11]
|
||||
# FIXME: macos-latest will raise error now.
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, macos-13, macos-latest]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
exclude:
|
||||
- os: macos-13
|
||||
python-version: "3.11"
|
||||
- os: macos-13
|
||||
python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
# This is because on macos systems you can install pyqlib using
|
||||
# `pip install pyqlib` installs, it does not recognize the
|
||||
# `pyqlib-<version>-cp38-cp38-macosx_11_0_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_11_0_x86_64.whl`.
|
||||
# So we limit the version of python, in order to generate a version of qlib that is usable for macos: `pyqlib-<veresion>-cp38-cp37m
|
||||
# `pyqlib-<version>-cp38-cp38-macosx_10_15_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_10_15_x86_64.whl`.
|
||||
# Python 3.7.16, 3.8.16 can build macosx_10_15. But Python 3.7.17, 3.8.17 can build macosx_11_0
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os == 'macos-11' && matrix.python-version == '3.7'
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os == 'macos-11' && matrix.python-version == '3.8'
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.8.16"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os != 'macos-11'
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install setuptools wheel twine
|
||||
make dev
|
||||
- name: Build wheel on ${{ matrix.os }}
|
||||
run: |
|
||||
pip install numpy
|
||||
pip install cython
|
||||
python setup.py bdist_wheel
|
||||
- name: Build and publish
|
||||
make build
|
||||
- name: Upload to PyPi
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine upload dist/*
|
||||
twine check dist/*.whl
|
||||
twine upload dist/*.whl --verbose
|
||||
|
||||
deploy_with_manylinux:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Build wheel on Linux
|
||||
uses: RalfG/python-wheels-manylinux-build@v0.3.1-manylinux2010_x86_64
|
||||
uses: RalfG/python-wheels-manylinux-build@v0.7.1-manylinux2014_x86_64
|
||||
with:
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-versions: 'cp37-cp37m cp38-cp38'
|
||||
python-versions: 'cp38-cp38 cp39-cp39 cp310-cp310 cp311-cp311 cp312-cp312'
|
||||
build-requirements: 'numpy cython'
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.7
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install twine
|
||||
- name: Build and publish
|
||||
python -m pip install twine
|
||||
- name: Upload to PyPi
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine upload dist/pyqlib-*-manylinux*.whl
|
||||
twine check dist/pyqlib-*-manylinux*.whl
|
||||
twine upload dist/pyqlib-*-manylinux*.whl --verbose
|
||||
|
||||
29
.github/workflows/test_qlib_from_pip.yml
vendored
29
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -13,28 +13,17 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from pip
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -45,13 +34,10 @@ jobs:
|
||||
|
||||
- name: Qlib installation test
|
||||
run: |
|
||||
# 2024-05-30 scs has released a new version: 3.2.4.post2,
|
||||
# This will cause the CI to fail, so we have limited the version of scs for now.
|
||||
python -m pip install "scs<=3.2.4"
|
||||
python -m pip install pyqlib
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
@@ -68,8 +54,5 @@ jobs:
|
||||
cd qlib
|
||||
|
||||
- name: Test workflow by config
|
||||
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
|
||||
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
102
.github/workflows/test_qlib_from_source.yml
vendored
102
.github/workflows/test_qlib_from_source.yml
vendored
@@ -14,28 +14,17 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -45,7 +34,7 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
@@ -61,87 +50,33 @@ jobs:
|
||||
|
||||
- name: Set up Python tools
|
||||
run: |
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install -e .[dev]
|
||||
make dev
|
||||
|
||||
- name: Lint with Black
|
||||
# Python 3.7 will use a black with low level. So we use python with higher version for black check
|
||||
if: (matrix.python-version != '3.7')
|
||||
run: |
|
||||
pip install -U black # follow the latest version of black, previous Qlib dependency will downgrade black
|
||||
black . -l 120 --check --diff
|
||||
make black
|
||||
|
||||
- name: Make html with sphinx
|
||||
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
|
||||
if: ${{ matrix.os == 'ubuntu-22.04' }}
|
||||
run: |
|
||||
cd docs
|
||||
sphinx-build -W --keep-going -b html . _build
|
||||
cd ..
|
||||
make docs-gen
|
||||
|
||||
# Check Qlib with pylint
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
# C0209: consider-using-f-string
|
||||
# R0402: consider-using-from-import
|
||||
# R1705: no-else-return
|
||||
# R1710: inconsistent-return-statements
|
||||
# R1725: super-with-arguments
|
||||
# R1735: use-dict-literal
|
||||
# W0102: dangerous-default-value
|
||||
# W0212: protected-access
|
||||
# W0221: arguments-differ
|
||||
# W0223: abstract-method
|
||||
# W0231: super-init-not-called
|
||||
# W0237: arguments-renamed
|
||||
# W0612: unused-variable
|
||||
# W0621: redefined-outer-name
|
||||
# W0622: redefined-builtin
|
||||
# FIXME: specify exception type
|
||||
# W0703: broad-except
|
||||
# W1309: f-string-without-interpolation
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
|
||||
- name: Check Qlib with pylint
|
||||
run: |
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
make pylint
|
||||
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
# Description: We have used black to limit the length of each line to 120.
|
||||
# F541 f-string is missing placeholders
|
||||
# Description: The same thing is done when using pylint for detection.
|
||||
# E266 too many leading '#' for block comment
|
||||
# Description: To make the code more readable, a lot of "#" is used.
|
||||
# This error code appears centrally in:
|
||||
# qlib/backtest/executor.py
|
||||
# qlib/data/ops.py
|
||||
# qlib/utils/__init__.py
|
||||
# E402 module level import not at top of file
|
||||
# Description: There are times when module level import is not available at the top of the file.
|
||||
# W503 line break before binary operator
|
||||
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
|
||||
# E731 do not assign a lambda expression, use a def
|
||||
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
|
||||
# E203 whitespace before ':'
|
||||
# Description: If there is whitespace before ":", it cannot pass the black check.
|
||||
- name: Check Qlib with flake8
|
||||
run: |
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
make flake8
|
||||
|
||||
# https://github.com/python/mypy/issues/10600
|
||||
- name: Check Qlib with mypy
|
||||
run: |
|
||||
mypy qlib --install-types --non-interactive || true
|
||||
mypy qlib --verbose
|
||||
make mypy
|
||||
|
||||
- name: Check Qlib ipynb with nbqa
|
||||
run: |
|
||||
nbqa black . -l 120 --check --diff
|
||||
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$'
|
||||
make nbqa
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
@@ -149,7 +84,7 @@ jobs:
|
||||
python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
@@ -159,18 +94,11 @@ jobs:
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
|
||||
# Run after data downloads
|
||||
- name: Check Qlib ipynb with nbconvert
|
||||
# Running the nbconvert check on a macos-11 system results in a "Kernel died" error, so we've temporarily disabled macos-11 here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
# add more ipynb files in future
|
||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||
make nbconvert
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
|
||||
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
python -m pip install numba
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
27
.github/workflows/test_qlib_from_source_slow.yml
vendored
27
.github/workflows/test_qlib_from_source_slow.yml
vendored
@@ -14,44 +14,31 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source slow
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
make dev
|
||||
|
||||
- name: Downloads dependencies data
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -48,4 +48,5 @@ tags
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
.idea/
|
||||
.idea/
|
||||
.aider*
|
||||
|
||||
@@ -9,7 +9,7 @@ version: 2
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.7"
|
||||
python: "3.8"
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
|
||||
31
Dockerfile
Normal file
31
Dockerfile
Normal file
@@ -0,0 +1,31 @@
|
||||
FROM continuumio/miniconda3:latest
|
||||
|
||||
WORKDIR /qlib
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential
|
||||
|
||||
RUN conda create --name qlib_env python=3.8 -y
|
||||
RUN echo "conda activate qlib_env" >> ~/.bashrc
|
||||
ENV PATH /opt/conda/envs/qlib_env/bin:$PATH
|
||||
|
||||
RUN python -m pip install --upgrade pip
|
||||
|
||||
RUN python -m pip install numpy==1.23.5
|
||||
RUN python -m pip install pandas==1.5.3
|
||||
RUN python -m pip install importlib-metadata==5.2.0
|
||||
RUN python -m pip install "cloudpickle<3"
|
||||
RUN python -m pip install scikit-learn==1.3.2
|
||||
|
||||
RUN python -m pip install cython packaging tables matplotlib statsmodels
|
||||
RUN python -m pip install pybind11 cvxpy
|
||||
|
||||
ARG IS_STABLE="yes"
|
||||
|
||||
RUN if [ "$IS_STABLE" = "yes" ]; then \
|
||||
python -m pip install pyqlib; \
|
||||
else \
|
||||
python setup.py install; \
|
||||
fi
|
||||
@@ -1 +1,6 @@
|
||||
include qlib/VERSION.txt
|
||||
exclude tests/*
|
||||
include qlib/*
|
||||
include qlib/*/*
|
||||
include qlib/*/*/*
|
||||
include qlib/*/*/*/*
|
||||
include qlib/*/*/*/*/*
|
||||
|
||||
195
Makefile
Normal file
195
Makefile
Normal file
@@ -0,0 +1,195 @@
|
||||
.PHONY: clean deepclean prerequisite dependencies lightgbm rl develop lint docs package test analysis all install dev black pylint flake8 mypy nbqa nbconvert lint build upload docs-gen
|
||||
#You can modify it according to your terminal
|
||||
SHELL := /bin/bash
|
||||
|
||||
########################################################################################
|
||||
# Variables
|
||||
########################################################################################
|
||||
|
||||
# Documentation target directory, will be adapted to specific folder for readthedocs.
|
||||
PUBLIC_DIR := $(shell [ "$$READTHEDOCS" = "True" ] && echo "$$READTHEDOCS_OUTPUT/html" || echo "public")
|
||||
|
||||
SO_DIR := qlib/data/_libs
|
||||
SO_FILES := $(wildcard $(SO_DIR)/*.so)
|
||||
|
||||
########################################################################################
|
||||
# Development Environment Management
|
||||
########################################################################################
|
||||
# Remove common intermediate files.
|
||||
clean:
|
||||
-rm -rf \
|
||||
$(PUBLIC_DIR) \
|
||||
qlib/data/_libs/*.cpp \
|
||||
qlib/data/_libs/*.so \
|
||||
mlruns \
|
||||
public \
|
||||
build \
|
||||
.coverage \
|
||||
.mypy_cache \
|
||||
.pytest_cache \
|
||||
.ruff_cache \
|
||||
Pipfile* \
|
||||
coverage.xml \
|
||||
dist \
|
||||
release-notes.md
|
||||
|
||||
find . -name '*.egg-info' -print0 | xargs -0 rm -rf
|
||||
find . -name '*.pyc' -print0 | xargs -0 rm -f
|
||||
find . -name '*.swp' -print0 | xargs -0 rm -f
|
||||
find . -name '.DS_Store' -print0 | xargs -0 rm -f
|
||||
find . -name '__pycache__' -print0 | xargs -0 rm -rf
|
||||
|
||||
# Remove pre-commit hook, virtual environment alongside itermediate files.
|
||||
deepclean: clean
|
||||
if command -v pre-commit > /dev/null 2>&1; then pre-commit uninstall --hook-type pre-push; fi
|
||||
if command -v pipenv >/dev/null 2>&1 && pipenv --venv >/dev/null 2>&1; then pipenv --rm; fi
|
||||
|
||||
# Prerequisite section
|
||||
# What this code does is compile two Cython modules, rolling and expanding, using setuptools and Cython,
|
||||
# and builds them as binary expansion modules that can be imported directly into Python.
|
||||
# Since pyproject.toml can't do that, we compile it here.
|
||||
prerequisite:
|
||||
@if [ -n "$(SO_FILES)" ]; then \
|
||||
echo "Shared library files exist, skipping build."; \
|
||||
else \
|
||||
echo "No shared library files found, building..."; \
|
||||
pip install --upgrade setuptools wheel; \
|
||||
python -m pip install cython numpy; \
|
||||
python -c "from setuptools import setup, Extension; from Cython.Build import cythonize; import numpy; extensions = [Extension('qlib.data._libs.rolling', ['qlib/data/_libs/rolling.pyx'], language='c++', include_dirs=[numpy.get_include()]), Extension('qlib.data._libs.expanding', ['qlib/data/_libs/expanding.pyx'], language='c++', include_dirs=[numpy.get_include()])]; setup(ext_modules=cythonize(extensions, language_level='3'), script_args=['build_ext', '--inplace'])"; \
|
||||
fi
|
||||
|
||||
# Install the package in editable mode.
|
||||
dependencies:
|
||||
python -m pip install -e .
|
||||
|
||||
lightgbm:
|
||||
python -m pip install lightgbm --prefer-binary
|
||||
|
||||
rl:
|
||||
python -m pip install -e .[rl]
|
||||
|
||||
develop:
|
||||
python -m pip install -e .[dev]
|
||||
|
||||
lint:
|
||||
python -m pip install -e .[lint]
|
||||
|
||||
docs:
|
||||
python -m pip install -e .[docs]
|
||||
|
||||
package:
|
||||
python -m pip install -e .[package]
|
||||
|
||||
test:
|
||||
python -m pip install -e .[test]
|
||||
|
||||
analysis:
|
||||
python -m pip install -e .[analysis]
|
||||
|
||||
all:
|
||||
python -m pip install -e .[dev,lint,docs,package,test,analysis,rl]
|
||||
|
||||
install: prerequisite dependencies
|
||||
|
||||
dev: prerequisite all
|
||||
|
||||
########################################################################################
|
||||
# Lint and pre-commit
|
||||
########################################################################################
|
||||
|
||||
# Check lint with black.
|
||||
black:
|
||||
black . -l 120 --check --diff
|
||||
|
||||
# Check code folder with pylint.
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
# C0209: consider-using-f-string
|
||||
# R0402: consider-using-from-import
|
||||
# R1705: no-else-return
|
||||
# R1710: inconsistent-return-statements
|
||||
# R1725: super-with-arguments
|
||||
# R1735: use-dict-literal
|
||||
# W0102: dangerous-default-value
|
||||
# W0212: protected-access
|
||||
# W0221: arguments-differ
|
||||
# W0223: abstract-method
|
||||
# W0231: super-init-not-called
|
||||
# W0237: arguments-renamed
|
||||
# W0612: unused-variable
|
||||
# W0621: redefined-outer-name
|
||||
# W0622: redefined-builtin
|
||||
# FIXME: specify exception type
|
||||
# W0703: broad-except
|
||||
# W1309: f-string-without-interpolation
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# W4904: deprecated-class
|
||||
# R0917: too-many-positional-arguments
|
||||
# E1123: unexpected-keyword-arg
|
||||
# References for disable error: https://pylint.pycqa.org/en/latest/user_guide/messages/messages_overview.html
|
||||
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
pylint:
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,W4904,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1730,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' qlib --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,E1123,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' scripts --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
|
||||
# Check code with flake8.
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
# Description: We have used black to limit the length of each line to 120.
|
||||
# F541 f-string is missing placeholders
|
||||
# Description: The same thing is done when using pylint for detection.
|
||||
# E266 too many leading '#' for block comment
|
||||
# Description: To make the code more readable, a lot of "#" is used.
|
||||
# This error code appears centrally in:
|
||||
# qlib/backtest/executor.py
|
||||
# qlib/data/ops.py
|
||||
# qlib/utils/__init__.py
|
||||
# E402 module level import not at top of file
|
||||
# Description: There are times when module level import is not available at the top of the file.
|
||||
# W503 line break before binary operator
|
||||
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
|
||||
# E731 do not assign a lambda expression, use a def
|
||||
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
|
||||
# E203 whitespace before ':'
|
||||
# Description: If there is whitespace before ":", it cannot pass the black check.
|
||||
flake8:
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
# Check code with mypy.
|
||||
# https://github.com/python/mypy/issues/10600
|
||||
mypy:
|
||||
mypy qlib --install-types --non-interactive
|
||||
mypy qlib --verbose
|
||||
|
||||
# Check ipynb with nbqa.
|
||||
nbqa:
|
||||
nbqa black . -l 120 --check --diff
|
||||
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}'
|
||||
|
||||
# Check ipynb with nbconvert.(Run after data downloads)
|
||||
# TODO: Add more ipynb files in future
|
||||
nbconvert:
|
||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||
|
||||
lint: black pylint flake8 mypy nbqa
|
||||
|
||||
########################################################################################
|
||||
# Package
|
||||
########################################################################################
|
||||
|
||||
# Build the package.
|
||||
build:
|
||||
python -m build --wheel
|
||||
|
||||
# Upload the package.
|
||||
upload:
|
||||
python -m twine upload dist/*
|
||||
|
||||
########################################################################################
|
||||
# Documentation
|
||||
########################################################################################
|
||||
|
||||
docs-gen:
|
||||
python -m sphinx.cmd.build -W docs $(PUBLIC_DIR)
|
||||
95
README.md
95
README.md
@@ -8,9 +8,30 @@
|
||||
[](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
## :newspaper: **What's NEW!** :sparkling_heart:
|
||||
|
||||
Recent released features
|
||||
|
||||
### Introducing <a href="https://github.com/microsoft/RD-Agent"><img src="docs/_static/img/rdagent_logo.png" alt="RD_Agent" style="height: 2em"></a>: LLM-Based Autonomous Evolving Agents for Industrial Data-Driven R&D
|
||||
|
||||
We are excited to announce the release of **RD-Agent**📢, a powerful tool that supports automated factor mining and model optimization in quant investment R&D.
|
||||
|
||||
RD-Agent is now available on [GitHub](https://github.com/microsoft/RD-Agent), and we welcome your star🌟!
|
||||
|
||||
To learn more, please visit our [♾️Demo page](https://rdagent.azurewebsites.net/). Here, you will find demo videos in both English and Chinese to help you better understand the scenario and usage of RD-Agent.
|
||||
|
||||
We have prepared several demo videos for you:
|
||||
| Scenario | Demo video (English) | Demo video (中文) |
|
||||
| -- | ------ | ------ |
|
||||
| Quant Factor Mining | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=zh) |
|
||||
| Quant Factor Mining from reports | [Link](https://rdagent.azurewebsites.net/report_factor?lang=en) | [Link](https://rdagent.azurewebsites.net/report_factor?lang=zh) |
|
||||
| Quant Model Optimization | [Link](https://rdagent.azurewebsites.net/model_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/model_loop?lang=zh) |
|
||||
|
||||
***
|
||||
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| BPQP for End-to-end learning | 📈Coming soon!([Under review](https://github.com/microsoft/qlib/pull/1863)) |
|
||||
| 🔥LLM-driven Auto Quant Factory🔥 | 🚀 Released in [♾️RD-Agent](https://github.com/microsoft/RD-Agent) on Aug 8, 2024 |
|
||||
| KRNN and Sandwich models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1414/) on May 26, 2023 |
|
||||
| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |
|
||||
| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|
|
||||
@@ -132,17 +153,18 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
|
||||
## Installation
|
||||
|
||||
This table demonstrates the supported Python version of `Qlib`:
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:----:|
|
||||
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:------------------:|
|
||||
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
| Python 3.9 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.10 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.11 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.12 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
|
||||
**Note**:
|
||||
1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
1. `Qlib`Requires `tables` package, `hdf5` in tables does not support python3.9.
|
||||
2. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.8 or higher, or use `conda`'s Python to install ``Qlib`` from source.
|
||||
3. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
|
||||
### Install with pip
|
||||
Users can easily install ``Qlib`` by pip according to the following command.
|
||||
@@ -160,7 +182,7 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
|
||||
```bash
|
||||
pip install numpy
|
||||
pip install --upgrade cython
|
||||
pip install --upgrade cython
|
||||
```
|
||||
|
||||
* Clone the repository and install ``Qlib`` as follows.
|
||||
@@ -168,7 +190,6 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
pip install . # `pip install -e .[dev]` is recommended for development. check details in docs/developer/code_standard_and_dev_guide.rst
|
||||
```
|
||||
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommended approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
|
||||
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test_qlib_from_source.yml) may help you find the problem.
|
||||
|
||||
@@ -176,11 +197,11 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
|
||||
## Data Preparation
|
||||
❗ Due to more restrict data security policy. The offical dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.
|
||||
Here is an example to download the data updated on 20220720.
|
||||
Here is an example to download the data updated on 20240809.
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
wget https://github.com/chenditc/investment_data/releases/download/2024-08-09/qlib_bin.tar.gz
|
||||
mkdir -p ~/.qlib/qlib_data/cn_data
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=1
|
||||
rm -f qlib_bin.tar.gz
|
||||
```
|
||||
|
||||
@@ -272,6 +293,38 @@ We recommend users to prepare their own data if they have a high-quality dataset
|
||||
```
|
||||
-->
|
||||
|
||||
## Docker images
|
||||
1. Pulling a docker image from a docker hub repository
|
||||
```bash
|
||||
docker pull pyqlib/qlib_image_stable:stable
|
||||
```
|
||||
2. Start a new Docker container
|
||||
```bash
|
||||
docker run -it --name <container name> -v <Mounted local directory>:/app qlib_image_stable
|
||||
```
|
||||
3. At this point you are in the docker environment and can run the qlib scripts. An example:
|
||||
```bash
|
||||
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
>>> python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
4. Exit the container
|
||||
```bash
|
||||
>>> exit
|
||||
```
|
||||
5. Restart the container
|
||||
```bash
|
||||
docker start -i -a <container name>
|
||||
```
|
||||
6. Stop the container
|
||||
```bash
|
||||
docker stop <container name>
|
||||
```
|
||||
7. Delete the container
|
||||
```bash
|
||||
docker rm <container name>
|
||||
```
|
||||
8. If you want to know more information, please refer to the [documentation](https://qlib.readthedocs.io/en/latest/developer/how_to_build_image.html).
|
||||
|
||||
## Auto Quant Research Workflow
|
||||
Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||
|
||||
@@ -305,22 +358,22 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
```
|
||||
Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
|
||||
|
||||
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
|
||||
2. Graphical Reports Analysis: First, run `python -m pip install .[analysis]` to install the required dependencies. Then run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports.
|
||||
- Forecasting signal (model prediction) analysis
|
||||
- Cumulative Return of groups
|
||||

|
||||

|
||||
- Return distribution
|
||||

|
||||

|
||||
- Information Coefficient (IC)
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
- Auto Correlation of forecasting signal (model prediction)
|
||||

|
||||

|
||||
|
||||
- Portfolio analysis
|
||||
- Backtest return
|
||||

|
||||

|
||||
<!--
|
||||
- Score IC
|
||||

|
||||
@@ -499,7 +552,7 @@ Qlib data are stored in a compact format, which is efficient to be combined into
|
||||
Join IM discussion groups:
|
||||
|[Gitter](https://gitter.im/Microsoft/qlib)|
|
||||
|----|
|
||||
||
|
||||
||
|
||||
|
||||
# Contributing
|
||||
We appreciate all contributions and thank all the contributors!
|
||||
|
||||
31
build_docker_image.sh
Normal file
31
build_docker_image.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
docker_user="your_dockerhub_username"
|
||||
|
||||
read -p "Do you want to build the nightly version of the qlib image? (default is stable) (yes/no): " answer;
|
||||
answer=$(echo "$answer" | tr '[:upper:]' '[:lower:]')
|
||||
|
||||
if [ "$answer" = "yes" ]; then
|
||||
# Build the nightly version of the qlib image
|
||||
docker build --build-arg IS_STABLE=no -t qlib_image -f ./Dockerfile .
|
||||
image_tag="nightly"
|
||||
else
|
||||
# Build the stable version of the qlib image
|
||||
docker build -t qlib_image -f ./Dockerfile .
|
||||
image_tag="stable"
|
||||
fi
|
||||
|
||||
read -p "Is it uploaded to docker hub? (default is no) (yes/no): " answer;
|
||||
answer=$(echo "$answer" | tr '[:upper:]' '[:lower:]')
|
||||
|
||||
if [ "$answer" = "yes" ]; then
|
||||
# Log in to Docker Hub
|
||||
# If you are a new docker hub user, please verify your email address before proceeding with this step.
|
||||
docker login
|
||||
# Tag the Docker image
|
||||
docker tag qlib_image "$docker_user/qlib_image:$image_tag"
|
||||
# Push the Docker image to Docker Hub
|
||||
docker push "$docker_user/qlib_image:$image_tag"
|
||||
else
|
||||
echo "Not uploaded to docker hub."
|
||||
fi
|
||||
BIN
docs/_static/img/rdagent_logo.png
vendored
Normal file
BIN
docs/_static/img/rdagent_logo.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 94 KiB |
@@ -123,7 +123,6 @@ html_logo = "_static/img/logo/1.png"
|
||||
html_theme_options = {
|
||||
"logo_only": True,
|
||||
"collapse_navigation": False,
|
||||
"display_version": False,
|
||||
"navigation_depth": 4,
|
||||
}
|
||||
|
||||
|
||||
81
docs/developer/how_to_build_image.rst
Normal file
81
docs/developer/how_to_build_image.rst
Normal file
@@ -0,0 +1,81 @@
|
||||
.. _docker_image:
|
||||
|
||||
==================
|
||||
Build Docker Image
|
||||
==================
|
||||
|
||||
Dockerfile
|
||||
==========
|
||||
|
||||
There is a **Dockerfile** file in the root directory of the project from which you can build the docker image. There are two build methods in Dockerfile to choose from.
|
||||
When executing the build command, use the ``--build-arg`` parameter to control the image version. The ``--build-arg`` parameter defaults to ``yes``, which builds the ``stable`` version of the qlib image.
|
||||
|
||||
1.For the ``stable`` version, use ``pip install pyqlib`` to build the qlib image.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker build --build-arg IS_STABLE=yes -t <image name> -f ./Dockerfile .
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker build -t <image name> -f ./Dockerfile .
|
||||
|
||||
2. For the ``nightly`` version, use current source code to build the qlib image.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker build --build-arg IS_STABLE=no -t <image name> -f ./Dockerfile .
|
||||
|
||||
Auto build of qlib images
|
||||
=========================
|
||||
|
||||
1. There is a **build_docker_image.sh** file in the root directory of your project, which can be used to automatically build docker images and upload them to your docker hub repository(Optional, configuration required).
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
sh build_docker_image.sh
|
||||
>>> Do you want to build the nightly version of the qlib image? (default is stable) (yes/no):
|
||||
>>> Is it uploaded to docker hub? (default is no) (yes/no):
|
||||
|
||||
2. If you want to upload the built image to your docker hub repository, you need to edit your **build_docker_image.sh** file first, fill in ``docker_user`` in the file, and then execute this file.
|
||||
|
||||
How to use qlib images
|
||||
======================
|
||||
1. Start a new Docker container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker run -it --name <container name> -v <Mounted local directory>:/app <image name>
|
||||
|
||||
2. At this point you are in the docker environment and can run the qlib scripts. An example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
>>> python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
3. Exit the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
>>> exit
|
||||
|
||||
4. Restart the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker start -i -a <container name>
|
||||
|
||||
5. Stop the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker stop -i -a <container name>
|
||||
|
||||
6. Delete the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker rm <container name>
|
||||
|
||||
7. For more information on using docker see the `docker documentation <https://docs.docker.com/reference/cli/docker/>`_.
|
||||
@@ -61,6 +61,7 @@ Document Structure
|
||||
:caption: FOR DEVELOPERS:
|
||||
|
||||
Code Standard & Development Guidance <developer/code_standard_and_dev_guide.rst>
|
||||
How to build image <developer/how_to_build_image.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -7,9 +7,13 @@ What is GeneralPtNN
|
||||
- Now you can just replace the Pytorch model structure to run a NN model.
|
||||
|
||||
We provide an example to demonstrate the effectiveness of the current design.
|
||||
- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158 dataset)
|
||||
- `workflow_config_mlp.yaml` align with previous results [MLP](../README.md#Alpha158 dataset)
|
||||
- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158-dataset)
|
||||
- `workflow_config_gru2mlp.yaml` to demonstrate we can convert config from time-series to tabular data with minimal changes
|
||||
- You only have to change the net & dataset class to make the conversion.
|
||||
- `workflow_config_mlp.yaml` achieved similar functionality with [MLP](../README.md#Alpha158-dataset)
|
||||
|
||||
# TODO
|
||||
|
||||
We will align existing models to current design.
|
||||
- We will align existing models to current design.
|
||||
|
||||
- The result of `workflow_config_mlp.yaml` is different with the result of [MLP](../README.md#Alpha158-dataset) since GeneralPtNN has a different stopping method compared to previous implementations. Specificly, GeneralPtNN controls training according to epoches, whereas previous methods controlled by max_steps.
|
||||
|
||||
@@ -55,10 +55,6 @@ task:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
d_feat: 20
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 2e-4
|
||||
early_stop: 10
|
||||
@@ -67,6 +63,13 @@ task:
|
||||
loss: mse
|
||||
n_jobs: 20
|
||||
GPU: 0
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_gru_ts.GRUModel"
|
||||
pt_model_kwargs: {
|
||||
"d_feat": 20,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.,
|
||||
}
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
93
examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml
Normal file
93
examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
lr: 1e-3
|
||||
n_epochs: 1
|
||||
batch_size: 800
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
|
||||
pt_model_kwargs:
|
||||
input_dim: 20
|
||||
layers: [20,]
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -60,15 +60,15 @@ task:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
# FIXME: wrong parameters.
|
||||
lr: 2e-3
|
||||
batch_size: 8192
|
||||
GPU: 0
|
||||
loss: mse
|
||||
weight_decay: 0.0002
|
||||
pt_model_kwargs:
|
||||
input_dim: 157
|
||||
optimizer: adam
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
|
||||
pt_model_kwargs:
|
||||
input_dim: 157
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import argparse
|
||||
|
||||
import qlib
|
||||
import ruamel.yaml as yaml
|
||||
from ruamel.yaml import YAML
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
|
||||
def main(seed, config_file="configs/config_alstm.yaml"):
|
||||
# set random seed
|
||||
with open(config_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(f)
|
||||
|
||||
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
|
||||
seed_suffix = ""
|
||||
|
||||
@@ -9,8 +9,8 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from pprint import pprint
|
||||
from ruamel.yaml import YAML
|
||||
import subprocess
|
||||
import yaml
|
||||
from qlib.log import TimeInspector
|
||||
|
||||
from qlib import init
|
||||
@@ -30,7 +30,8 @@ if __name__ == "__main__":
|
||||
subprocess.run(f"qrun {config_path}", shell=True)
|
||||
|
||||
# 2) dump handler
|
||||
task_config = yaml.safe_load(config_path.open())
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
task_config = yaml.load(config_path.open())
|
||||
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
|
||||
pprint(hd_conf)
|
||||
hd: DataHandlerLP = init_instance_by_config(hd_conf)
|
||||
|
||||
@@ -9,10 +9,9 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from pprint import pprint
|
||||
from ruamel.yaml import YAML
|
||||
import subprocess
|
||||
|
||||
import yaml
|
||||
|
||||
from qlib import init
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.log import TimeInspector
|
||||
@@ -29,7 +28,8 @@ if __name__ == "__main__":
|
||||
exp_name = "data_mem_reuse_demo"
|
||||
|
||||
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
|
||||
task_config = yaml.safe_load(config_path.open())
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
task_config = yaml.load(config_path.open())
|
||||
|
||||
# 1) without using processed data in memory
|
||||
with TimeInspector.logt("The original time without reusing processed data in memory:"):
|
||||
|
||||
@@ -6,7 +6,6 @@ import sys
|
||||
import fire
|
||||
import time
|
||||
import glob
|
||||
import yaml
|
||||
import shutil
|
||||
import signal
|
||||
import inspect
|
||||
@@ -15,6 +14,7 @@ import functools
|
||||
import statistics
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from ruamel.yaml import YAML
|
||||
from pathlib import Path
|
||||
from operator import xor
|
||||
from pprint import pprint
|
||||
@@ -188,7 +188,8 @@ def gen_and_save_md_table(metrics, dataset):
|
||||
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
||||
def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
|
||||
with open(yaml_path, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(fp)
|
||||
try:
|
||||
del config["task"]["model"]["kwargs"]["seed"]
|
||||
except KeyError:
|
||||
|
||||
@@ -1,2 +1,93 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "numpy", "Cython"]
|
||||
requires = ["setuptools", "cython", "numpy>=1.24.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Operating System :: MacOS",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
name = "pyqlib"
|
||||
dynamic = ["version"]
|
||||
description = "A Quantitative-research Platform"
|
||||
requires-python = ">=3.8.0"
|
||||
readme = {file = "README.md", content-type = "text/markdown"}
|
||||
|
||||
dependencies = [
|
||||
"pyyaml",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"mlflow",
|
||||
"filelock>=3.16.0",
|
||||
"redis",
|
||||
"dill",
|
||||
"fire",
|
||||
"ruamel.yaml>=0.17.38",
|
||||
"python-redis-lock",
|
||||
"tqdm",
|
||||
"pymongo",
|
||||
"loguru",
|
||||
"lightgbm",
|
||||
"gym",
|
||||
"cvxpy",
|
||||
"joblib",
|
||||
"matplotlib",
|
||||
"jupyter",
|
||||
"nbconvert",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest",
|
||||
"statsmodels",
|
||||
]
|
||||
# On macos-13 system, when using python version greater than or equal to 3.10,
|
||||
# pytorch can't fully support Numpy version above 2.0, so, when you want to install torch,
|
||||
# it will limit the version of Numpy less than 2.0.
|
||||
rl = [
|
||||
"tianshou<=0.4.10",
|
||||
"torch",
|
||||
"numpy<2.0.0",
|
||||
]
|
||||
lint = [
|
||||
"black",
|
||||
"pylint",
|
||||
"mypy<1.5.0",
|
||||
"flake8",
|
||||
"nbqa",
|
||||
]
|
||||
docs = [
|
||||
"sphinx",
|
||||
"sphinx_rtd_theme",
|
||||
"readthedocs_sphinx_ext",
|
||||
]
|
||||
package = [
|
||||
"twine",
|
||||
"build",
|
||||
]
|
||||
# test_pit dependency packages
|
||||
test = [
|
||||
"yahooquery",
|
||||
"baostock",
|
||||
]
|
||||
analysis = [
|
||||
"plotly",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = [
|
||||
"qlib",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
qrun = "qlib.workflow.cli:run"
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.9.5.99"
|
||||
__version__ = "0.9.6.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
import yaml
|
||||
from ruamel.yaml import YAML
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
@@ -176,7 +176,8 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
config = {}
|
||||
else:
|
||||
with open(conf_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(f)
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
@@ -272,7 +273,8 @@ def auto_init(**kwargs):
|
||||
logger = get_module_logger("Initialization")
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
conf = yaml.load(f)
|
||||
|
||||
conf_type = conf.get("conf_type", "origin")
|
||||
if conf_type == "origin":
|
||||
|
||||
@@ -278,7 +278,7 @@ class BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `empty` method")
|
||||
|
||||
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
|
||||
"""Replace np.NaN with fill_value in two metrics and add them."""
|
||||
"""Replace np.nan with fill_value in two metrics and add them."""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `add` method")
|
||||
|
||||
@@ -412,7 +412,7 @@ class BaseOrderIndicator:
|
||||
metrics : Union[str, List[str]]
|
||||
all metrics needs to be sumed.
|
||||
fill_value : float, optional
|
||||
fill np.NaN with value. By default None.
|
||||
fill np.nan with value. By default None.
|
||||
"""
|
||||
|
||||
raise NotImplementedError(f"Please implement the 'sum_all_indicators' method")
|
||||
|
||||
@@ -325,9 +325,9 @@ class Indicator:
|
||||
|
||||
def _update_order_fulfill_rate(self) -> None:
|
||||
def func(deal_amount, amount):
|
||||
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
|
||||
# deal_amount is np.nan or None when there is no inner decision. So full fill rate is 0.
|
||||
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
|
||||
tmp_deal_amount = tmp_deal_amount.replace({np.NaN: 0})
|
||||
tmp_deal_amount = tmp_deal_amount.replace({np.nan: 0})
|
||||
return tmp_deal_amount / amount
|
||||
|
||||
self.order_indicator.transfer(func, "ffr")
|
||||
@@ -354,8 +354,8 @@ class Indicator:
|
||||
)
|
||||
|
||||
def func(trade_price, deal_amount):
|
||||
# trade_price is np.NaN instead of inf when deal_amount is zero.
|
||||
tmp_deal_amount = deal_amount.replace({0: np.NaN})
|
||||
# trade_price is np.nan instead of inf when deal_amount is zero.
|
||||
tmp_deal_amount = deal_amount.replace({0: np.nan})
|
||||
return trade_price / tmp_deal_amount
|
||||
|
||||
self.order_indicator.transfer(func, "trade_price")
|
||||
@@ -425,7 +425,7 @@ class Indicator:
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
|
||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||
# ~(np.NaN < 1e-8) -> ~(False) -> True
|
||||
# ~(np.nan < 1e-8) -> ~(False) -> True
|
||||
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
if agg == "vwap":
|
||||
|
||||
@@ -173,7 +173,11 @@ _default_config = {
|
||||
"filters": ["field_not_found"],
|
||||
}
|
||||
},
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||
# Normally this should be set to `False` to avoid duplicated logging [1].
|
||||
# However, due to bug in pytest, it requires log message to propagate to root logger to be captured by `caplog` [2].
|
||||
# [1] https://github.com/microsoft/qlib/pull/1661
|
||||
# [2] https://github.com/pytest-dev/pytest/issues/3697
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"], "propagate": False}},
|
||||
# To let qlib work with other packages, we shouldn't disable existing loggers.
|
||||
# Note that this param is default to True according to the documentation of logging.
|
||||
"disable_existing_loggers": False,
|
||||
|
||||
@@ -58,7 +58,7 @@ class Alpha360(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
@@ -83,7 +83,7 @@ class Alpha360(DataHandlerLP):
|
||||
data_loader=data_loader,
|
||||
learn_processors=learn_processors,
|
||||
infer_processors=infer_processors,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_label_config(self):
|
||||
@@ -109,7 +109,7 @@ class Alpha158(DataHandlerLP):
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
@@ -134,7 +134,7 @@ class Alpha158(DataHandlerLP):
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
process_type=process_type,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
|
||||
@@ -33,7 +33,7 @@ class CatBoostModel(Model, FeatureInt):
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
reweighter=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
|
||||
@@ -31,7 +31,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
sub_weights=None,
|
||||
epochs=100,
|
||||
early_stopping_rounds=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
|
||||
self.num_models = num_models # the number of sub-models
|
||||
|
||||
@@ -56,7 +56,7 @@ class ADARNN(Model):
|
||||
n_splits=2,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**_
|
||||
**_,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADARNN")
|
||||
@@ -154,10 +154,7 @@ class ADARNN(Model):
|
||||
self.model.train()
|
||||
criterion = nn.MSELoss()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||
len_loader = np.inf
|
||||
for loader in train_loader_list:
|
||||
if len(loader) < len_loader:
|
||||
len_loader = len(loader)
|
||||
out_weight_list = None
|
||||
for data_all in zip(*train_loader_list):
|
||||
# for data_all in zip(*train_loader_list):
|
||||
self.train_optimizer.zero_grad()
|
||||
@@ -571,6 +568,7 @@ class TransferLoss:
|
||||
Returns:
|
||||
[tensor] -- transfer loss
|
||||
"""
|
||||
loss = None
|
||||
if self.loss_type in ("mmd_lin", "mmd"):
|
||||
mmdloss = MMD_loss(kernel_type="linear")
|
||||
loss = mmdloss(X, Y)
|
||||
|
||||
@@ -63,7 +63,7 @@ class ADD(Model):
|
||||
mu=0.05,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADD")
|
||||
|
||||
@@ -52,7 +52,7 @@ class ALSTM(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ALSTM")
|
||||
|
||||
@@ -56,7 +56,7 @@ class ALSTM(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ALSTM")
|
||||
|
||||
@@ -56,7 +56,7 @@ class GATs(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GATs")
|
||||
|
||||
@@ -73,7 +73,7 @@ class GATs(Model):
|
||||
GPU=0,
|
||||
n_jobs=10,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GATs")
|
||||
|
||||
@@ -3,19 +3,16 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from torch.utils.data import DataLoader, RandomSampler, StackDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Callable, Optional, Text, Union
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
from typing import Union
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import StackDataset
|
||||
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
@@ -24,336 +21,14 @@ from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...utils import (
|
||||
auto_filter_kwargs,
|
||||
init_instance_by_config,
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
)
|
||||
from ...log import get_module_logger
|
||||
from ...workflow import R
|
||||
from qlib.contrib.meta.data_selection.utils import ICLoss
|
||||
from torch.nn import DataParallel
|
||||
|
||||
|
||||
class GeneralPTNN(Model):
|
||||
"""General Pytorch Neural Network Model
|
||||
Parameters
|
||||
----------
|
||||
input_dim : int
|
||||
input dimension
|
||||
output_dim : int
|
||||
output dimension
|
||||
layers : tuple
|
||||
layer sizes
|
||||
lr : float
|
||||
learning rate
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lr=0.001,
|
||||
max_steps=300,
|
||||
batch_size=2000,
|
||||
early_stop_rounds=50,
|
||||
eval_steps=20,
|
||||
optimizer="gd",
|
||||
loss="mse",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
weight_decay=0.0,
|
||||
data_parall=False,
|
||||
scheduler: Optional[Union[Callable]] = "default", # when it is Callable, it accept one argument named optimizer
|
||||
init_model=None,
|
||||
eval_train_metric=False,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_nn.Net",
|
||||
pt_model_kwargs={
|
||||
"input_dim": 360,
|
||||
"layers": (256,),
|
||||
},
|
||||
valid_key=DataHandlerLP.DK_L,
|
||||
# TODO: Infer Key is a more reasonable key. But it requires more detailed processing on label processing
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("DNNModelPytorch")
|
||||
self.logger.info("DNN pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.lr = lr
|
||||
self.max_steps = max_steps
|
||||
self.batch_size = batch_size
|
||||
self.early_stop_rounds = early_stop_rounds
|
||||
self.eval_steps = eval_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss_type = loss
|
||||
if isinstance(GPU, str):
|
||||
self.device = torch.device(GPU)
|
||||
else:
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.weight_decay = weight_decay
|
||||
self.data_parall = data_parall
|
||||
self.eval_train_metric = eval_train_metric
|
||||
self.valid_key = valid_key
|
||||
|
||||
self.best_step = None
|
||||
|
||||
self.logger.info(
|
||||
"DNN parameters setting:"
|
||||
f"\nlr : {lr}"
|
||||
f"\nmax_steps : {max_steps}"
|
||||
f"\nbatch_size : {batch_size}"
|
||||
f"\nearly_stop_rounds : {early_stop_rounds}"
|
||||
f"\neval_steps : {eval_steps}"
|
||||
f"\noptimizer : {optimizer}"
|
||||
f"\nloss_type : {loss}"
|
||||
f"\nseed : {seed}"
|
||||
f"\ndevice : {self.device}"
|
||||
f"\nuse_GPU : {self.use_gpu}"
|
||||
f"\nweight_decay : {weight_decay}"
|
||||
f"\nenable data parall : {self.data_parall}"
|
||||
f"\npt_model_uri: {pt_model_uri}"
|
||||
f"\npt_model_kwargs: {pt_model_kwargs}"
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss))
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
if init_model is None:
|
||||
self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs})
|
||||
|
||||
if self.data_parall:
|
||||
self.dnn_model = DataParallel(self.dnn_model).to(self.device)
|
||||
else:
|
||||
self.dnn_model = init_model
|
||||
|
||||
self.logger.info("model:\n{:}".format(self.dnn_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
if scheduler == "default":
|
||||
# Reduce learning rate when loss has stopped decrease
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.train_optimizer,
|
||||
mode="min",
|
||||
factor=0.5,
|
||||
patience=10,
|
||||
verbose=True,
|
||||
threshold=0.0001,
|
||||
threshold_mode="rel",
|
||||
cooldown=0,
|
||||
min_lr=0.00001,
|
||||
eps=1e-08,
|
||||
)
|
||||
elif scheduler is None:
|
||||
self.scheduler = None
|
||||
else:
|
||||
self.scheduler = scheduler(optimizer=self.train_optimizer)
|
||||
|
||||
self.dnn_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
|
||||
def _eval_valid_dl(self, valid_loader, val_index):
|
||||
with torch.no_grad():
|
||||
self.dnn_model.eval()
|
||||
val_loss = []
|
||||
val_pred = []
|
||||
val_label = []
|
||||
for x_batch, y_batch in valid_loader:
|
||||
x_batch = x_batch.to(self.device)
|
||||
y_batch = y_batch.to(self.device)
|
||||
cur_loss = self.get_loss(preds, y_batch, self.loss_type)
|
||||
val_loss.append(cur_loss.detach().cpu().numpy().item())
|
||||
val_loss = np.mean(val_loss)
|
||||
val_pred = torch.cat(val_pred, axis=0).detach().cpu().numpy()
|
||||
val_label = torch.cat(val_label, axis=0).detach().cpu().numpy()
|
||||
val_metric = self.get_metric(val_pred, val_label, val_index).detach().cpu().numpy().item()
|
||||
return val_loss, val_metric
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: Union[DatasetH, TSDatasetH],
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
ists = isinstance(dataset, TSDatasetH) # is this time series dataset
|
||||
|
||||
# prepare training
|
||||
train_x = dataset.prepare("train", col_set="feature", data_key=DataHandlerLP.DK_L)
|
||||
train_y = dataset.prepare("train", col_set="label", data_key=DataHandlerLP.DK_L)
|
||||
train_ds = StackDataset(train_x, train_y)
|
||||
train_sampler = RandomSampler(train_ds)
|
||||
train_loader = DataLoader(train_ds, batch_size=self.batch_size, sampler=train_sampler)
|
||||
|
||||
# prepare validation
|
||||
valid_x = dataset.prepare("train", col_set="feature", data_key=DataHandlerLP.DK_L)
|
||||
valid_y = dataset.prepare("train", col_set="label", data_key=DataHandlerLP.DK_L)
|
||||
valid_ds = StackDataset(valid_x, valid_y)
|
||||
valid_loader = DataLoader(valid_ds, batch_size=self.batch_size, shuffle=False)
|
||||
if ists:
|
||||
val_index = valid_x.data_index
|
||||
else:
|
||||
val_index = valid_x.index
|
||||
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_loss = np.inf
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
|
||||
|
||||
for step in range(1, self.max_steps + 1):
|
||||
if stop_steps >= self.early_stop_rounds:
|
||||
if verbose:
|
||||
self.logger.info("\tearly stop")
|
||||
break
|
||||
loss = AverageMeter()
|
||||
self.dnn_model.train()
|
||||
self.train_optimizer.zero_grad()
|
||||
|
||||
for x_batch, y_batch in train_loader:
|
||||
x_batch = x_batch.to(self.device)
|
||||
y_batch = y_batch.to(self.device)
|
||||
|
||||
# forward
|
||||
preds = self.dnn_model(x_batch)
|
||||
cur_loss = self.get_loss(preds, y_batch, self.loss_type)
|
||||
cur_loss.backward()
|
||||
self.train_optimizer.step()
|
||||
loss.update(cur_loss.item())
|
||||
R.log_metrics(train_loss=loss.avg, step=step)
|
||||
|
||||
# validation
|
||||
train_loss += loss.val
|
||||
# for every `eval_steps` steps or at the last steps, we will evaluate the model.
|
||||
if step % self.eval_steps == 0 or step == self.max_steps:
|
||||
stop_steps += 1
|
||||
train_loss /= self.eval_steps
|
||||
|
||||
val_loss, val_metric = self._eval_valid_dl(valid_loader, val_index)
|
||||
R.log_metrics(val_loss=val_loss, step=step)
|
||||
R.log_metrics(val_metric=val_metric, step=step)
|
||||
|
||||
if val_loss < best_loss:
|
||||
if verbose:
|
||||
self.logger.info(
|
||||
"\tvalid loss update from {:.6f} to {:.6f}, save checkpoint.".format(
|
||||
best_loss, val_loss
|
||||
)
|
||||
)
|
||||
best_loss = val_loss
|
||||
self.best_step = step
|
||||
R.log_metrics(best_step=self.best_step, step=step)
|
||||
stop_steps = 0
|
||||
torch.save(self.dnn_model.state_dict(), save_path)
|
||||
train_loss = 0
|
||||
# update learning rate
|
||||
if self.scheduler is not None:
|
||||
auto_filter_kwargs(self.scheduler.step, warning=False)(metrics=val_loss, epoch=step)
|
||||
R.log_metrics(lr=self.get_lr(), step=step)
|
||||
|
||||
# restore the optimal parameters after training
|
||||
self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device))
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_lr(self):
|
||||
assert len(self.train_optimizer.param_groups) == 1
|
||||
return self.train_optimizer.param_groups[0]["lr"]
|
||||
|
||||
def get_loss(self, pred, target, loss_type, w=None):
|
||||
pred, target = pred.reshape(-1), target.reshape(-1)
|
||||
if w is None:
|
||||
# make it ones and the same size with pred
|
||||
w = torch.ones_like(pred).to(pred.device)
|
||||
|
||||
if loss_type == "mse":
|
||||
sqr_loss = torch.mul(pred - target, pred - target)
|
||||
loss = torch.mul(sqr_loss, w).mean()
|
||||
return loss
|
||||
elif loss_type == "binary":
|
||||
loss = nn.BCEWithLogitsLoss(weight=w)
|
||||
return loss(pred, target)
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
|
||||
def get_metric(self, pred, target, index):
|
||||
# NOTE: the order of the index must follow <datetime, instrument> sorted order
|
||||
return -ICLoss()(pred, target, index) # pylint: disable=E1130
|
||||
|
||||
def _nn_predict(self, data, return_cpu=True):
|
||||
"""Reusing predicting NN.
|
||||
Scenarios
|
||||
1) test inference (data may come from CPU and expect the output data is on CPU)
|
||||
2) evaluation on training (data may come from GPU)
|
||||
"""
|
||||
if not isinstance(data, torch.Tensor):
|
||||
if isinstance(data, pd.DataFrame):
|
||||
data = data.values
|
||||
data = torch.Tensor(data)
|
||||
data = data.to(self.device)
|
||||
preds = []
|
||||
self.dnn_model.eval()
|
||||
with torch.no_grad():
|
||||
batch_size = 8096
|
||||
for i in range(0, len(data), batch_size):
|
||||
x = data[i : i + batch_size]
|
||||
preds.append(self.dnn_model(x.to(self.device)).detach().reshape(-1))
|
||||
if return_cpu:
|
||||
preds = np.concatenate([pr.cpu().numpy() for pr in preds])
|
||||
else:
|
||||
preds = torch.cat(preds, axis=0)
|
||||
return preds
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
preds = self._nn_predict(x_test_pd)
|
||||
return pd.Series(preds.reshape(-1), index=x_test_pd.index)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
from ...model.utils import ConcatDataset
|
||||
|
||||
|
||||
class GeneralPTNN(Model):
|
||||
"""
|
||||
Motivation:
|
||||
@@ -381,16 +56,17 @@ class GeneralPTNN(Model):
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
weight_decay=0.0,
|
||||
optimizer="adam",
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
|
||||
pt_model_kwargs={
|
||||
"d_feat":6,
|
||||
"hidden_size":64,
|
||||
"num_layers":2,
|
||||
"dropout":0.,
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.0,
|
||||
},
|
||||
):
|
||||
# Set logger.
|
||||
@@ -405,6 +81,7 @@ class GeneralPTNN(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.weight_decay = weight_decay
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.seed = seed
|
||||
@@ -424,6 +101,7 @@ class GeneralPTNN(Model):
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nweight_decay : {}"
|
||||
"\nseed : {}"
|
||||
"\npt_model_uri: {}"
|
||||
"\npt_model_kwargs: {}".format(
|
||||
@@ -437,11 +115,11 @@ class GeneralPTNN(Model):
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
weight_decay,
|
||||
seed,
|
||||
pt_model_uri,
|
||||
pt_model_kwargs,
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
@@ -452,9 +130,9 @@ class GeneralPTNN(Model):
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr)
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr)
|
||||
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
@@ -488,7 +166,6 @@ class GeneralPTNN(Model):
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
||||
def _get_fl(self, data: torch.Tensor):
|
||||
"""
|
||||
get feature and label from data
|
||||
@@ -521,7 +198,7 @@ class GeneralPTNN(Model):
|
||||
self.dnn_model.train()
|
||||
|
||||
for data, weight in data_loader:
|
||||
feature , label = self._get_fl(data)
|
||||
feature, label = self._get_fl(data)
|
||||
|
||||
pred = self.dnn_model(feature.float())
|
||||
loss = self.loss_fn(pred, label, weight.to(self.device))
|
||||
@@ -538,9 +215,7 @@ class GeneralPTNN(Model):
|
||||
losses = []
|
||||
|
||||
for data, weight in data_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
feature, label = self._get_fl(data)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.dnn_model(feature.float())
|
||||
@@ -624,6 +299,8 @@ class GeneralPTNN(Model):
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if step == 0:
|
||||
best_param = copy.deepcopy(self.dnn_model.state_dict())
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
@@ -642,22 +319,40 @@ class GeneralPTNN(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: Union[DatasetH, TSDatasetH]):
|
||||
def predict(
|
||||
self,
|
||||
dataset: Union[DatasetH, TSDatasetH],
|
||||
batch_size=None,
|
||||
n_jobs=None,
|
||||
):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
|
||||
if isinstance(dataset, TSDatasetH):
|
||||
dl_test.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
index = dl_test.get_index()
|
||||
else:
|
||||
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
|
||||
index = dl_test.index
|
||||
dl_test = dl_test.values
|
||||
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.dnn_model.eval()
|
||||
preds = []
|
||||
|
||||
for data in test_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
feature, _ = self._get_fl(data)
|
||||
feature = feature.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.dnn_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=dl_test.get_index())
|
||||
preds_concat = np.concatenate(preds)
|
||||
if preds_concat.ndim != 1:
|
||||
preds_concat = preds_concat.ravel()
|
||||
|
||||
return pd.Series(preds_concat, index=index)
|
||||
|
||||
@@ -52,7 +52,7 @@ class GRU(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GRU")
|
||||
@@ -317,7 +317,6 @@ class GRU(Model):
|
||||
|
||||
|
||||
class GRUModel(nn.Module):
|
||||
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ class GRU(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GRU")
|
||||
|
||||
@@ -59,7 +59,7 @@ class HIST(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("HIST")
|
||||
@@ -256,7 +256,7 @@ class HIST(Model):
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
if not os.path.exists(self.stock2concept):
|
||||
url = "http://fintech.msra.cn/stock_data/downloads/qlib_csi300_stock2concept.npy"
|
||||
url = "https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/qlib_csi300_stock2concept.npy"
|
||||
urllib.request.urlretrieve(url, self.stock2concept)
|
||||
|
||||
stock_index = np.load(self.stock_index, allow_pickle=True).item()
|
||||
|
||||
@@ -55,7 +55,7 @@ class IGMTF(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("IGMTF")
|
||||
|
||||
@@ -255,7 +255,7 @@ class KRNN(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("KRNN")
|
||||
|
||||
@@ -44,7 +44,7 @@ class LocalformerModel(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
|
||||
@@ -42,7 +42,7 @@ class LocalformerModel(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
|
||||
@@ -51,7 +51,7 @@ class LSTM(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("LSTM")
|
||||
|
||||
@@ -53,7 +53,7 @@ class LSTM(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("LSTM")
|
||||
|
||||
@@ -35,7 +35,7 @@ class SandwichModel(nn.Module):
|
||||
rnn_layers,
|
||||
dropout,
|
||||
device,
|
||||
**params
|
||||
**params,
|
||||
):
|
||||
"""Build a Sandwich model
|
||||
|
||||
@@ -129,7 +129,7 @@ class Sandwich(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("Sandwich")
|
||||
|
||||
@@ -212,7 +212,7 @@ class SFM(Model):
|
||||
optimizer="gd",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("SFM")
|
||||
|
||||
@@ -56,7 +56,7 @@ class TCN(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCN")
|
||||
|
||||
@@ -54,7 +54,7 @@ class TCN(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCN")
|
||||
|
||||
@@ -58,7 +58,7 @@ class TCTS(Model):
|
||||
mode="soft",
|
||||
seed=None,
|
||||
lowest_valid_performance=0.993,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCTS")
|
||||
|
||||
@@ -43,7 +43,7 @@ class TransformerModel(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
|
||||
@@ -41,7 +41,7 @@ class TransformerModel(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
|
||||
@@ -28,7 +28,7 @@ class XGBModel(Model, FeatureInt):
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
reweighter=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
@@ -63,7 +63,7 @@ class XGBModel(Model, FeatureInt):
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import yaml
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
import shutil
|
||||
from ruamel.yaml import YAML
|
||||
from ...backtest.account import Account
|
||||
from .user import User
|
||||
from .utils import load_instance, save_instance
|
||||
@@ -110,7 +110,8 @@ class UserManager:
|
||||
raise ValueError("User data for {} already exists".format(user_id))
|
||||
|
||||
with config_file.open("r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(fp)
|
||||
# load model
|
||||
model = init_instance_by_config(config["model"])
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
|
||||
import pathlib
|
||||
import pickle
|
||||
import yaml
|
||||
import pandas as pd
|
||||
from ruamel.yaml import YAML
|
||||
from ...data import D
|
||||
from ...config import C
|
||||
from ...log import get_module_logger
|
||||
@@ -91,7 +91,8 @@ def prepare(um, today, user_id, exchange_config=None):
|
||||
dates.append(get_next_trading_date(dates[-1], future=True))
|
||||
if exchange_config:
|
||||
with pathlib.Path(exchange_config).open("r") as fp:
|
||||
exchange_paras = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
exchange_paras = yaml.load(fp)
|
||||
else:
|
||||
exchange_paras = {}
|
||||
trade_exchange = Exchange(trade_dates=dates, **exchange_paras)
|
||||
|
||||
@@ -176,7 +176,7 @@ class HeatmapGraph(BaseGraph):
|
||||
x=self._df.columns,
|
||||
y=self._df.index,
|
||||
z=self._df.values.tolist(),
|
||||
**self._graph_kwargs
|
||||
**self._graph_kwargs,
|
||||
)
|
||||
]
|
||||
return _data
|
||||
@@ -213,7 +213,7 @@ class SubplotsGraph:
|
||||
sub_graph_layout: dict = None,
|
||||
sub_graph_data: list = None,
|
||||
subplots_kwargs: dict = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -355,7 +355,7 @@ class SubplotsGraph:
|
||||
df=self._df.loc[:, [column_name]],
|
||||
name_dict={column_name: temp_name},
|
||||
graph_kwargs=_graph_kwargs,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise TypeError()
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# Licensed under the MIT License.
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from ruamel.yaml import YAML
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
import yaml
|
||||
|
||||
from qlib import auto_init
|
||||
from qlib.log import get_module_logger
|
||||
@@ -117,7 +117,8 @@ class Rolling:
|
||||
|
||||
def _raw_conf(self) -> dict:
|
||||
with self.conf_path.open("r") as f:
|
||||
return yaml.safe_load(f)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
return yaml.load(f)
|
||||
|
||||
def _replace_handler_with_cache(self, task: dict):
|
||||
"""
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
|
||||
import yaml
|
||||
import copy
|
||||
import os
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
|
||||
class TunerConfigManager:
|
||||
@@ -16,7 +16,8 @@ class TunerConfigManager:
|
||||
self.config_path = config_path
|
||||
|
||||
with open(config_path) as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(fp)
|
||||
self.config = copy.deepcopy(config)
|
||||
|
||||
self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self)
|
||||
|
||||
@@ -41,6 +41,7 @@ class DataLoader(abc.ABC):
|
||||
----------
|
||||
instruments : str or dict
|
||||
it can either be the market name or the config file of instruments generated by InstrumentProvider.
|
||||
If the value of instruments is None, it means that no filtering is done.
|
||||
start_time : str
|
||||
start of the time range.
|
||||
end_time : str
|
||||
@@ -50,6 +51,11 @@ class DataLoader(abc.ABC):
|
||||
-------
|
||||
pd.DataFrame:
|
||||
data load from the under layer source
|
||||
|
||||
Raise
|
||||
-----
|
||||
KeyError:
|
||||
if the instruments filter is not supported, raise KeyError
|
||||
"""
|
||||
|
||||
|
||||
@@ -320,7 +326,13 @@ class NestedDataLoader(DataLoader):
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
df_full = None
|
||||
for dl in self.data_loader_l:
|
||||
df_current = dl.load(instruments, start_time, end_time)
|
||||
try:
|
||||
df_current = dl.load(instruments, start_time, end_time)
|
||||
except KeyError:
|
||||
warnings.warn(
|
||||
"If the value of `instruments` cannot be processed, it will set instruments to None to get all the data."
|
||||
)
|
||||
df_current = dl.load(instruments=None, start_time=start_time, end_time=end_time)
|
||||
if df_full is None:
|
||||
df_full = df_current
|
||||
else:
|
||||
|
||||
@@ -104,15 +104,24 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
"""
|
||||
|
||||
stock_selector = slice(None)
|
||||
time_selector = slice(None) # by default not filter by time.
|
||||
|
||||
if level is None:
|
||||
# For directly applying.
|
||||
if isinstance(selector, tuple) and self.stock_level < len(selector):
|
||||
# full selector format
|
||||
stock_selector = selector[self.stock_level]
|
||||
time_selector = selector[1 - self.stock_level]
|
||||
elif isinstance(selector, (list, str)) and self.stock_level == 0:
|
||||
# only stock selector
|
||||
stock_selector = selector
|
||||
elif level in ("instrument", self.stock_level):
|
||||
if isinstance(selector, tuple):
|
||||
# NOTE: How could the stock level selector be a tuple?
|
||||
stock_selector = selector[0]
|
||||
raise TypeError(
|
||||
"I forget why would this case appear. But I think it does not make sense. So we raise a error for that case."
|
||||
)
|
||||
elif isinstance(selector, (list, str)):
|
||||
stock_selector = selector
|
||||
|
||||
@@ -120,7 +129,7 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
|
||||
|
||||
if stock_selector == slice(None):
|
||||
return self.hash_df
|
||||
return self.hash_df, time_selector
|
||||
|
||||
if isinstance(stock_selector, str):
|
||||
stock_selector = [stock_selector]
|
||||
@@ -129,7 +138,7 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
for each_stock in sorted(stock_selector):
|
||||
if each_stock in self.hash_df:
|
||||
select_dict[each_stock] = self.hash_df[each_stock]
|
||||
return select_dict
|
||||
return select_dict, time_selector
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
@@ -138,10 +147,13 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
|
||||
fetch_stock_df_list, time_selector = self._fetch_hash_df_by_stock(selector=selector, level=level)
|
||||
fetch_stock_df_list = list(fetch_stock_df_list.values())
|
||||
for _index, stock_df in enumerate(fetch_stock_df_list):
|
||||
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
|
||||
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig)
|
||||
fetch_index_df = fetch_df_by_index(
|
||||
df=fetch_col_df, selector=time_selector, level="datetime", fetch_orig=fetch_orig
|
||||
)
|
||||
fetch_stock_df_list[_index] = fetch_index_df
|
||||
if len(fetch_stock_df_list) == 0:
|
||||
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
|
||||
|
||||
@@ -164,6 +164,7 @@ class SeriesDFilter(BaseDFilter):
|
||||
timestamp = []
|
||||
_lbool = None
|
||||
_ltime = None
|
||||
_cur_start = None
|
||||
for _ts, _bool in timestamp_series.items():
|
||||
# there is likely to be NAN when the filter series don't have the
|
||||
# bool value, so we just change the NAN into False
|
||||
|
||||
@@ -7,8 +7,7 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from importlib import import_module
|
||||
|
||||
import yaml
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
|
||||
DELETE_KEY = "_delete_"
|
||||
@@ -57,7 +56,8 @@ def parse_backtest_config(path: str) -> dict:
|
||||
del sys.modules[tmp_module_name]
|
||||
else:
|
||||
with open(tmp_config_file.name) as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(input_stream)
|
||||
|
||||
if "_base_" in config:
|
||||
base_file_name = config.pop("_base_")
|
||||
|
||||
@@ -8,12 +8,12 @@ import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from ruamel.yaml import YAML
|
||||
from typing import cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
@@ -263,6 +263,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_path, "r") as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(input_stream)
|
||||
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
|
||||
@@ -10,7 +10,6 @@ import os
|
||||
import re
|
||||
import copy
|
||||
import json
|
||||
import yaml
|
||||
import redis
|
||||
import bisect
|
||||
import struct
|
||||
@@ -25,6 +24,7 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Callable
|
||||
from packaging import version
|
||||
from ruamel.yaml import YAML
|
||||
from .file import (
|
||||
get_or_create_path,
|
||||
save_multiple_parts_file,
|
||||
@@ -244,12 +244,13 @@ def parse_config(config):
|
||||
if not isinstance(config, str):
|
||||
return config
|
||||
# Check whether config is file
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
if os.path.exists(config):
|
||||
with open(config, "r") as f:
|
||||
return yaml.safe_load(f)
|
||||
return yaml.load(f)
|
||||
# Check whether the str can be parsed
|
||||
try:
|
||||
return yaml.safe_load(config)
|
||||
return yaml.load(config)
|
||||
except BaseException as base_exp:
|
||||
raise ValueError("cannot parse config!") from base_exp
|
||||
|
||||
@@ -799,6 +800,7 @@ def fill_placeholder(config: dict, config_extend: dict):
|
||||
)
|
||||
return value
|
||||
|
||||
item_keys = None
|
||||
while top < tail:
|
||||
now_item = item_queue[top]
|
||||
top += 1
|
||||
|
||||
@@ -44,7 +44,7 @@ def concat(data_list: Union[SingleData], axis=0) -> MultiData:
|
||||
all_index_map = dict(zip(all_index, range(len(all_index))))
|
||||
|
||||
# concat all
|
||||
tmp_data = np.full((len(all_index), len(data_list)), np.NaN)
|
||||
tmp_data = np.full((len(all_index), len(data_list)), np.nan)
|
||||
for data_id, index_data in enumerate(data_list):
|
||||
assert isinstance(index_data, SingleData)
|
||||
now_data_map = [all_index_map[index] for index in index_data.index]
|
||||
@@ -64,7 +64,7 @@ def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) ->
|
||||
new_index : list
|
||||
the new_index of new SingleData.
|
||||
fill_value : float
|
||||
fill the missing values or replace np.NaN.
|
||||
fill the missing values or replace np.nan.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -444,7 +444,7 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
return self.__class__(~self.data.astype(bool), *self.indices)
|
||||
|
||||
def abs(self):
|
||||
"""get the abs of data except np.NaN."""
|
||||
"""get the abs of data except np.nan."""
|
||||
tmp_data = np.absolute(self.data)
|
||||
return self.__class__(tmp_data, *self.indices)
|
||||
|
||||
@@ -566,8 +566,8 @@ class SingleData(IndexData):
|
||||
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
|
||||
)
|
||||
|
||||
def reindex(self, index: Index, fill_value=np.NaN) -> SingleData:
|
||||
"""reindex data and fill the missing value with np.NaN.
|
||||
def reindex(self, index: Index, fill_value=np.nan) -> SingleData:
|
||||
"""reindex data and fill the missing value with np.nan.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -615,7 +615,7 @@ class SingleData(IndexData):
|
||||
return pd.Series(self.data, index=self.index)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(pd.Series(self.data, index=self.index))
|
||||
return str(pd.Series(self.data, index=self.index.tolist()))
|
||||
|
||||
|
||||
class MultiData(IndexData):
|
||||
@@ -651,4 +651,4 @@ class MultiData(IndexData):
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(pd.DataFrame(self.data, index=self.index, columns=self.columns))
|
||||
return str(pd.DataFrame(self.data, index=self.index.tolist(), columns=self.columns.tolist()))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import threading
|
||||
from functools import partial
|
||||
from threading import Thread
|
||||
from typing import Callable, Text, Union
|
||||
@@ -9,7 +10,7 @@ from joblib import Parallel, delayed
|
||||
from joblib._parallel_backends import MultiprocessingBackend
|
||||
import pandas as pd
|
||||
|
||||
from queue import Queue
|
||||
from queue import Empty, Queue
|
||||
import concurrent
|
||||
|
||||
from qlib.config import C, QlibConfig
|
||||
@@ -85,7 +86,17 @@ class AsyncCaller:
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
data = self._q.get()
|
||||
# NOTE:
|
||||
# atexit will only trigger when all the threads ended. So it may results in deadlock.
|
||||
# So the child-threading should actively watch the status of main threading to stop itself.
|
||||
main_thread = threading.main_thread()
|
||||
if not main_thread.is_alive():
|
||||
break
|
||||
try:
|
||||
data = self._q.get(timeout=1)
|
||||
except Empty:
|
||||
# NOTE: avoid deadlock. make checking main thread possible
|
||||
continue
|
||||
if data == self.STOP_MARK:
|
||||
break
|
||||
data()
|
||||
|
||||
@@ -7,7 +7,7 @@ import sys
|
||||
|
||||
import fire
|
||||
from jinja2 import Template, meta
|
||||
import ruamel.yaml as yaml
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import qlib
|
||||
from qlib.config import C
|
||||
@@ -104,7 +104,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
"""
|
||||
# Render the template
|
||||
rendered_yaml = render_template(config_path)
|
||||
config = yaml.safe_load(rendered_yaml)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(rendered_yaml)
|
||||
|
||||
base_config_path = config.get("BASE_CONFIG_PATH", None)
|
||||
if base_config_path:
|
||||
@@ -126,7 +127,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}")
|
||||
|
||||
with open(path) as fp:
|
||||
base_config = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
base_config = yaml.load(fp)
|
||||
logger.info(f"Load BASE_CONFIG_PATH succeed: {path.resolve()}")
|
||||
config = update_config(base_config, config)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCod
|
||||
from mlflow.entities import ViewType
|
||||
import os
|
||||
from typing import Optional, Text
|
||||
from pathlib import Path
|
||||
|
||||
from .exp import MLflowExperiment, Experiment
|
||||
from ..config import C
|
||||
@@ -233,7 +234,7 @@ class ExpManager:
|
||||
# So we supported it in the interface wrapper
|
||||
pr = urlparse(self.uri)
|
||||
if pr.scheme == "file":
|
||||
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")): # pylint: disable=E0110
|
||||
with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110
|
||||
return self.create_exp(experiment_name), True
|
||||
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
|
||||
try:
|
||||
@@ -421,7 +422,11 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
def list_experiments(self):
|
||||
# retrieve all the existing experiments
|
||||
exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY)
|
||||
mlflow_version = int(mlflow.__version__.split(".", maxsplit=1)[0])
|
||||
if mlflow_version >= 2:
|
||||
exps = self.client.search_experiments(view_type=ViewType.ACTIVE_ONLY)
|
||||
else:
|
||||
exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY) # pylint: disable=E1101
|
||||
experiments = dict()
|
||||
for exp in exps:
|
||||
experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)
|
||||
|
||||
@@ -9,6 +9,7 @@ import shutil
|
||||
import pickle
|
||||
import tempfile
|
||||
import subprocess
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
@@ -316,7 +317,10 @@ class MLflowRecorder(Recorder):
|
||||
This function will return the directory path of this recorder.
|
||||
"""
|
||||
if self.artifact_uri is not None:
|
||||
local_dir_path = Path(self.artifact_uri.lstrip("file:")) / ".."
|
||||
if platform.system() == "Windows":
|
||||
local_dir_path = Path(self.artifact_uri.lstrip("file:").lstrip("/")).parent
|
||||
else:
|
||||
local_dir_path = Path(self.artifact_uri.lstrip("file:")).parent
|
||||
local_dir_path = str(local_dir_path.resolve())
|
||||
if os.path.isdir(local_dir_path):
|
||||
return local_dir_path
|
||||
|
||||
@@ -28,7 +28,7 @@ termcolor==1.1.0
|
||||
tqdm==4.63.0
|
||||
trio==0.20.0
|
||||
trio-websocket==0.9.2
|
||||
urllib3==1.26.8
|
||||
urllib3==1.26.19
|
||||
wget==3.2
|
||||
wsproto==1.1.0
|
||||
yahooquery==2.2.15
|
||||
|
||||
195
setup.py
195
setup.py
@@ -1,9 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import os
|
||||
from setuptools import setup, Extension
|
||||
import numpy
|
||||
|
||||
from setuptools import find_packages, setup, Extension
|
||||
import os
|
||||
|
||||
|
||||
def read(rel_path: str) -> str:
|
||||
@@ -20,185 +17,25 @@ def get_version(rel_path: str) -> str:
|
||||
raise RuntimeError("Unable to find version string.")
|
||||
|
||||
|
||||
# Package meta-data.
|
||||
NAME = "pyqlib"
|
||||
DESCRIPTION = "A Quantitative-research Platform"
|
||||
REQUIRES_PYTHON = ">=3.5.0"
|
||||
NUMPY_INCLUDE = numpy.get_include()
|
||||
|
||||
VERSION = get_version("qlib/__init__.py")
|
||||
|
||||
# Detect Cython
|
||||
try:
|
||||
import Cython
|
||||
|
||||
ver = Cython.__version__
|
||||
_CYTHON_INSTALLED = ver >= "0.28"
|
||||
except ImportError:
|
||||
_CYTHON_INSTALLED = False
|
||||
|
||||
if not _CYTHON_INSTALLED:
|
||||
print("Required Cython version >= 0.28 is not detected!")
|
||||
print('Please run "pip install --upgrade cython" first.')
|
||||
exit(-1)
|
||||
|
||||
# What packages are required for this module to be executed?
|
||||
# `estimator` may depend on other packages. In order to reduce dependencies, it is not written here.
|
||||
REQUIRED = [
|
||||
"numpy>=1.12.0, <1.24",
|
||||
"pandas>=0.25.1",
|
||||
"scipy>=1.7.3",
|
||||
"requests>=2.18.0",
|
||||
"sacred>=0.7.4",
|
||||
"python-socketio",
|
||||
"redis>=3.0.1",
|
||||
"python-redis-lock>=3.3.1",
|
||||
"schedule>=0.6.0",
|
||||
"cvxpy>=1.0.21",
|
||||
"hyperopt==0.1.2",
|
||||
"fire>=0.3.1",
|
||||
"statsmodels",
|
||||
"xlrd>=1.0.0",
|
||||
"plotly>=4.12.0",
|
||||
"matplotlib>=3.3",
|
||||
"tables>=3.6.1",
|
||||
"pyyaml>=5.3.1",
|
||||
# To ensure stable operation of the experiment manager, we have limited the version of mlflow,
|
||||
# and we need to verify whether version 2.0 of mlflow can serve qlib properly.
|
||||
"mlflow>=1.12.1, <=1.30.0",
|
||||
# mlflow 1.30.0 requires packaging<22, so we limit the packaging version, otherwise the CI will fail.
|
||||
"packaging<22",
|
||||
"tqdm",
|
||||
"loguru",
|
||||
"lightgbm>=3.3.0",
|
||||
"tornado",
|
||||
"joblib>=0.17.0",
|
||||
# With the upgrading of ruamel.yaml to 0.18, the safe_load method was deprecated,
|
||||
# which would cause qlib.workflow.cli to not work properly,
|
||||
# and no good replacement has been found, so the version of ruamel.yaml has been restricted for now.
|
||||
# Refs: https://pypi.org/project/ruamel.yaml/
|
||||
"ruamel.yaml<=0.17.36",
|
||||
"pymongo==3.7.2", # For task management
|
||||
"scikit-learn>=0.22",
|
||||
"dill",
|
||||
"dataclasses;python_version<'3.7'",
|
||||
"filelock",
|
||||
"jinja2",
|
||||
"gym",
|
||||
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
|
||||
"protobuf<=3.20.1;python_version<='3.8'",
|
||||
"cryptography",
|
||||
]
|
||||
|
||||
# Numpy include
|
||||
NUMPY_INCLUDE = numpy.get_include()
|
||||
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
|
||||
long_description = f.read()
|
||||
|
||||
|
||||
# Cython Extensions
|
||||
extensions = [
|
||||
Extension(
|
||||
"qlib.data._libs.rolling",
|
||||
["qlib/data/_libs/rolling.pyx"],
|
||||
language="c++",
|
||||
include_dirs=[NUMPY_INCLUDE],
|
||||
),
|
||||
Extension(
|
||||
"qlib.data._libs.expanding",
|
||||
["qlib/data/_libs/expanding.pyx"],
|
||||
language="c++",
|
||||
include_dirs=[NUMPY_INCLUDE],
|
||||
),
|
||||
]
|
||||
|
||||
# Where the magic happens:
|
||||
setup(
|
||||
name=NAME,
|
||||
version=VERSION,
|
||||
license="MIT Licence",
|
||||
url="https://github.com/microsoft/qlib",
|
||||
description=DESCRIPTION,
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
python_requires=REQUIRES_PYTHON,
|
||||
packages=find_packages(exclude=("tests",)),
|
||||
# if your package is a single module, use this instead of 'packages':
|
||||
# py_modules=['qlib'],
|
||||
entry_points={
|
||||
# 'console_scripts': ['mycli=mymodule:cli'],
|
||||
"console_scripts": [
|
||||
"qrun=qlib.workflow.cli:run",
|
||||
],
|
||||
},
|
||||
ext_modules=extensions,
|
||||
install_requires=REQUIRED,
|
||||
extras_require={
|
||||
"dev": [
|
||||
"coverage",
|
||||
"pytest>=3",
|
||||
"sphinx",
|
||||
"sphinx_rtd_theme",
|
||||
"pre-commit",
|
||||
# CI dependencies
|
||||
"wheel",
|
||||
"setuptools",
|
||||
"black",
|
||||
# Version 3.0 of pylint had problems with the build process, so we limited the version of pylint.
|
||||
"pylint<=2.17.6",
|
||||
# Using the latest versions(0.981 and 0.982) of mypy,
|
||||
# the error "multiprocessing.Value()" is detected in the file "qlib/rl/utils/data_queue.py",
|
||||
# If this is fixed in a subsequent version of mypy, then we will revert to the latest version of mypy.
|
||||
# References: https://github.com/python/typeshed/issues/8799
|
||||
"mypy<0.981",
|
||||
"flake8",
|
||||
"nbqa",
|
||||
"jupyter",
|
||||
"nbconvert",
|
||||
# The 5.0.0 version of importlib-metadata removed the deprecated endpoint,
|
||||
# which prevented flake8 from working properly, so we restricted the version of importlib-metadata.
|
||||
# To help ensure the dependencies of flake8 https://github.com/python/importlib_metadata/issues/406
|
||||
"importlib-metadata<5.0.0",
|
||||
"readthedocs_sphinx_ext",
|
||||
"cmake",
|
||||
"lxml",
|
||||
"baostock",
|
||||
"yahooquery",
|
||||
# 2024-05-30 scs has released a new version: 3.2.4.post2,
|
||||
# this version, causes qlib installation to fail, so we've limited the scs version a bit for now.
|
||||
"scs<=3.2.4",
|
||||
"beautifulsoup4",
|
||||
# In version 0.4.11 of tianshou, the code:
|
||||
# logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
# was changed in PR787,
|
||||
# which causes pytest errors(AttributeError: 'dict' object has no attribute 'info') in CI,
|
||||
# so we restricted the version of tianshou.
|
||||
# References:
|
||||
# https://github.com/thu-ml/tianshou/releases
|
||||
"tianshou<=0.4.10",
|
||||
"gym>=0.24", # If you do not put gym at the end, gym will degrade causing pytest results to fail.
|
||||
],
|
||||
"rl": [
|
||||
"tianshou<=0.4.10",
|
||||
"torch",
|
||||
],
|
||||
},
|
||||
include_package_data=True,
|
||||
classifiers=[
|
||||
# Trove classifiers
|
||||
# Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
|
||||
# 'License :: OSI Approved :: MIT License',
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Operating System :: MacOS",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
ext_modules=[
|
||||
Extension(
|
||||
"qlib.data._libs.rolling",
|
||||
["qlib/data/_libs/rolling.pyx"],
|
||||
language="c++",
|
||||
include_dirs=[NUMPY_INCLUDE],
|
||||
),
|
||||
Extension(
|
||||
"qlib.data._libs.expanding",
|
||||
["qlib/data/_libs/expanding.pyx"],
|
||||
language="c++",
|
||||
include_dirs=[NUMPY_INCLUDE],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -7,14 +7,16 @@ import qlib
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent))
|
||||
from qlib.data.dataset.loader import NestedDataLoader
|
||||
from qlib.data.dataset.loader import NestedDataLoader, QlibDataLoader
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.contrib.data.loader import Alpha158DL, Alpha360DL
|
||||
from qlib.data import D
|
||||
|
||||
|
||||
class TestDataLoader(unittest.TestCase):
|
||||
|
||||
def test_nested_data_loader(self):
|
||||
qlib.init()
|
||||
qlib.init(kernels=1)
|
||||
nd = NestedDataLoader(
|
||||
dataloader_l=[
|
||||
{
|
||||
@@ -28,7 +30,7 @@ class TestDataLoader(unittest.TestCase):
|
||||
)
|
||||
# Of course you can use StaticDataLoader
|
||||
|
||||
dataset = nd.load()
|
||||
dataset = nd.load(start_time="2020-01-01", end_time="2020-01-31")
|
||||
|
||||
assert dataset is not None
|
||||
|
||||
@@ -44,6 +46,35 @@ class TestDataLoader(unittest.TestCase):
|
||||
assert "LABEL0" in columns_list
|
||||
|
||||
# Then you can use it wth DataHandler;
|
||||
# NOTE: please note that the data processors are missing!!! You should add based on your requirements
|
||||
|
||||
"""
|
||||
dataset.to_pickle("test_df.pkl")
|
||||
nested_data_loader = NestedDataLoader(
|
||||
dataloader_l=[
|
||||
{
|
||||
"class": "qlib.contrib.data.loader.Alpha158DL",
|
||||
"kwargs": {"config": {"label": (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])}},
|
||||
},
|
||||
{
|
||||
"class": "qlib.contrib.data.loader.Alpha360DL",
|
||||
},
|
||||
{
|
||||
"class": "qlib.data.dataset.loader.StaticDataLoader",
|
||||
"kwargs": {"config": "test_df.pkl"},
|
||||
},
|
||||
]
|
||||
)
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"instruments": "csi300",
|
||||
"data_loader": nested_data_loader,
|
||||
}
|
||||
data_handler = DataHandlerLP(**data_handler_config)
|
||||
data = data_handler.fetch()
|
||||
print(data)
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import unittest
|
||||
import platform
|
||||
import mlflow
|
||||
import time
|
||||
from pathlib import Path
|
||||
@@ -26,7 +27,10 @@ class MLflowTest(unittest.TestCase):
|
||||
_ = mlflow.tracking.MlflowClient(tracking_uri=str(self.TMP_PATH))
|
||||
end = time.time()
|
||||
elapsed = end - start
|
||||
self.assertLess(elapsed, 1e-2) # it can be done in less than 10ms
|
||||
if platform.system() == "Linux":
|
||||
self.assertLess(elapsed, 1e-2) # it can be done in less than 10ms
|
||||
else:
|
||||
self.assertLess(elapsed, 2e-2)
|
||||
print(elapsed)
|
||||
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class IndexDataTest(unittest.TestCase):
|
||||
print(sd.loc[:"c"])
|
||||
|
||||
def test_corner_cases(self):
|
||||
sd = idd.MultiData([[1, 2], [3, np.NaN]], index=["foo", "bar"], columns=["f", "g"])
|
||||
sd = idd.MultiData([[1, 2], [3, np.nan]], index=["foo", "bar"], columns=["f", "g"])
|
||||
print(sd)
|
||||
|
||||
self.assertTrue(np.isnan(sd.loc["bar", "g"]))
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from qlib.contrib.model.pytorch_general_nn import GeneralPTNN
|
||||
from qlib.data.dataset import DatasetH, TSDatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
class TestNN(TestAutoData):
|
||||
|
||||
def test_both_dataset(self):
|
||||
try:
|
||||
from qlib.contrib.model.pytorch_general_nn import GeneralPTNN
|
||||
from qlib.data.dataset import DatasetH, TSDatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
except ImportError:
|
||||
print("Import error.")
|
||||
return
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
@@ -18,35 +20,24 @@ class TestNN(TestAutoData):
|
||||
"class": "QlibDataLoader", # Assuming QlibDataLoader is a string reference to the class
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": [
|
||||
["$high", "$close", "$low"],
|
||||
["H", "C", "L"]
|
||||
],
|
||||
"label": [
|
||||
["Ref($close, -2)/Ref($close, -1) - 1"],
|
||||
["LABEL0"]
|
||||
]
|
||||
"feature": [["$high", "$close", "$low"], ["H", "C", "L"]],
|
||||
"label": [["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]],
|
||||
},
|
||||
"freq": "day"
|
||||
}
|
||||
"freq": "day",
|
||||
},
|
||||
},
|
||||
# TODO: processors
|
||||
"learn_processors": [
|
||||
{
|
||||
"class": "DropnaLabel",
|
||||
"class": "DropnaLabel",
|
||||
},
|
||||
{
|
||||
"class": "CSZScoreNorm",
|
||||
"kwargs": {
|
||||
"fields_group": "label"
|
||||
}
|
||||
}
|
||||
]
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
],
|
||||
}
|
||||
segments = {
|
||||
"train": ["2008-01-01", "2014-12-31"],
|
||||
"valid": ["2015-01-01", "2016-12-31"],
|
||||
"test": ["2017-01-01", "2020-08-01"]
|
||||
"test": ["2017-01-01", "2020-08-01"],
|
||||
}
|
||||
data_handler = DataHandlerLP(**data_handler_config)
|
||||
|
||||
@@ -59,27 +50,30 @@ class TestNN(TestAutoData):
|
||||
model_l = [
|
||||
GeneralPTNN(
|
||||
n_epochs=2,
|
||||
batch_size=32,
|
||||
n_jobs=0,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
|
||||
pt_model_kwargs={
|
||||
"d_feat":3,
|
||||
"hidden_size":8,
|
||||
"num_layers":1,
|
||||
"dropout":0.,
|
||||
"d_feat": 3,
|
||||
"hidden_size": 8,
|
||||
"num_layers": 1,
|
||||
"dropout": 0.0,
|
||||
},
|
||||
),
|
||||
GeneralPTNN(
|
||||
n_epochs=2,
|
||||
batch_size=32,
|
||||
n_jobs=0,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
|
||||
pt_model_kwargs={
|
||||
"input_dim":3,
|
||||
"input_dim": 3,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
for ds, model in reversed(list(zip((tsds, tbds), model_l))):
|
||||
|
||||
for ds, model in list(zip((tsds, tbds), model_l)):
|
||||
model.fit(ds) # It works
|
||||
model.predict(ds) # It works
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from random import randint, choice
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
import re
|
||||
from typing import Any, Tuple
|
||||
@@ -69,6 +69,10 @@ class AnyPolicy(BasePolicy):
|
||||
|
||||
def test_simple_env_logger(caplog):
|
||||
set_log_with_config(C.logging_config)
|
||||
# In order for caplog to capture log messages, we configure it here:
|
||||
# allow logs from the qlib logger to be passed to the parent logger.
|
||||
C.logging_config["loggers"]["qlib"]["propagate"] = True
|
||||
logging.config.dictConfig(C.logging_config)
|
||||
for venv_cls_name in ["dummy", "shmem", "subproc"]:
|
||||
writer = ConsoleWriter()
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||
@@ -80,13 +84,12 @@ def test_simple_env_logger(caplog):
|
||||
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
|
||||
assert output_file.columns.tolist() == ["reward", "a", "c"]
|
||||
assert len(output_file) >= 30
|
||||
|
||||
line_counter = 0
|
||||
for line in caplog.text.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
line_counter += 1
|
||||
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
|
||||
assert re.match(r".*reward .* {2}a .* \(([456])\.\d+\) {2}c .* \((14|15|16)\.\d+\)", line)
|
||||
assert line_counter >= 3
|
||||
|
||||
|
||||
@@ -137,15 +140,17 @@ class RandomFivePolicy(BasePolicy):
|
||||
|
||||
def test_logger_with_env_wrapper():
|
||||
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
|
||||
env_wrapper_factory = lambda: EnvWrapper(
|
||||
SimpleSimulator,
|
||||
DummyStateInterpreter(),
|
||||
DummyActionInterpreter(),
|
||||
data_iterator,
|
||||
logger=LogCollector(LogLevel.DEBUG),
|
||||
)
|
||||
|
||||
# loglevel can be debug here because metrics can all dump into csv
|
||||
def env_wrapper_factory():
|
||||
return EnvWrapper(
|
||||
SimpleSimulator,
|
||||
DummyStateInterpreter(),
|
||||
DummyActionInterpreter(),
|
||||
data_iterator,
|
||||
logger=LogCollector(LogLevel.DEBUG),
|
||||
)
|
||||
|
||||
# loglevel can be debugged here because metrics can all dump into csv
|
||||
# otherwise, csv writer might crash
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
|
||||
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
|
||||
@@ -155,7 +160,7 @@ def test_logger_with_env_wrapper():
|
||||
|
||||
output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
|
||||
assert len(output_df) == 20
|
||||
# obs has a increasing trend
|
||||
# obs has an increasing trend
|
||||
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
|
||||
assert (output_df["test_a"] == 233).all()
|
||||
assert (output_df["test_b"] == 200).all()
|
||||
|
||||
@@ -8,7 +8,6 @@ import shutil
|
||||
import unittest
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import baostock as bs
|
||||
from pathlib import Path
|
||||
|
||||
from qlib.data import D
|
||||
|
||||
Reference in New Issue
Block a user