mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
Compare commits
8 Commits
update_pub
...
extra_df_d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2366fe1345 | ||
|
|
6c2fa0fc71 | ||
|
|
3b6c227562 | ||
|
|
d2c68e0cc0 | ||
|
|
b4879fc9da | ||
|
|
3f86171051 | ||
|
|
ce596f9dfa | ||
|
|
13768d1dac |
@@ -1,8 +0,0 @@
|
|||||||
__pycache__
|
|
||||||
*.pyc
|
|
||||||
*.pyo
|
|
||||||
*.pyd
|
|
||||||
.Python
|
|
||||||
.env
|
|
||||||
.git
|
|
||||||
|
|
||||||
78
.github/workflows/python-publish.yml
vendored
78
.github/workflows/python-publish.yml
vendored
@@ -3,73 +3,79 @@
|
|||||||
|
|
||||||
name: Upload Python Package
|
name: Upload Python Package
|
||||||
|
|
||||||
# on:
|
|
||||||
# release:
|
|
||||||
# types: [published]
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
release:
|
||||||
branches: [ main ]
|
types: [published]
|
||||||
pull_request:
|
|
||||||
branches: [ main ]
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
deploy_with_bdist_wheel:
|
deploy_with_bdist_wheel:
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [windows-latest, macos-13, macos-latest, macos-15]
|
os: [windows-latest, macos-11]
|
||||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
# FIXME: macos-latest will raise error now.
|
||||||
exclude:
|
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||||
- os: macos-13
|
python-version: [3.7, 3.8]
|
||||||
python-version: "3.11"
|
|
||||||
- os: macos-13
|
|
||||||
python-version: "3.12"
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v2
|
||||||
|
# This is because on macos systems you can install pyqlib using
|
||||||
|
# `pip install pyqlib` installs, it does not recognize the
|
||||||
|
# `pyqlib-<version>-cp38-cp38-macosx_11_0_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_11_0_x86_64.whl`.
|
||||||
|
# So we limit the version of python, in order to generate a version of qlib that is usable for macos: `pyqlib-<veresion>-cp38-cp37m
|
||||||
|
# `pyqlib-<version>-cp38-cp38-macosx_10_15_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_10_15_x86_64.whl`.
|
||||||
|
# Python 3.7.16, 3.8.16 can build macosx_10_15. But Python 3.7.17, 3.8.17 can build macosx_11_0
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v4
|
if: matrix.os == 'macos-11' && matrix.python-version == '3.7'
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: "3.7.16"
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
if: matrix.os == 'macos-11' && matrix.python-version == '3.8'
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: "3.8.16"
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
if: matrix.os != 'macos-11'
|
||||||
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
make dev
|
python -m pip install --upgrade pip
|
||||||
|
pip install setuptools wheel twine
|
||||||
- name: Build wheel on ${{ matrix.os }}
|
- name: Build wheel on ${{ matrix.os }}
|
||||||
run: |
|
run: |
|
||||||
make build
|
pip install numpy
|
||||||
|
pip install cython
|
||||||
|
python setup.py bdist_wheel
|
||||||
- name: Build and publish
|
- name: Build and publish
|
||||||
env:
|
env:
|
||||||
TWINE_USERNAME: __token__
|
TWINE_USERNAME: __token__
|
||||||
TWINE_PASSWORD: ${{ secrets.TESTPYPI_TOKEN }}
|
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
ls dist
|
twine upload dist/*
|
||||||
twine check dist/*.whl
|
|
||||||
|
|
||||||
deploy_with_manylinux:
|
deploy_with_manylinux:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- name: Build wheel on Linux
|
- name: Build wheel on Linux
|
||||||
uses: RalfG/python-wheels-manylinux-build@v0.7.1-manylinux2014_x86_64
|
uses: RalfG/python-wheels-manylinux-build@v0.3.1-manylinux2010_x86_64
|
||||||
with:
|
with:
|
||||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||||
python-versions: 'cp38-cp38 cp39-cp39 cp310-cp310 cp311-cp311 cp312-cp312'
|
python-versions: 'cp37-cp37m cp38-cp38'
|
||||||
build-requirements: 'numpy cython'
|
build-requirements: 'numpy cython'
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: 3.7
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install twine
|
pip install twine
|
||||||
python -m pip list
|
|
||||||
- name: Build and publish
|
- name: Build and publish
|
||||||
env:
|
env:
|
||||||
TWINE_USERNAME: __token__
|
TWINE_USERNAME: __token__
|
||||||
TWINE_PASSWORD: ${{ secrets.TESTPYPI_TOKEN }}
|
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
ls dist
|
twine upload dist/pyqlib-*-manylinux*.whl
|
||||||
twine check dist/*.whl
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
28
.github/workflows/test_qlib_from_pip.yml
vendored
28
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -15,18 +15,26 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||||
# If you want to use python 3.7 in github action, then the latest macos system version is macos-13,
|
# so we limit the macos version to macos-12.
|
||||||
# after macos-13 python 3.7 is no longer supported.
|
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||||
# so we limit the macos version to macos-13.
|
|
||||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
|
|
||||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
python-version: [3.7, 3.8]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Test qlib from pip
|
- name: Test qlib from pip
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||||
|
# So we make the version number of python 3.7 for MacOS more specific.
|
||||||
|
# refs: https://github.com/actions/setup-python/issues/682
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.7.16"
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
@@ -37,10 +45,13 @@ jobs:
|
|||||||
|
|
||||||
- name: Qlib installation test
|
- name: Qlib installation test
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pyqlib==0.9.5.80
|
# 2024-05-30 scs has released a new version: 3.2.4.post2,
|
||||||
|
# This will cause the CI to fail, so we have limited the version of scs for now.
|
||||||
|
python -m pip install "scs<=3.2.4"
|
||||||
|
python -m pip install pyqlib
|
||||||
|
|
||||||
- name: Install Lightgbm for MacOS
|
- name: Install Lightgbm for MacOS
|
||||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||||
run: |
|
run: |
|
||||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||||
@@ -57,5 +68,8 @@ jobs:
|
|||||||
cd qlib
|
cd qlib
|
||||||
|
|
||||||
- name: Test workflow by config
|
- name: Test workflow by config
|
||||||
|
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
|
||||||
|
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
|
||||||
|
if: ${{ matrix.os != 'macos-11' }}
|
||||||
run: |
|
run: |
|
||||||
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||||
|
|||||||
185
.github/workflows/test_qlib_from_source.yml
vendored
Normal file
185
.github/workflows/test_qlib_from_source.yml
vendored
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
name: Test qlib from source
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
timeout-minutes: 180
|
||||||
|
# we may retry for 3 times for `Unit tests with Pytest`
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||||
|
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||||
|
# so we limit the macos version to macos-12.
|
||||||
|
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||||
|
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||||
|
python-version: [3.7, 3.8]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Test qlib from source
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||||
|
# So we make the version number of python 3.7 for MacOS more specific.
|
||||||
|
# refs: https://github.com/actions/setup-python/issues/682
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.7.16"
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Update pip to the latest version
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
|
||||||
|
- name: Installing pytorch for macos
|
||||||
|
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||||
|
run: |
|
||||||
|
python -m pip install torch torchvision torchaudio
|
||||||
|
|
||||||
|
- name: Installing pytorch for ubuntu
|
||||||
|
if: ${{ matrix.os == 'ubuntu-20.04' || matrix.os == 'ubuntu-22.04' }}
|
||||||
|
run: |
|
||||||
|
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
|
- name: Installing pytorch for windows
|
||||||
|
if: ${{ matrix.os == 'windows-latest' }}
|
||||||
|
run: |
|
||||||
|
python -m pip install torch torchvision torchaudio
|
||||||
|
|
||||||
|
- name: Set up Python tools
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade cython
|
||||||
|
python -m pip install -e .[dev]
|
||||||
|
|
||||||
|
- name: Lint with Black
|
||||||
|
# Python 3.7 will use a black with low level. So we use python with higher version for black check
|
||||||
|
if: (matrix.python-version != '3.7')
|
||||||
|
run: |
|
||||||
|
pip install -U black # follow the latest version of black, previous Qlib dependency will downgrade black
|
||||||
|
black . -l 120 --check --diff
|
||||||
|
|
||||||
|
- name: Make html with sphinx
|
||||||
|
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
|
||||||
|
if: ${{ matrix.os == 'ubuntu-22.04' }}
|
||||||
|
run: |
|
||||||
|
cd docs
|
||||||
|
sphinx-build -W --keep-going -b html . _build
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
# Check Qlib with pylint
|
||||||
|
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||||
|
# C0103: invalid-name
|
||||||
|
# C0209: consider-using-f-string
|
||||||
|
# R0402: consider-using-from-import
|
||||||
|
# R1705: no-else-return
|
||||||
|
# R1710: inconsistent-return-statements
|
||||||
|
# R1725: super-with-arguments
|
||||||
|
# R1735: use-dict-literal
|
||||||
|
# W0102: dangerous-default-value
|
||||||
|
# W0212: protected-access
|
||||||
|
# W0221: arguments-differ
|
||||||
|
# W0223: abstract-method
|
||||||
|
# W0231: super-init-not-called
|
||||||
|
# W0237: arguments-renamed
|
||||||
|
# W0612: unused-variable
|
||||||
|
# W0621: redefined-outer-name
|
||||||
|
# W0622: redefined-builtin
|
||||||
|
# FIXME: specify exception type
|
||||||
|
# W0703: broad-except
|
||||||
|
# W1309: f-string-without-interpolation
|
||||||
|
# E1102: not-callable
|
||||||
|
# E1136: unsubscriptable-object
|
||||||
|
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||||
|
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
|
||||||
|
- name: Check Qlib with pylint
|
||||||
|
run: |
|
||||||
|
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||||
|
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||||
|
|
||||||
|
# The following flake8 error codes were ignored:
|
||||||
|
# E501 line too long
|
||||||
|
# Description: We have used black to limit the length of each line to 120.
|
||||||
|
# F541 f-string is missing placeholders
|
||||||
|
# Description: The same thing is done when using pylint for detection.
|
||||||
|
# E266 too many leading '#' for block comment
|
||||||
|
# Description: To make the code more readable, a lot of "#" is used.
|
||||||
|
# This error code appears centrally in:
|
||||||
|
# qlib/backtest/executor.py
|
||||||
|
# qlib/data/ops.py
|
||||||
|
# qlib/utils/__init__.py
|
||||||
|
# E402 module level import not at top of file
|
||||||
|
# Description: There are times when module level import is not available at the top of the file.
|
||||||
|
# W503 line break before binary operator
|
||||||
|
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
|
||||||
|
# E731 do not assign a lambda expression, use a def
|
||||||
|
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
|
||||||
|
# E203 whitespace before ':'
|
||||||
|
# Description: If there is whitespace before ":", it cannot pass the black check.
|
||||||
|
- name: Check Qlib with flake8
|
||||||
|
run: |
|
||||||
|
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||||
|
|
||||||
|
# https://github.com/python/mypy/issues/10600
|
||||||
|
- name: Check Qlib with mypy
|
||||||
|
run: |
|
||||||
|
mypy qlib --install-types --non-interactive || true
|
||||||
|
mypy qlib --verbose
|
||||||
|
|
||||||
|
- name: Check Qlib ipynb with nbqa
|
||||||
|
run: |
|
||||||
|
nbqa black . -l 120 --check --diff
|
||||||
|
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$'
|
||||||
|
|
||||||
|
- name: Test data downloads
|
||||||
|
run: |
|
||||||
|
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||||
|
python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl
|
||||||
|
|
||||||
|
- name: Install Lightgbm for MacOS
|
||||||
|
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||||
|
run: |
|
||||||
|
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||||
|
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||||
|
# FIX MacOS error: Segmentation fault
|
||||||
|
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||||
|
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||||
|
brew unlink libomp
|
||||||
|
brew install libomp.rb
|
||||||
|
|
||||||
|
# Run after data downloads
|
||||||
|
- name: Check Qlib ipynb with nbconvert
|
||||||
|
# Running the nbconvert check on a macos-11 system results in a "Kernel died" error, so we've temporarily disabled macos-11 here.
|
||||||
|
if: ${{ matrix.os != 'macos-11' }}
|
||||||
|
run: |
|
||||||
|
# add more ipynb files in future
|
||||||
|
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||||
|
|
||||||
|
- name: Test workflow by config (install from source)
|
||||||
|
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
|
||||||
|
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
|
||||||
|
if: ${{ matrix.os != 'macos-11' }}
|
||||||
|
run: |
|
||||||
|
python -m pip install numba
|
||||||
|
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||||
|
|
||||||
|
- name: Unit tests with Pytest
|
||||||
|
uses: nick-fields/retry@v2
|
||||||
|
with:
|
||||||
|
timeout_minutes: 60
|
||||||
|
max_attempts: 3
|
||||||
|
command: |
|
||||||
|
cd tests
|
||||||
|
python -m pytest . -m "not slow" --durations=0
|
||||||
71
.github/workflows/test_qlib_from_source_slow.yml
vendored
Normal file
71
.github/workflows/test_qlib_from_source_slow.yml
vendored
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
name: Test qlib from source slow
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
timeout-minutes: 720
|
||||||
|
# we may retry for 3 times for `Unit tests with Pytest`
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||||
|
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||||
|
# so we limit the macos version to macos-12.
|
||||||
|
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||||
|
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||||
|
python-version: [3.7, 3.8]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Test qlib from source slow
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||||
|
# So we make the version number of python 3.7 for MacOS more specific.
|
||||||
|
# refs: https://github.com/actions/setup-python/issues/682
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.7.16"
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Set up Python tools
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install --upgrade cython numpy
|
||||||
|
pip install -e .[dev]
|
||||||
|
|
||||||
|
- name: Downloads dependencies data
|
||||||
|
run: |
|
||||||
|
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||||
|
|
||||||
|
- name: Install Lightgbm for MacOS
|
||||||
|
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||||
|
run: |
|
||||||
|
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||||
|
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||||
|
# FIX MacOS error: Segmentation fault
|
||||||
|
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||||
|
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||||
|
brew unlink libomp
|
||||||
|
brew install libomp.rb
|
||||||
|
|
||||||
|
- name: Unit tests with Pytest
|
||||||
|
uses: nick-fields/retry@v2
|
||||||
|
with:
|
||||||
|
timeout_minutes: 240
|
||||||
|
max_attempts: 3
|
||||||
|
command: |
|
||||||
|
cd tests
|
||||||
|
python -m pytest . -m "slow" --durations=0
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -48,5 +48,4 @@ tags
|
|||||||
*.swp
|
*.swp
|
||||||
|
|
||||||
./pretrain
|
./pretrain
|
||||||
.idea/
|
.idea/
|
||||||
.aider*
|
|
||||||
@@ -9,7 +9,7 @@ version: 2
|
|||||||
build:
|
build:
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
tools:
|
tools:
|
||||||
python: "3.8"
|
python: "3.7"
|
||||||
|
|
||||||
# Build documentation in the docs/ directory with Sphinx
|
# Build documentation in the docs/ directory with Sphinx
|
||||||
sphinx:
|
sphinx:
|
||||||
|
|||||||
31
Dockerfile
31
Dockerfile
@@ -1,31 +0,0 @@
|
|||||||
FROM continuumio/miniconda3:latest
|
|
||||||
|
|
||||||
WORKDIR /qlib
|
|
||||||
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y build-essential
|
|
||||||
|
|
||||||
RUN conda create --name qlib_env python=3.8 -y
|
|
||||||
RUN echo "conda activate qlib_env" >> ~/.bashrc
|
|
||||||
ENV PATH /opt/conda/envs/qlib_env/bin:$PATH
|
|
||||||
|
|
||||||
RUN python -m pip install --upgrade pip
|
|
||||||
|
|
||||||
RUN python -m pip install numpy==1.23.5
|
|
||||||
RUN python -m pip install pandas==1.5.3
|
|
||||||
RUN python -m pip install importlib-metadata==5.2.0
|
|
||||||
RUN python -m pip install "cloudpickle<3"
|
|
||||||
RUN python -m pip install scikit-learn==1.3.2
|
|
||||||
|
|
||||||
RUN python -m pip install cython packaging tables matplotlib statsmodels
|
|
||||||
RUN python -m pip install pybind11 cvxpy
|
|
||||||
|
|
||||||
ARG IS_STABLE="yes"
|
|
||||||
|
|
||||||
RUN if [ "$IS_STABLE" = "yes" ]; then \
|
|
||||||
python -m pip install pyqlib; \
|
|
||||||
else \
|
|
||||||
python setup.py install; \
|
|
||||||
fi
|
|
||||||
@@ -1,6 +1 @@
|
|||||||
exclude tests/*
|
include qlib/VERSION.txt
|
||||||
include qlib/*
|
|
||||||
include qlib/*/*
|
|
||||||
include qlib/*/*/*
|
|
||||||
include qlib/*/*/*/*
|
|
||||||
include qlib/*/*/*/*/*
|
|
||||||
|
|||||||
195
Makefile
195
Makefile
@@ -1,195 +0,0 @@
|
|||||||
.PHONY: clean deepclean prerequisite dependencies lightgbm rl develop lint docs package test analysis all install dev black pylint flake8 mypy nbqa nbconvert lint build upload docs-gen
|
|
||||||
#You can modify it according to your terminal
|
|
||||||
SHELL := /bin/bash
|
|
||||||
|
|
||||||
########################################################################################
|
|
||||||
# Variables
|
|
||||||
########################################################################################
|
|
||||||
|
|
||||||
# Documentation target directory, will be adapted to specific folder for readthedocs.
|
|
||||||
PUBLIC_DIR := $(shell [ "$$READTHEDOCS" = "True" ] && echo "$$READTHEDOCS_OUTPUT/html" || echo "public")
|
|
||||||
|
|
||||||
SO_DIR := qlib/data/_libs
|
|
||||||
SO_FILES := $(wildcard $(SO_DIR)/*.so)
|
|
||||||
|
|
||||||
########################################################################################
|
|
||||||
# Development Environment Management
|
|
||||||
########################################################################################
|
|
||||||
# Remove common intermediate files.
|
|
||||||
clean:
|
|
||||||
-rm -rf \
|
|
||||||
$(PUBLIC_DIR) \
|
|
||||||
qlib/data/_libs/*.cpp \
|
|
||||||
qlib/data/_libs/*.so \
|
|
||||||
mlruns \
|
|
||||||
public \
|
|
||||||
build \
|
|
||||||
.coverage \
|
|
||||||
.mypy_cache \
|
|
||||||
.pytest_cache \
|
|
||||||
.ruff_cache \
|
|
||||||
Pipfile* \
|
|
||||||
coverage.xml \
|
|
||||||
dist \
|
|
||||||
release-notes.md
|
|
||||||
|
|
||||||
find . -name '*.egg-info' -print0 | xargs -0 rm -rf
|
|
||||||
find . -name '*.pyc' -print0 | xargs -0 rm -f
|
|
||||||
find . -name '*.swp' -print0 | xargs -0 rm -f
|
|
||||||
find . -name '.DS_Store' -print0 | xargs -0 rm -f
|
|
||||||
find . -name '__pycache__' -print0 | xargs -0 rm -rf
|
|
||||||
|
|
||||||
# Remove pre-commit hook, virtual environment alongside itermediate files.
|
|
||||||
deepclean: clean
|
|
||||||
if command -v pre-commit > /dev/null 2>&1; then pre-commit uninstall --hook-type pre-push; fi
|
|
||||||
if command -v pipenv >/dev/null 2>&1 && pipenv --venv >/dev/null 2>&1; then pipenv --rm; fi
|
|
||||||
|
|
||||||
# Prerequisite section
|
|
||||||
# What this code does is compile two Cython modules, rolling and expanding, using setuptools and Cython,
|
|
||||||
# and builds them as binary expansion modules that can be imported directly into Python.
|
|
||||||
# Since pyproject.toml can't do that, we compile it here.
|
|
||||||
prerequisite:
|
|
||||||
@if [ -n "$(SO_FILES)" ]; then \
|
|
||||||
echo "Shared library files exist, skipping build."; \
|
|
||||||
else \
|
|
||||||
echo "No shared library files found, building..."; \
|
|
||||||
pip install --upgrade setuptools wheel; \
|
|
||||||
python -m pip install cython numpy; \
|
|
||||||
python -c "from setuptools import setup, Extension; from Cython.Build import cythonize; import numpy; extensions = [Extension('qlib.data._libs.rolling', ['qlib/data/_libs/rolling.pyx'], language='c++', include_dirs=[numpy.get_include()]), Extension('qlib.data._libs.expanding', ['qlib/data/_libs/expanding.pyx'], language='c++', include_dirs=[numpy.get_include()])]; setup(ext_modules=cythonize(extensions, language_level='3'), script_args=['build_ext', '--inplace'])"; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Install the package in editable mode.
|
|
||||||
dependencies:
|
|
||||||
python -m pip install -e .
|
|
||||||
|
|
||||||
lightgbm:
|
|
||||||
python -m pip install lightgbm --prefer-binary
|
|
||||||
|
|
||||||
rl:
|
|
||||||
python -m pip install -e .[rl]
|
|
||||||
|
|
||||||
develop:
|
|
||||||
python -m pip install -e .[dev]
|
|
||||||
|
|
||||||
lint:
|
|
||||||
python -m pip install -e .[lint]
|
|
||||||
|
|
||||||
docs:
|
|
||||||
python -m pip install -e .[docs]
|
|
||||||
|
|
||||||
package:
|
|
||||||
python -m pip install -e .[package]
|
|
||||||
|
|
||||||
test:
|
|
||||||
python -m pip install -e .[test]
|
|
||||||
|
|
||||||
analysis:
|
|
||||||
python -m pip install -e .[analysis]
|
|
||||||
|
|
||||||
all:
|
|
||||||
python -m pip install -e .[dev,lint,docs,package,test,analysis,rl]
|
|
||||||
|
|
||||||
install: prerequisite dependencies
|
|
||||||
|
|
||||||
dev: prerequisite all
|
|
||||||
|
|
||||||
########################################################################################
|
|
||||||
# Lint and pre-commit
|
|
||||||
########################################################################################
|
|
||||||
|
|
||||||
# Check lint with black.
|
|
||||||
black:
|
|
||||||
black . -l 120 --check --diff
|
|
||||||
|
|
||||||
# Check code folder with pylint.
|
|
||||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
|
||||||
# C0103: invalid-name
|
|
||||||
# C0209: consider-using-f-string
|
|
||||||
# R0402: consider-using-from-import
|
|
||||||
# R1705: no-else-return
|
|
||||||
# R1710: inconsistent-return-statements
|
|
||||||
# R1725: super-with-arguments
|
|
||||||
# R1735: use-dict-literal
|
|
||||||
# W0102: dangerous-default-value
|
|
||||||
# W0212: protected-access
|
|
||||||
# W0221: arguments-differ
|
|
||||||
# W0223: abstract-method
|
|
||||||
# W0231: super-init-not-called
|
|
||||||
# W0237: arguments-renamed
|
|
||||||
# W0612: unused-variable
|
|
||||||
# W0621: redefined-outer-name
|
|
||||||
# W0622: redefined-builtin
|
|
||||||
# FIXME: specify exception type
|
|
||||||
# W0703: broad-except
|
|
||||||
# W1309: f-string-without-interpolation
|
|
||||||
# E1102: not-callable
|
|
||||||
# E1136: unsubscriptable-object
|
|
||||||
# W4904: deprecated-class
|
|
||||||
# R0917: too-many-positional-arguments
|
|
||||||
# E1123: unexpected-keyword-arg
|
|
||||||
# References for disable error: https://pylint.pycqa.org/en/latest/user_guide/messages/messages_overview.html
|
|
||||||
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
|
|
||||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
|
||||||
pylint:
|
|
||||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,W4904,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1730,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' qlib --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
|
||||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,E1123,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' scripts --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
|
||||||
|
|
||||||
# Check code with flake8.
|
|
||||||
# The following flake8 error codes were ignored:
|
|
||||||
# E501 line too long
|
|
||||||
# Description: We have used black to limit the length of each line to 120.
|
|
||||||
# F541 f-string is missing placeholders
|
|
||||||
# Description: The same thing is done when using pylint for detection.
|
|
||||||
# E266 too many leading '#' for block comment
|
|
||||||
# Description: To make the code more readable, a lot of "#" is used.
|
|
||||||
# This error code appears centrally in:
|
|
||||||
# qlib/backtest/executor.py
|
|
||||||
# qlib/data/ops.py
|
|
||||||
# qlib/utils/__init__.py
|
|
||||||
# E402 module level import not at top of file
|
|
||||||
# Description: There are times when module level import is not available at the top of the file.
|
|
||||||
# W503 line break before binary operator
|
|
||||||
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
|
|
||||||
# E731 do not assign a lambda expression, use a def
|
|
||||||
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
|
|
||||||
# E203 whitespace before ':'
|
|
||||||
# Description: If there is whitespace before ":", it cannot pass the black check.
|
|
||||||
flake8:
|
|
||||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
|
||||||
|
|
||||||
# Check code with mypy.
|
|
||||||
# https://github.com/python/mypy/issues/10600
|
|
||||||
mypy:
|
|
||||||
mypy qlib --install-types --non-interactive
|
|
||||||
mypy qlib --verbose
|
|
||||||
|
|
||||||
# Check ipynb with nbqa.
|
|
||||||
nbqa:
|
|
||||||
nbqa black . -l 120 --check --diff
|
|
||||||
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}'
|
|
||||||
|
|
||||||
# Check ipynb with nbconvert.(Run after data downloads)
|
|
||||||
# TODO: Add more ipynb files in future
|
|
||||||
nbconvert:
|
|
||||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
|
||||||
|
|
||||||
lint: black pylint flake8 mypy nbqa
|
|
||||||
|
|
||||||
########################################################################################
|
|
||||||
# Package
|
|
||||||
########################################################################################
|
|
||||||
|
|
||||||
# Build the package.
|
|
||||||
build:
|
|
||||||
python -m build --wheel
|
|
||||||
|
|
||||||
# Upload the package.
|
|
||||||
upload:
|
|
||||||
python -m twine upload dist/*
|
|
||||||
|
|
||||||
########################################################################################
|
|
||||||
# Documentation
|
|
||||||
########################################################################################
|
|
||||||
|
|
||||||
docs-gen:
|
|
||||||
python -m sphinx.cmd.build -W docs $(PUBLIC_DIR)
|
|
||||||
83
README.md
83
README.md
@@ -8,30 +8,9 @@
|
|||||||
[](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
[](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||||
|
|
||||||
## :newspaper: **What's NEW!** :sparkling_heart:
|
## :newspaper: **What's NEW!** :sparkling_heart:
|
||||||
|
|
||||||
Recent released features
|
Recent released features
|
||||||
|
|
||||||
### Introducing <a href="https://github.com/microsoft/RD-Agent"><img src="docs/_static/img/rdagent_logo.png" alt="RD_Agent" style="height: 2em"></a>: LLM-Based Autonomous Evolving Agents for Industrial Data-Driven R&D
|
|
||||||
|
|
||||||
We are excited to announce the release of **RD-Agent**📢, a powerful tool that supports automated factor mining and model optimization in quant investment R&D.
|
|
||||||
|
|
||||||
RD-Agent is now available on [GitHub](https://github.com/microsoft/RD-Agent), and we welcome your star🌟!
|
|
||||||
|
|
||||||
To learn more, please visit our [♾️Demo page](https://rdagent.azurewebsites.net/). Here, you will find demo videos in both English and Chinese to help you better understand the scenario and usage of RD-Agent.
|
|
||||||
|
|
||||||
We have prepared several demo videos for you:
|
|
||||||
| Scenario | Demo video (English) | Demo video (中文) |
|
|
||||||
| -- | ------ | ------ |
|
|
||||||
| Quant Factor Mining | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=zh) |
|
|
||||||
| Quant Factor Mining from reports | [Link](https://rdagent.azurewebsites.net/report_factor?lang=en) | [Link](https://rdagent.azurewebsites.net/report_factor?lang=zh) |
|
|
||||||
| Quant Model Optimization | [Link](https://rdagent.azurewebsites.net/model_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/model_loop?lang=zh) |
|
|
||||||
|
|
||||||
***
|
|
||||||
|
|
||||||
| Feature | Status |
|
| Feature | Status |
|
||||||
| -- | ------ |
|
| -- | ------ |
|
||||||
| BPQP for End-to-end learning | 📈Coming soon!([Under review](https://github.com/microsoft/qlib/pull/1863)) |
|
|
||||||
| 🔥LLM-driven Auto Quant Factory🔥 | 🚀 Released in [♾️RD-Agent](https://github.com/microsoft/RD-Agent) on Aug 8, 2024 |
|
|
||||||
| KRNN and Sandwich models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1414/) on May 26, 2023 |
|
| KRNN and Sandwich models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1414/) on May 26, 2023 |
|
||||||
| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |
|
| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |
|
||||||
| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|
|
| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|
|
||||||
@@ -153,11 +132,11 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
|
|||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
This table demonstrates the supported Python version of `Qlib`:
|
This table demonstrates the supported Python version of `Qlib`:
|
||||||
| | install with pip | install from source | plot |
|
| | install with pip | install from source | plot |
|
||||||
| ------------- |:---------------------:|:--------------------:|:------------------:|
|
| ------------- |:---------------------:|:--------------------:|:----:|
|
||||||
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||||
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||||
|
|
||||||
**Note**:
|
**Note**:
|
||||||
1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.
|
1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.
|
||||||
@@ -197,11 +176,11 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
|||||||
|
|
||||||
## Data Preparation
|
## Data Preparation
|
||||||
❗ Due to more restrict data security policy. The offical dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.
|
❗ Due to more restrict data security policy. The offical dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.
|
||||||
Here is an example to download the data updated on 20240809.
|
Here is an example to download the data updated on 20220720.
|
||||||
```bash
|
```bash
|
||||||
wget https://github.com/chenditc/investment_data/releases/download/2024-08-09/qlib_bin.tar.gz
|
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||||
mkdir -p ~/.qlib/qlib_data/cn_data
|
mkdir -p ~/.qlib/qlib_data/cn_data
|
||||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=1
|
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||||
rm -f qlib_bin.tar.gz
|
rm -f qlib_bin.tar.gz
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -293,38 +272,6 @@ We recommend users to prepare their own data if they have a high-quality dataset
|
|||||||
```
|
```
|
||||||
-->
|
-->
|
||||||
|
|
||||||
## Docker images
|
|
||||||
1. Pulling a docker image from a docker hub repository
|
|
||||||
```bash
|
|
||||||
docker pull pyqlib/qlib_image_stable:stable
|
|
||||||
```
|
|
||||||
2. Start a new Docker container
|
|
||||||
```bash
|
|
||||||
docker run -it --name <container name> -v <Mounted local directory>:/app qlib_image_stable
|
|
||||||
```
|
|
||||||
3. At this point you are in the docker environment and can run the qlib scripts. An example:
|
|
||||||
```bash
|
|
||||||
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
|
||||||
>>> python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
|
||||||
```
|
|
||||||
4. Exit the container
|
|
||||||
```bash
|
|
||||||
>>> exit
|
|
||||||
```
|
|
||||||
5. Restart the container
|
|
||||||
```bash
|
|
||||||
docker start -i -a <container name>
|
|
||||||
```
|
|
||||||
6. Stop the container
|
|
||||||
```bash
|
|
||||||
docker stop <container name>
|
|
||||||
```
|
|
||||||
7. Delete the container
|
|
||||||
```bash
|
|
||||||
docker rm <container name>
|
|
||||||
```
|
|
||||||
8. If you want to know more information, please refer to the [documentation](https://qlib.readthedocs.io/en/latest/developer/how_to_build_image.html).
|
|
||||||
|
|
||||||
## Auto Quant Research Workflow
|
## Auto Quant Research Workflow
|
||||||
Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||||
|
|
||||||
@@ -358,22 +305,22 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
|||||||
```
|
```
|
||||||
Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
|
Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
|
||||||
|
|
||||||
2. Graphical Reports Analysis: First, run `python -m pip install .[analysis]` to install the required dependencies. Then run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports.
|
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
|
||||||
- Forecasting signal (model prediction) analysis
|
- Forecasting signal (model prediction) analysis
|
||||||
- Cumulative Return of groups
|
- Cumulative Return of groups
|
||||||

|

|
||||||
- Return distribution
|
- Return distribution
|
||||||

|

|
||||||
- Information Coefficient (IC)
|
- Information Coefficient (IC)
|
||||||

|

|
||||||

|

|
||||||

|

|
||||||
- Auto Correlation of forecasting signal (model prediction)
|
- Auto Correlation of forecasting signal (model prediction)
|
||||||

|

|
||||||
|
|
||||||
- Portfolio analysis
|
- Portfolio analysis
|
||||||
- Backtest return
|
- Backtest return
|
||||||

|

|
||||||
<!--
|
<!--
|
||||||
- Score IC
|
- Score IC
|
||||||

|

|
||||||
@@ -552,7 +499,7 @@ Qlib data are stored in a compact format, which is efficient to be combined into
|
|||||||
Join IM discussion groups:
|
Join IM discussion groups:
|
||||||
|[Gitter](https://gitter.im/Microsoft/qlib)|
|
|[Gitter](https://gitter.im/Microsoft/qlib)|
|
||||||
|----|
|
|----|
|
||||||
||
|
||
|
||||||
|
|
||||||
# Contributing
|
# Contributing
|
||||||
We appreciate all contributions and thank all the contributors!
|
We appreciate all contributions and thank all the contributors!
|
||||||
|
|||||||
@@ -1,31 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
docker_user="your_dockerhub_username"
|
|
||||||
|
|
||||||
read -p "Do you want to build the nightly version of the qlib image? (default is stable) (yes/no): " answer;
|
|
||||||
answer=$(echo "$answer" | tr '[:upper:]' '[:lower:]')
|
|
||||||
|
|
||||||
if [ "$answer" = "yes" ]; then
|
|
||||||
# Build the nightly version of the qlib image
|
|
||||||
docker build --build-arg IS_STABLE=no -t qlib_image -f ./Dockerfile .
|
|
||||||
image_tag="nightly"
|
|
||||||
else
|
|
||||||
# Build the stable version of the qlib image
|
|
||||||
docker build -t qlib_image -f ./Dockerfile .
|
|
||||||
image_tag="stable"
|
|
||||||
fi
|
|
||||||
|
|
||||||
read -p "Is it uploaded to docker hub? (default is no) (yes/no): " answer;
|
|
||||||
answer=$(echo "$answer" | tr '[:upper:]' '[:lower:]')
|
|
||||||
|
|
||||||
if [ "$answer" = "yes" ]; then
|
|
||||||
# Log in to Docker Hub
|
|
||||||
# If you are a new docker hub user, please verify your email address before proceeding with this step.
|
|
||||||
docker login
|
|
||||||
# Tag the Docker image
|
|
||||||
docker tag qlib_image "$docker_user/qlib_image:$image_tag"
|
|
||||||
# Push the Docker image to Docker Hub
|
|
||||||
docker push "$docker_user/qlib_image:$image_tag"
|
|
||||||
else
|
|
||||||
echo "Not uploaded to docker hub."
|
|
||||||
fi
|
|
||||||
BIN
docs/_static/img/rdagent_logo.png
vendored
BIN
docs/_static/img/rdagent_logo.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 94 KiB |
@@ -123,6 +123,7 @@ html_logo = "_static/img/logo/1.png"
|
|||||||
html_theme_options = {
|
html_theme_options = {
|
||||||
"logo_only": True,
|
"logo_only": True,
|
||||||
"collapse_navigation": False,
|
"collapse_navigation": False,
|
||||||
|
"display_version": False,
|
||||||
"navigation_depth": 4,
|
"navigation_depth": 4,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,81 +0,0 @@
|
|||||||
.. _docker_image:
|
|
||||||
|
|
||||||
==================
|
|
||||||
Build Docker Image
|
|
||||||
==================
|
|
||||||
|
|
||||||
Dockerfile
|
|
||||||
==========
|
|
||||||
|
|
||||||
There is a **Dockerfile** file in the root directory of the project from which you can build the docker image. There are two build methods in Dockerfile to choose from.
|
|
||||||
When executing the build command, use the ``--build-arg`` parameter to control the image version. The ``--build-arg`` parameter defaults to ``yes``, which builds the ``stable`` version of the qlib image.
|
|
||||||
|
|
||||||
1.For the ``stable`` version, use ``pip install pyqlib`` to build the qlib image.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
docker build --build-arg IS_STABLE=yes -t <image name> -f ./Dockerfile .
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
docker build -t <image name> -f ./Dockerfile .
|
|
||||||
|
|
||||||
2. For the ``nightly`` version, use current source code to build the qlib image.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
docker build --build-arg IS_STABLE=no -t <image name> -f ./Dockerfile .
|
|
||||||
|
|
||||||
Auto build of qlib images
|
|
||||||
=========================
|
|
||||||
|
|
||||||
1. There is a **build_docker_image.sh** file in the root directory of your project, which can be used to automatically build docker images and upload them to your docker hub repository(Optional, configuration required).
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
sh build_docker_image.sh
|
|
||||||
>>> Do you want to build the nightly version of the qlib image? (default is stable) (yes/no):
|
|
||||||
>>> Is it uploaded to docker hub? (default is no) (yes/no):
|
|
||||||
|
|
||||||
2. If you want to upload the built image to your docker hub repository, you need to edit your **build_docker_image.sh** file first, fill in ``docker_user`` in the file, and then execute this file.
|
|
||||||
|
|
||||||
How to use qlib images
|
|
||||||
======================
|
|
||||||
1. Start a new Docker container
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
docker run -it --name <container name> -v <Mounted local directory>:/app <image name>
|
|
||||||
|
|
||||||
2. At this point you are in the docker environment and can run the qlib scripts. An example:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
|
||||||
>>> python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
|
||||||
|
|
||||||
3. Exit the container
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
>>> exit
|
|
||||||
|
|
||||||
4. Restart the container
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
docker start -i -a <container name>
|
|
||||||
|
|
||||||
5. Stop the container
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
docker stop -i -a <container name>
|
|
||||||
|
|
||||||
6. Delete the container
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
docker rm <container name>
|
|
||||||
|
|
||||||
7. For more information on using docker see the `docker documentation <https://docs.docker.com/reference/cli/docker/>`_.
|
|
||||||
@@ -61,7 +61,6 @@ Document Structure
|
|||||||
:caption: FOR DEVELOPERS:
|
:caption: FOR DEVELOPERS:
|
||||||
|
|
||||||
Code Standard & Development Guidance <developer/code_standard_and_dev_guide.rst>
|
Code Standard & Development Guidance <developer/code_standard_and_dev_guide.rst>
|
||||||
How to build image <developer/how_to_build_image.rst>
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 3
|
:maxdepth: 3
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
# Introduction
|
|
||||||
|
|
||||||
What is GeneralPtNN
|
|
||||||
- Fix previous design that fail to support both Time-series and tabular data
|
|
||||||
- Now you can just replace the Pytorch model structure to run a NN model.
|
|
||||||
|
|
||||||
We provide an example to demonstrate the effectiveness of the current design.
|
|
||||||
- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158-dataset)
|
|
||||||
- `workflow_config_gru2mlp.yaml` to demonstrate we can convert config from time-series to tabular data with minimal changes
|
|
||||||
- You only have to change the net & dataset class to make the conversion.
|
|
||||||
- `workflow_config_mlp.yaml` achieved similar functionality with [MLP](../README.md#Alpha158-dataset)
|
|
||||||
|
|
||||||
# TODO
|
|
||||||
|
|
||||||
- We will align existing models to current design.
|
|
||||||
|
|
||||||
- The result of `workflow_config_mlp.yaml` is different with the result of [MLP](../README.md#Alpha158-dataset) since GeneralPtNN has a different stopping method compared to previous implementations. Specificly, GeneralPtNN controls training according to epoches, whereas previous methods controlled by max_steps.
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
qlib_init:
|
|
||||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
|
||||||
region: cn
|
|
||||||
market: &market csi300
|
|
||||||
benchmark: &benchmark SH000300
|
|
||||||
data_handler_config: &data_handler_config
|
|
||||||
start_time: 2008-01-01
|
|
||||||
end_time: 2020-08-01
|
|
||||||
fit_start_time: 2008-01-01
|
|
||||||
fit_end_time: 2014-12-31
|
|
||||||
instruments: *market
|
|
||||||
infer_processors:
|
|
||||||
- class: FilterCol
|
|
||||||
kwargs:
|
|
||||||
fields_group: feature
|
|
||||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
|
||||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
|
||||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
|
||||||
]
|
|
||||||
- class: RobustZScoreNorm
|
|
||||||
kwargs:
|
|
||||||
fields_group: feature
|
|
||||||
clip_outlier: true
|
|
||||||
- class: Fillna
|
|
||||||
kwargs:
|
|
||||||
fields_group: feature
|
|
||||||
learn_processors:
|
|
||||||
- class: DropnaLabel
|
|
||||||
- class: CSRankNorm
|
|
||||||
kwargs:
|
|
||||||
fields_group: label
|
|
||||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
|
||||||
|
|
||||||
port_analysis_config: &port_analysis_config
|
|
||||||
strategy:
|
|
||||||
class: TopkDropoutStrategy
|
|
||||||
module_path: qlib.contrib.strategy
|
|
||||||
kwargs:
|
|
||||||
signal: <PRED>
|
|
||||||
topk: 50
|
|
||||||
n_drop: 5
|
|
||||||
backtest:
|
|
||||||
start_time: 2017-01-01
|
|
||||||
end_time: 2020-08-01
|
|
||||||
account: 100000000
|
|
||||||
benchmark: *benchmark
|
|
||||||
exchange_kwargs:
|
|
||||||
limit_threshold: 0.095
|
|
||||||
deal_price: close
|
|
||||||
open_cost: 0.0005
|
|
||||||
close_cost: 0.0015
|
|
||||||
min_cost: 5
|
|
||||||
task:
|
|
||||||
model:
|
|
||||||
class: GeneralPTNN
|
|
||||||
module_path: qlib.contrib.model.pytorch_general_nn
|
|
||||||
kwargs:
|
|
||||||
n_epochs: 200
|
|
||||||
lr: 2e-4
|
|
||||||
early_stop: 10
|
|
||||||
batch_size: 800
|
|
||||||
metric: loss
|
|
||||||
loss: mse
|
|
||||||
n_jobs: 20
|
|
||||||
GPU: 0
|
|
||||||
pt_model_uri: "qlib.contrib.model.pytorch_gru_ts.GRUModel"
|
|
||||||
pt_model_kwargs: {
|
|
||||||
"d_feat": 20,
|
|
||||||
"hidden_size": 64,
|
|
||||||
"num_layers": 2,
|
|
||||||
"dropout": 0.,
|
|
||||||
}
|
|
||||||
dataset:
|
|
||||||
class: TSDatasetH
|
|
||||||
module_path: qlib.data.dataset
|
|
||||||
kwargs:
|
|
||||||
handler:
|
|
||||||
class: Alpha158
|
|
||||||
module_path: qlib.contrib.data.handler
|
|
||||||
kwargs: *data_handler_config
|
|
||||||
segments:
|
|
||||||
train: [2008-01-01, 2014-12-31]
|
|
||||||
valid: [2015-01-01, 2016-12-31]
|
|
||||||
test: [2017-01-01, 2020-08-01]
|
|
||||||
step_len: 20
|
|
||||||
record:
|
|
||||||
- class: SignalRecord
|
|
||||||
module_path: qlib.workflow.record_temp
|
|
||||||
kwargs:
|
|
||||||
model: <MODEL>
|
|
||||||
dataset: <DATASET>
|
|
||||||
- class: SigAnaRecord
|
|
||||||
module_path: qlib.workflow.record_temp
|
|
||||||
kwargs:
|
|
||||||
ana_long_short: False
|
|
||||||
ann_scaler: 252
|
|
||||||
- class: PortAnaRecord
|
|
||||||
module_path: qlib.workflow.record_temp
|
|
||||||
kwargs:
|
|
||||||
config: *port_analysis_config
|
|
||||||
@@ -1,93 +0,0 @@
|
|||||||
qlib_init:
|
|
||||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
|
||||||
region: cn
|
|
||||||
market: &market csi300
|
|
||||||
benchmark: &benchmark SH000300
|
|
||||||
data_handler_config: &data_handler_config
|
|
||||||
start_time: 2008-01-01
|
|
||||||
end_time: 2020-08-01
|
|
||||||
fit_start_time: 2008-01-01
|
|
||||||
fit_end_time: 2014-12-31
|
|
||||||
instruments: *market
|
|
||||||
infer_processors:
|
|
||||||
- class: FilterCol
|
|
||||||
kwargs:
|
|
||||||
fields_group: feature
|
|
||||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
|
||||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
|
||||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
|
||||||
]
|
|
||||||
- class: RobustZScoreNorm
|
|
||||||
kwargs:
|
|
||||||
fields_group: feature
|
|
||||||
clip_outlier: true
|
|
||||||
- class: Fillna
|
|
||||||
kwargs:
|
|
||||||
fields_group: feature
|
|
||||||
learn_processors:
|
|
||||||
- class: DropnaLabel
|
|
||||||
- class: CSRankNorm
|
|
||||||
kwargs:
|
|
||||||
fields_group: label
|
|
||||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
|
||||||
|
|
||||||
port_analysis_config: &port_analysis_config
|
|
||||||
strategy:
|
|
||||||
class: TopkDropoutStrategy
|
|
||||||
module_path: qlib.contrib.strategy
|
|
||||||
kwargs:
|
|
||||||
signal: <PRED>
|
|
||||||
topk: 50
|
|
||||||
n_drop: 5
|
|
||||||
backtest:
|
|
||||||
start_time: 2017-01-01
|
|
||||||
end_time: 2020-08-01
|
|
||||||
account: 100000000
|
|
||||||
benchmark: *benchmark
|
|
||||||
exchange_kwargs:
|
|
||||||
limit_threshold: 0.095
|
|
||||||
deal_price: close
|
|
||||||
open_cost: 0.0005
|
|
||||||
close_cost: 0.0015
|
|
||||||
min_cost: 5
|
|
||||||
task:
|
|
||||||
model:
|
|
||||||
class: GeneralPTNN
|
|
||||||
module_path: qlib.contrib.model.pytorch_general_nn
|
|
||||||
kwargs:
|
|
||||||
lr: 1e-3
|
|
||||||
n_epochs: 1
|
|
||||||
batch_size: 800
|
|
||||||
loss: mse
|
|
||||||
optimizer: adam
|
|
||||||
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
|
|
||||||
pt_model_kwargs:
|
|
||||||
input_dim: 20
|
|
||||||
layers: [20,]
|
|
||||||
dataset:
|
|
||||||
class: DatasetH
|
|
||||||
module_path: qlib.data.dataset
|
|
||||||
kwargs:
|
|
||||||
handler:
|
|
||||||
class: Alpha158
|
|
||||||
module_path: qlib.contrib.data.handler
|
|
||||||
kwargs: *data_handler_config
|
|
||||||
segments:
|
|
||||||
train: [2008-01-01, 2014-12-31]
|
|
||||||
valid: [2015-01-01, 2016-12-31]
|
|
||||||
test: [2017-01-01, 2020-08-01]
|
|
||||||
record:
|
|
||||||
- class: SignalRecord
|
|
||||||
module_path: qlib.workflow.record_temp
|
|
||||||
kwargs:
|
|
||||||
model: <MODEL>
|
|
||||||
dataset: <DATASET>
|
|
||||||
- class: SigAnaRecord
|
|
||||||
module_path: qlib.workflow.record_temp
|
|
||||||
kwargs:
|
|
||||||
ana_long_short: False
|
|
||||||
ann_scaler: 252
|
|
||||||
- class: PortAnaRecord
|
|
||||||
module_path: qlib.workflow.record_temp
|
|
||||||
kwargs:
|
|
||||||
config: *port_analysis_config
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
qlib_init:
|
|
||||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
|
||||||
region: cn
|
|
||||||
market: &market csi300
|
|
||||||
benchmark: &benchmark SH000300
|
|
||||||
data_handler_config: &data_handler_config
|
|
||||||
start_time: 2008-01-01
|
|
||||||
end_time: 2020-08-01
|
|
||||||
fit_start_time: 2008-01-01
|
|
||||||
fit_end_time: 2014-12-31
|
|
||||||
instruments: *market
|
|
||||||
infer_processors: [
|
|
||||||
{
|
|
||||||
"class" : "DropCol",
|
|
||||||
"kwargs":{"col_list": ["VWAP0"]}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"class" : "CSZFillna",
|
|
||||||
"kwargs":{"fields_group": "feature"}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
learn_processors: [
|
|
||||||
{
|
|
||||||
"class" : "DropCol",
|
|
||||||
"kwargs":{"col_list": ["VWAP0"]}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"class" : "DropnaProcessor",
|
|
||||||
"kwargs":{"fields_group": "feature"}
|
|
||||||
},
|
|
||||||
"DropnaLabel",
|
|
||||||
{
|
|
||||||
"class": "CSZScoreNorm",
|
|
||||||
"kwargs": {"fields_group": "label"}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
process_type: "independent"
|
|
||||||
|
|
||||||
port_analysis_config: &port_analysis_config
|
|
||||||
strategy:
|
|
||||||
class: TopkDropoutStrategy
|
|
||||||
module_path: qlib.contrib.strategy
|
|
||||||
kwargs:
|
|
||||||
signal: <PRED>
|
|
||||||
topk: 50
|
|
||||||
n_drop: 5
|
|
||||||
backtest:
|
|
||||||
start_time: 2017-01-01
|
|
||||||
end_time: 2020-08-01
|
|
||||||
account: 100000000
|
|
||||||
benchmark: *benchmark
|
|
||||||
exchange_kwargs:
|
|
||||||
limit_threshold: 0.095
|
|
||||||
deal_price: close
|
|
||||||
open_cost: 0.0005
|
|
||||||
close_cost: 0.0015
|
|
||||||
min_cost: 5
|
|
||||||
task:
|
|
||||||
model:
|
|
||||||
class: GeneralPTNN
|
|
||||||
module_path: qlib.contrib.model.pytorch_general_nn
|
|
||||||
kwargs:
|
|
||||||
# FIXME: wrong parameters.
|
|
||||||
lr: 2e-3
|
|
||||||
batch_size: 8192
|
|
||||||
loss: mse
|
|
||||||
weight_decay: 0.0002
|
|
||||||
optimizer: adam
|
|
||||||
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
|
|
||||||
pt_model_kwargs:
|
|
||||||
input_dim: 157
|
|
||||||
dataset:
|
|
||||||
class: DatasetH
|
|
||||||
module_path: qlib.data.dataset
|
|
||||||
kwargs:
|
|
||||||
handler:
|
|
||||||
class: Alpha158
|
|
||||||
module_path: qlib.contrib.data.handler
|
|
||||||
kwargs: *data_handler_config
|
|
||||||
segments:
|
|
||||||
train: [2008-01-01, 2014-12-31]
|
|
||||||
valid: [2015-01-01, 2016-12-31]
|
|
||||||
test: [2017-01-01, 2020-08-01]
|
|
||||||
record:
|
|
||||||
- class: SignalRecord
|
|
||||||
module_path: qlib.workflow.record_temp
|
|
||||||
kwargs:
|
|
||||||
model: <MODEL>
|
|
||||||
dataset: <DATASET>
|
|
||||||
- class: SigAnaRecord
|
|
||||||
module_path: qlib.workflow.record_temp
|
|
||||||
kwargs:
|
|
||||||
ana_long_short: False
|
|
||||||
ann_scaler: 252
|
|
||||||
- class: PortAnaRecord
|
|
||||||
module_path: qlib.workflow.record_temp
|
|
||||||
kwargs:
|
|
||||||
config: *port_analysis_config
|
|
||||||
@@ -1,15 +1,14 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import qlib
|
import qlib
|
||||||
from ruamel.yaml import YAML
|
import ruamel.yaml as yaml
|
||||||
from qlib.utils import init_instance_by_config
|
from qlib.utils import init_instance_by_config
|
||||||
|
|
||||||
|
|
||||||
def main(seed, config_file="configs/config_alstm.yaml"):
|
def main(seed, config_file="configs/config_alstm.yaml"):
|
||||||
# set random seed
|
# set random seed
|
||||||
with open(config_file) as f:
|
with open(config_file) as f:
|
||||||
yaml = YAML(typ="safe", pure=True)
|
config = yaml.safe_load(f)
|
||||||
config = yaml.load(f)
|
|
||||||
|
|
||||||
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
|
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
|
||||||
seed_suffix = ""
|
seed_suffix = ""
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ from copy import deepcopy
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pickle
|
import pickle
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from ruamel.yaml import YAML
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import yaml
|
||||||
from qlib.log import TimeInspector
|
from qlib.log import TimeInspector
|
||||||
|
|
||||||
from qlib import init
|
from qlib import init
|
||||||
@@ -30,8 +30,7 @@ if __name__ == "__main__":
|
|||||||
subprocess.run(f"qrun {config_path}", shell=True)
|
subprocess.run(f"qrun {config_path}", shell=True)
|
||||||
|
|
||||||
# 2) dump handler
|
# 2) dump handler
|
||||||
yaml = YAML(typ="safe", pure=True)
|
task_config = yaml.safe_load(config_path.open())
|
||||||
task_config = yaml.load(config_path.open())
|
|
||||||
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
|
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
|
||||||
pprint(hd_conf)
|
pprint(hd_conf)
|
||||||
hd: DataHandlerLP = init_instance_by_config(hd_conf)
|
hd: DataHandlerLP = init_instance_by_config(hd_conf)
|
||||||
|
|||||||
@@ -9,9 +9,10 @@ from copy import deepcopy
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pickle
|
import pickle
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from ruamel.yaml import YAML
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
from qlib import init
|
from qlib import init
|
||||||
from qlib.data.dataset.handler import DataHandlerLP
|
from qlib.data.dataset.handler import DataHandlerLP
|
||||||
from qlib.log import TimeInspector
|
from qlib.log import TimeInspector
|
||||||
@@ -28,8 +29,7 @@ if __name__ == "__main__":
|
|||||||
exp_name = "data_mem_reuse_demo"
|
exp_name = "data_mem_reuse_demo"
|
||||||
|
|
||||||
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
|
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
|
||||||
yaml = YAML(typ="safe", pure=True)
|
task_config = yaml.safe_load(config_path.open())
|
||||||
task_config = yaml.load(config_path.open())
|
|
||||||
|
|
||||||
# 1) without using processed data in memory
|
# 1) without using processed data in memory
|
||||||
with TimeInspector.logt("The original time without reusing processed data in memory:"):
|
with TimeInspector.logt("The original time without reusing processed data in memory:"):
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import sys
|
|||||||
import fire
|
import fire
|
||||||
import time
|
import time
|
||||||
import glob
|
import glob
|
||||||
|
import yaml
|
||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import inspect
|
import inspect
|
||||||
@@ -14,7 +15,6 @@ import functools
|
|||||||
import statistics
|
import statistics
|
||||||
import subprocess
|
import subprocess
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from ruamel.yaml import YAML
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from operator import xor
|
from operator import xor
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
@@ -188,8 +188,7 @@ def gen_and_save_md_table(metrics, dataset):
|
|||||||
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
||||||
def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
|
def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
|
||||||
with open(yaml_path, "r") as fp:
|
with open(yaml_path, "r") as fp:
|
||||||
yaml = YAML(typ="safe", pure=True)
|
config = yaml.safe_load(fp)
|
||||||
config = yaml.load(fp)
|
|
||||||
try:
|
try:
|
||||||
del config["task"]["model"]["kwargs"]["seed"]
|
del config["task"]["model"]["kwargs"]["seed"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
|||||||
@@ -1,93 +1,2 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools", "cython", "numpy>=1.24.0"]
|
requires = ["setuptools", "numpy", "Cython"]
|
||||||
build-backend = "setuptools.build_meta"
|
|
||||||
|
|
||||||
[project]
|
|
||||||
classifiers = [
|
|
||||||
"Operating System :: POSIX :: Linux",
|
|
||||||
"Operating System :: Microsoft :: Windows",
|
|
||||||
"Operating System :: MacOS",
|
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Development Status :: 3 - Alpha",
|
|
||||||
"Programming Language :: Python",
|
|
||||||
"Programming Language :: Python :: 3",
|
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Programming Language :: Python :: 3.11",
|
|
||||||
"Programming Language :: Python :: 3.12",
|
|
||||||
]
|
|
||||||
name = "pyqlib"
|
|
||||||
dynamic = ["version"]
|
|
||||||
description = "A Quantitative-research Platform"
|
|
||||||
requires-python = ">=3.8.0"
|
|
||||||
readme = {file = "README.md", content-type = "text/markdown"}
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
"pyyaml",
|
|
||||||
"numpy",
|
|
||||||
"pandas",
|
|
||||||
"mlflow",
|
|
||||||
"filelock>=3.16.0",
|
|
||||||
"redis",
|
|
||||||
"dill",
|
|
||||||
"fire",
|
|
||||||
"ruamel.yaml>=0.17.38",
|
|
||||||
"python-redis-lock",
|
|
||||||
"tqdm",
|
|
||||||
"pymongo",
|
|
||||||
"loguru",
|
|
||||||
"lightgbm",
|
|
||||||
"gym",
|
|
||||||
"cvxpy",
|
|
||||||
"joblib",
|
|
||||||
"matplotlib",
|
|
||||||
"jupyter",
|
|
||||||
"nbconvert",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
dev = [
|
|
||||||
"pytest",
|
|
||||||
"statsmodels",
|
|
||||||
]
|
|
||||||
# On macos-13 system, when using python version greater than or equal to 3.10,
|
|
||||||
# pytorch can't fully support Numpy version above 2.0, so, when you want to install torch,
|
|
||||||
# it will limit the version of Numpy less than 2.0.
|
|
||||||
rl = [
|
|
||||||
"tianshou<=0.4.10",
|
|
||||||
"torch",
|
|
||||||
"numpy<2.0.0",
|
|
||||||
]
|
|
||||||
lint = [
|
|
||||||
"black",
|
|
||||||
"pylint",
|
|
||||||
"mypy<1.5.0",
|
|
||||||
"flake8",
|
|
||||||
"nbqa",
|
|
||||||
]
|
|
||||||
docs = [
|
|
||||||
"sphinx",
|
|
||||||
"sphinx_rtd_theme",
|
|
||||||
"readthedocs_sphinx_ext",
|
|
||||||
]
|
|
||||||
package = [
|
|
||||||
"twine",
|
|
||||||
"build",
|
|
||||||
]
|
|
||||||
# test_pit dependency packages
|
|
||||||
test = [
|
|
||||||
"yahooquery",
|
|
||||||
"baostock",
|
|
||||||
]
|
|
||||||
analysis = [
|
|
||||||
"plotly",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.setuptools]
|
|
||||||
packages = [
|
|
||||||
"qlib",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.scripts]
|
|
||||||
qrun = "qlib.workflow.cli:run"
|
|
||||||
|
|||||||
@@ -2,11 +2,11 @@
|
|||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
__version__ = "0.9.5.80"
|
__version__ = "0.9.5.99"
|
||||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||||
import os
|
import os
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from ruamel.yaml import YAML
|
import yaml
|
||||||
import logging
|
import logging
|
||||||
import platform
|
import platform
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -176,8 +176,7 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
|||||||
config = {}
|
config = {}
|
||||||
else:
|
else:
|
||||||
with open(conf_path) as f:
|
with open(conf_path) as f:
|
||||||
yaml = YAML(typ="safe", pure=True)
|
config = yaml.safe_load(f)
|
||||||
config = yaml.load(f)
|
|
||||||
config.update(kwargs)
|
config.update(kwargs)
|
||||||
default_conf = config.pop("default_conf", "client")
|
default_conf = config.pop("default_conf", "client")
|
||||||
init(default_conf, **config)
|
init(default_conf, **config)
|
||||||
@@ -273,8 +272,7 @@ def auto_init(**kwargs):
|
|||||||
logger = get_module_logger("Initialization")
|
logger = get_module_logger("Initialization")
|
||||||
conf_pp = pp / "config.yaml"
|
conf_pp = pp / "config.yaml"
|
||||||
with conf_pp.open() as f:
|
with conf_pp.open() as f:
|
||||||
yaml = YAML(typ="safe", pure=True)
|
conf = yaml.safe_load(f)
|
||||||
conf = yaml.load(f)
|
|
||||||
|
|
||||||
conf_type = conf.get("conf_type", "origin")
|
conf_type = conf.get("conf_type", "origin")
|
||||||
if conf_type == "origin":
|
if conf_type == "origin":
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ class BaseSingleMetric:
|
|||||||
raise NotImplementedError(f"Please implement the `empty` method")
|
raise NotImplementedError(f"Please implement the `empty` method")
|
||||||
|
|
||||||
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
|
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
|
||||||
"""Replace np.nan with fill_value in two metrics and add them."""
|
"""Replace np.NaN with fill_value in two metrics and add them."""
|
||||||
|
|
||||||
raise NotImplementedError(f"Please implement the `add` method")
|
raise NotImplementedError(f"Please implement the `add` method")
|
||||||
|
|
||||||
@@ -412,7 +412,7 @@ class BaseOrderIndicator:
|
|||||||
metrics : Union[str, List[str]]
|
metrics : Union[str, List[str]]
|
||||||
all metrics needs to be sumed.
|
all metrics needs to be sumed.
|
||||||
fill_value : float, optional
|
fill_value : float, optional
|
||||||
fill np.nan with value. By default None.
|
fill np.NaN with value. By default None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
raise NotImplementedError(f"Please implement the 'sum_all_indicators' method")
|
raise NotImplementedError(f"Please implement the 'sum_all_indicators' method")
|
||||||
|
|||||||
@@ -325,9 +325,9 @@ class Indicator:
|
|||||||
|
|
||||||
def _update_order_fulfill_rate(self) -> None:
|
def _update_order_fulfill_rate(self) -> None:
|
||||||
def func(deal_amount, amount):
|
def func(deal_amount, amount):
|
||||||
# deal_amount is np.nan or None when there is no inner decision. So full fill rate is 0.
|
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
|
||||||
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
|
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
|
||||||
tmp_deal_amount = tmp_deal_amount.replace({np.nan: 0})
|
tmp_deal_amount = tmp_deal_amount.replace({np.NaN: 0})
|
||||||
return tmp_deal_amount / amount
|
return tmp_deal_amount / amount
|
||||||
|
|
||||||
self.order_indicator.transfer(func, "ffr")
|
self.order_indicator.transfer(func, "ffr")
|
||||||
@@ -354,8 +354,8 @@ class Indicator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def func(trade_price, deal_amount):
|
def func(trade_price, deal_amount):
|
||||||
# trade_price is np.nan instead of inf when deal_amount is zero.
|
# trade_price is np.NaN instead of inf when deal_amount is zero.
|
||||||
tmp_deal_amount = deal_amount.replace({0: np.nan})
|
tmp_deal_amount = deal_amount.replace({0: np.NaN})
|
||||||
return trade_price / tmp_deal_amount
|
return trade_price / tmp_deal_amount
|
||||||
|
|
||||||
self.order_indicator.transfer(func, "trade_price")
|
self.order_indicator.transfer(func, "trade_price")
|
||||||
@@ -425,7 +425,7 @@ class Indicator:
|
|||||||
assert isinstance(price_s, idd.SingleData)
|
assert isinstance(price_s, idd.SingleData)
|
||||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
|
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
|
||||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||||
# ~(np.nan < 1e-8) -> ~(False) -> True
|
# ~(np.NaN < 1e-8) -> ~(False) -> True
|
||||||
|
|
||||||
assert isinstance(price_s, idd.SingleData)
|
assert isinstance(price_s, idd.SingleData)
|
||||||
if agg == "vwap":
|
if agg == "vwap":
|
||||||
|
|||||||
@@ -173,11 +173,7 @@ _default_config = {
|
|||||||
"filters": ["field_not_found"],
|
"filters": ["field_not_found"],
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
# Normally this should be set to `False` to avoid duplicated logging [1].
|
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||||
# However, due to bug in pytest, it requires log message to propagate to root logger to be captured by `caplog` [2].
|
|
||||||
# [1] https://github.com/microsoft/qlib/pull/1661
|
|
||||||
# [2] https://github.com/pytest-dev/pytest/issues/3697
|
|
||||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"], "propagate": False}},
|
|
||||||
# To let qlib work with other packages, we shouldn't disable existing loggers.
|
# To let qlib work with other packages, we shouldn't disable existing loggers.
|
||||||
# Note that this param is default to True according to the documentation of logging.
|
# Note that this param is default to True according to the documentation of logging.
|
||||||
"disable_existing_loggers": False,
|
"disable_existing_loggers": False,
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class Alpha360(DataHandlerLP):
|
|||||||
fit_end_time=None,
|
fit_end_time=None,
|
||||||
filter_pipe=None,
|
filter_pipe=None,
|
||||||
inst_processors=None,
|
inst_processors=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||||
@@ -83,7 +83,7 @@ class Alpha360(DataHandlerLP):
|
|||||||
data_loader=data_loader,
|
data_loader=data_loader,
|
||||||
learn_processors=learn_processors,
|
learn_processors=learn_processors,
|
||||||
infer_processors=infer_processors,
|
infer_processors=infer_processors,
|
||||||
**kwargs,
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_label_config(self):
|
def get_label_config(self):
|
||||||
@@ -109,7 +109,7 @@ class Alpha158(DataHandlerLP):
|
|||||||
process_type=DataHandlerLP.PTYPE_A,
|
process_type=DataHandlerLP.PTYPE_A,
|
||||||
filter_pipe=None,
|
filter_pipe=None,
|
||||||
inst_processors=None,
|
inst_processors=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||||
@@ -134,7 +134,7 @@ class Alpha158(DataHandlerLP):
|
|||||||
infer_processors=infer_processors,
|
infer_processors=infer_processors,
|
||||||
learn_processors=learn_processors,
|
learn_processors=learn_processors,
|
||||||
process_type=process_type,
|
process_type=process_type,
|
||||||
**kwargs,
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_feature_config(self):
|
def get_feature_config(self):
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class CatBoostModel(Model, FeatureInt):
|
|||||||
verbose_eval=20,
|
verbose_eval=20,
|
||||||
evals_result=dict(),
|
evals_result=dict(),
|
||||||
reweighter=None,
|
reweighter=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
df_train, df_valid = dataset.prepare(
|
df_train, df_valid = dataset.prepare(
|
||||||
["train", "valid"],
|
["train", "valid"],
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class DEnsembleModel(Model, FeatureInt):
|
|||||||
sub_weights=None,
|
sub_weights=None,
|
||||||
epochs=100,
|
epochs=100,
|
||||||
early_stopping_rounds=None,
|
early_stopping_rounds=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
|
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
|
||||||
self.num_models = num_models # the number of sub-models
|
self.num_models = num_models # the number of sub-models
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class ADARNN(Model):
|
|||||||
n_splits=2,
|
n_splits=2,
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**_,
|
**_
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("ADARNN")
|
self.logger = get_module_logger("ADARNN")
|
||||||
@@ -154,7 +154,10 @@ class ADARNN(Model):
|
|||||||
self.model.train()
|
self.model.train()
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||||
out_weight_list = None
|
len_loader = np.inf
|
||||||
|
for loader in train_loader_list:
|
||||||
|
if len(loader) < len_loader:
|
||||||
|
len_loader = len(loader)
|
||||||
for data_all in zip(*train_loader_list):
|
for data_all in zip(*train_loader_list):
|
||||||
# for data_all in zip(*train_loader_list):
|
# for data_all in zip(*train_loader_list):
|
||||||
self.train_optimizer.zero_grad()
|
self.train_optimizer.zero_grad()
|
||||||
@@ -568,7 +571,6 @@ class TransferLoss:
|
|||||||
Returns:
|
Returns:
|
||||||
[tensor] -- transfer loss
|
[tensor] -- transfer loss
|
||||||
"""
|
"""
|
||||||
loss = None
|
|
||||||
if self.loss_type in ("mmd_lin", "mmd"):
|
if self.loss_type in ("mmd_lin", "mmd"):
|
||||||
mmdloss = MMD_loss(kernel_type="linear")
|
mmdloss = MMD_loss(kernel_type="linear")
|
||||||
loss = mmdloss(X, Y)
|
loss = mmdloss(X, Y)
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class ADD(Model):
|
|||||||
mu=0.05,
|
mu=0.05,
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("ADD")
|
self.logger = get_module_logger("ADD")
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class ALSTM(Model):
|
|||||||
optimizer="adam",
|
optimizer="adam",
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("ALSTM")
|
self.logger = get_module_logger("ALSTM")
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class ALSTM(Model):
|
|||||||
n_jobs=10,
|
n_jobs=10,
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("ALSTM")
|
self.logger = get_module_logger("ALSTM")
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class GATs(Model):
|
|||||||
optimizer="adam",
|
optimizer="adam",
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("GATs")
|
self.logger = get_module_logger("GATs")
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class GATs(Model):
|
|||||||
GPU=0,
|
GPU=0,
|
||||||
n_jobs=10,
|
n_jobs=10,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("GATs")
|
self.logger = get_module_logger("GATs")
|
||||||
|
|||||||
@@ -1,358 +0,0 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from typing import Union
|
|
||||||
import copy
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.optim as optim
|
|
||||||
|
|
||||||
from qlib.data.dataset.weight import Reweighter
|
|
||||||
|
|
||||||
from .pytorch_utils import count_parameters
|
|
||||||
from ...model.base import Model
|
|
||||||
from ...data.dataset import DatasetH, TSDatasetH
|
|
||||||
from ...data.dataset.handler import DataHandlerLP
|
|
||||||
from ...utils import (
|
|
||||||
init_instance_by_config,
|
|
||||||
get_or_create_path,
|
|
||||||
)
|
|
||||||
from ...log import get_module_logger
|
|
||||||
|
|
||||||
from ...model.utils import ConcatDataset
|
|
||||||
|
|
||||||
|
|
||||||
class GeneralPTNN(Model):
|
|
||||||
"""
|
|
||||||
Motivation:
|
|
||||||
We want to provide a Qlib General Pytorch Model Adaptor
|
|
||||||
You can reuse it for all kinds of Pytorch models.
|
|
||||||
It should include the training and predict process
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
d_feat : int
|
|
||||||
input dimension for each time step
|
|
||||||
metric: str
|
|
||||||
the evaluation metric used in early stop
|
|
||||||
optimizer : str
|
|
||||||
optimizer name
|
|
||||||
GPU : str
|
|
||||||
the GPU ID(s) used for training
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_epochs=200,
|
|
||||||
lr=0.001,
|
|
||||||
metric="",
|
|
||||||
batch_size=2000,
|
|
||||||
early_stop=20,
|
|
||||||
loss="mse",
|
|
||||||
weight_decay=0.0,
|
|
||||||
optimizer="adam",
|
|
||||||
n_jobs=10,
|
|
||||||
GPU=0,
|
|
||||||
seed=None,
|
|
||||||
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
|
|
||||||
pt_model_kwargs={
|
|
||||||
"d_feat": 6,
|
|
||||||
"hidden_size": 64,
|
|
||||||
"num_layers": 2,
|
|
||||||
"dropout": 0.0,
|
|
||||||
},
|
|
||||||
):
|
|
||||||
# Set logger.
|
|
||||||
self.logger = get_module_logger("GeneralPTNN")
|
|
||||||
self.logger.info("GeneralPTNN pytorch version...")
|
|
||||||
|
|
||||||
# set hyper-parameters.
|
|
||||||
self.n_epochs = n_epochs
|
|
||||||
self.lr = lr
|
|
||||||
self.metric = metric
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.early_stop = early_stop
|
|
||||||
self.optimizer = optimizer.lower()
|
|
||||||
self.loss = loss
|
|
||||||
self.weight_decay = weight_decay
|
|
||||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
|
||||||
self.n_jobs = n_jobs
|
|
||||||
self.seed = seed
|
|
||||||
|
|
||||||
self.pt_model_uri, self.pt_model_kwargs = pt_model_uri, pt_model_kwargs
|
|
||||||
self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs})
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
"GeneralPTNN parameters setting:"
|
|
||||||
"\nn_epochs : {}"
|
|
||||||
"\nlr : {}"
|
|
||||||
"\nmetric : {}"
|
|
||||||
"\nbatch_size : {}"
|
|
||||||
"\nearly_stop : {}"
|
|
||||||
"\noptimizer : {}"
|
|
||||||
"\nloss_type : {}"
|
|
||||||
"\ndevice : {}"
|
|
||||||
"\nn_jobs : {}"
|
|
||||||
"\nuse_GPU : {}"
|
|
||||||
"\nweight_decay : {}"
|
|
||||||
"\nseed : {}"
|
|
||||||
"\npt_model_uri: {}"
|
|
||||||
"\npt_model_kwargs: {}".format(
|
|
||||||
n_epochs,
|
|
||||||
lr,
|
|
||||||
metric,
|
|
||||||
batch_size,
|
|
||||||
early_stop,
|
|
||||||
optimizer.lower(),
|
|
||||||
loss,
|
|
||||||
self.device,
|
|
||||||
n_jobs,
|
|
||||||
self.use_gpu,
|
|
||||||
weight_decay,
|
|
||||||
seed,
|
|
||||||
pt_model_uri,
|
|
||||||
pt_model_kwargs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.seed is not None:
|
|
||||||
np.random.seed(self.seed)
|
|
||||||
torch.manual_seed(self.seed)
|
|
||||||
|
|
||||||
self.logger.info("model:\n{:}".format(self.dnn_model))
|
|
||||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
|
|
||||||
|
|
||||||
if optimizer.lower() == "adam":
|
|
||||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)
|
|
||||||
elif optimizer.lower() == "gd":
|
|
||||||
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
|
||||||
|
|
||||||
self.fitted = False
|
|
||||||
self.dnn_model.to(self.device)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_gpu(self):
|
|
||||||
return self.device != torch.device("cpu")
|
|
||||||
|
|
||||||
def mse(self, pred, label, weight):
|
|
||||||
loss = weight * (pred - label) ** 2
|
|
||||||
return torch.mean(loss)
|
|
||||||
|
|
||||||
def loss_fn(self, pred, label, weight=None):
|
|
||||||
mask = ~torch.isnan(label)
|
|
||||||
|
|
||||||
if weight is None:
|
|
||||||
weight = torch.ones_like(label)
|
|
||||||
|
|
||||||
if self.loss == "mse":
|
|
||||||
return self.mse(pred[mask], label[mask], weight[mask])
|
|
||||||
|
|
||||||
raise ValueError("unknown loss `%s`" % self.loss)
|
|
||||||
|
|
||||||
def metric_fn(self, pred, label):
|
|
||||||
mask = torch.isfinite(label)
|
|
||||||
|
|
||||||
if self.metric in ("", "loss"):
|
|
||||||
return -self.loss_fn(pred[mask], label[mask])
|
|
||||||
|
|
||||||
raise ValueError("unknown metric `%s`" % self.metric)
|
|
||||||
|
|
||||||
def _get_fl(self, data: torch.Tensor):
|
|
||||||
"""
|
|
||||||
get feature and label from data
|
|
||||||
- Handle the different data shape of time series and tabular data
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
data : torch.Tensor
|
|
||||||
input data which maybe 3 dimension or 2 dimension
|
|
||||||
- 3dim: [batch_size, time_step, feature_dim]
|
|
||||||
- 2dim: [batch_size, feature_dim]
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
"""
|
|
||||||
if data.dim() == 3:
|
|
||||||
# it is a time series dataset
|
|
||||||
feature = data[:, :, 0:-1].to(self.device)
|
|
||||||
label = data[:, -1, -1].to(self.device)
|
|
||||||
elif data.dim() == 2:
|
|
||||||
# it is a tabular dataset
|
|
||||||
feature = data[:, 0:-1].to(self.device)
|
|
||||||
label = data[:, -1].to(self.device)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported data shape.")
|
|
||||||
return feature, label
|
|
||||||
|
|
||||||
def train_epoch(self, data_loader):
|
|
||||||
self.dnn_model.train()
|
|
||||||
|
|
||||||
for data, weight in data_loader:
|
|
||||||
feature, label = self._get_fl(data)
|
|
||||||
|
|
||||||
pred = self.dnn_model(feature.float())
|
|
||||||
loss = self.loss_fn(pred, label, weight.to(self.device))
|
|
||||||
|
|
||||||
self.train_optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_value_(self.dnn_model.parameters(), 3.0)
|
|
||||||
self.train_optimizer.step()
|
|
||||||
|
|
||||||
def test_epoch(self, data_loader):
|
|
||||||
self.dnn_model.eval()
|
|
||||||
|
|
||||||
scores = []
|
|
||||||
losses = []
|
|
||||||
|
|
||||||
for data, weight in data_loader:
|
|
||||||
feature, label = self._get_fl(data)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
pred = self.dnn_model(feature.float())
|
|
||||||
loss = self.loss_fn(pred, label, weight.to(self.device))
|
|
||||||
losses.append(loss.item())
|
|
||||||
|
|
||||||
score = self.metric_fn(pred, label)
|
|
||||||
scores.append(score.item())
|
|
||||||
|
|
||||||
return np.mean(losses), np.mean(scores)
|
|
||||||
|
|
||||||
def fit(
|
|
||||||
self,
|
|
||||||
dataset: Union[DatasetH, TSDatasetH],
|
|
||||||
evals_result=dict(),
|
|
||||||
save_path=None,
|
|
||||||
reweighter=None,
|
|
||||||
):
|
|
||||||
ists = isinstance(dataset, TSDatasetH) # is this time series dataset
|
|
||||||
|
|
||||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
|
||||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
|
||||||
if dl_train.empty or dl_valid.empty:
|
|
||||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
|
||||||
|
|
||||||
if reweighter is None:
|
|
||||||
wl_train = np.ones(len(dl_train))
|
|
||||||
wl_valid = np.ones(len(dl_valid))
|
|
||||||
elif isinstance(reweighter, Reweighter):
|
|
||||||
wl_train = reweighter.reweight(dl_train)
|
|
||||||
wl_valid = reweighter.reweight(dl_valid)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported reweighter type.")
|
|
||||||
|
|
||||||
# Preprocess for data. To align to Dataset Interface for DataLoader
|
|
||||||
if ists:
|
|
||||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
|
||||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
|
||||||
else:
|
|
||||||
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
|
|
||||||
dl_train = dl_train.values
|
|
||||||
dl_valid = dl_valid.values
|
|
||||||
|
|
||||||
train_loader = DataLoader(
|
|
||||||
ConcatDataset(dl_train, wl_train),
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=self.n_jobs,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
valid_loader = DataLoader(
|
|
||||||
ConcatDataset(dl_valid, wl_valid),
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=self.n_jobs,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
del dl_train, dl_valid, wl_train, wl_valid
|
|
||||||
|
|
||||||
save_path = get_or_create_path(save_path)
|
|
||||||
|
|
||||||
stop_steps = 0
|
|
||||||
train_loss = 0
|
|
||||||
best_score = -np.inf
|
|
||||||
best_epoch = 0
|
|
||||||
evals_result["train"] = []
|
|
||||||
evals_result["valid"] = []
|
|
||||||
|
|
||||||
# train
|
|
||||||
self.logger.info("training...")
|
|
||||||
self.fitted = True
|
|
||||||
|
|
||||||
for step in range(self.n_epochs):
|
|
||||||
self.logger.info("Epoch%d:", step)
|
|
||||||
self.logger.info("training...")
|
|
||||||
self.train_epoch(train_loader)
|
|
||||||
self.logger.info("evaluating...")
|
|
||||||
train_loss, train_score = self.test_epoch(train_loader)
|
|
||||||
val_loss, val_score = self.test_epoch(valid_loader)
|
|
||||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
|
||||||
evals_result["train"].append(train_score)
|
|
||||||
evals_result["valid"].append(val_score)
|
|
||||||
|
|
||||||
if step == 0:
|
|
||||||
best_param = copy.deepcopy(self.dnn_model.state_dict())
|
|
||||||
if val_score > best_score:
|
|
||||||
best_score = val_score
|
|
||||||
stop_steps = 0
|
|
||||||
best_epoch = step
|
|
||||||
best_param = copy.deepcopy(self.dnn_model.state_dict())
|
|
||||||
else:
|
|
||||||
stop_steps += 1
|
|
||||||
if stop_steps >= self.early_stop:
|
|
||||||
self.logger.info("early stop")
|
|
||||||
break
|
|
||||||
|
|
||||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
|
||||||
self.dnn_model.load_state_dict(best_param)
|
|
||||||
torch.save(best_param, save_path)
|
|
||||||
|
|
||||||
if self.use_gpu:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def predict(
|
|
||||||
self,
|
|
||||||
dataset: Union[DatasetH, TSDatasetH],
|
|
||||||
batch_size=None,
|
|
||||||
n_jobs=None,
|
|
||||||
):
|
|
||||||
if not self.fitted:
|
|
||||||
raise ValueError("model is not fitted yet!")
|
|
||||||
|
|
||||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
|
||||||
|
|
||||||
if isinstance(dataset, TSDatasetH):
|
|
||||||
dl_test.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
|
||||||
index = dl_test.get_index()
|
|
||||||
else:
|
|
||||||
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
|
|
||||||
index = dl_test.index
|
|
||||||
dl_test = dl_test.values
|
|
||||||
|
|
||||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
|
||||||
self.dnn_model.eval()
|
|
||||||
preds = []
|
|
||||||
|
|
||||||
for data in test_loader:
|
|
||||||
feature, _ = self._get_fl(data)
|
|
||||||
feature = feature.to(self.device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
pred = self.dnn_model(feature.float()).detach().cpu().numpy()
|
|
||||||
|
|
||||||
preds.append(pred)
|
|
||||||
|
|
||||||
preds_concat = np.concatenate(preds)
|
|
||||||
if preds_concat.ndim != 1:
|
|
||||||
preds_concat = preds_concat.ravel()
|
|
||||||
|
|
||||||
return pd.Series(preds_concat, index=index)
|
|
||||||
@@ -52,7 +52,7 @@ class GRU(Model):
|
|||||||
optimizer="adam",
|
optimizer="adam",
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("GRU")
|
self.logger = get_module_logger("GRU")
|
||||||
@@ -317,6 +317,7 @@ class GRU(Model):
|
|||||||
|
|
||||||
|
|
||||||
class GRUModel(nn.Module):
|
class GRUModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
|
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class GRU(Model):
|
|||||||
n_jobs=10,
|
n_jobs=10,
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("GRU")
|
self.logger = get_module_logger("GRU")
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class HIST(Model):
|
|||||||
optimizer="adam",
|
optimizer="adam",
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("HIST")
|
self.logger = get_module_logger("HIST")
|
||||||
@@ -256,7 +256,7 @@ class HIST(Model):
|
|||||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||||
|
|
||||||
if not os.path.exists(self.stock2concept):
|
if not os.path.exists(self.stock2concept):
|
||||||
url = "https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/qlib_csi300_stock2concept.npy"
|
url = "http://fintech.msra.cn/stock_data/downloads/qlib_csi300_stock2concept.npy"
|
||||||
urllib.request.urlretrieve(url, self.stock2concept)
|
urllib.request.urlretrieve(url, self.stock2concept)
|
||||||
|
|
||||||
stock_index = np.load(self.stock_index, allow_pickle=True).item()
|
stock_index = np.load(self.stock_index, allow_pickle=True).item()
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class IGMTF(Model):
|
|||||||
optimizer="adam",
|
optimizer="adam",
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("IGMTF")
|
self.logger = get_module_logger("IGMTF")
|
||||||
|
|||||||
@@ -255,7 +255,7 @@ class KRNN(Model):
|
|||||||
optimizer="adam",
|
optimizer="adam",
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("KRNN")
|
self.logger = get_module_logger("KRNN")
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class LocalformerModel(Model):
|
|||||||
n_jobs=10,
|
n_jobs=10,
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# set hyper-parameters.
|
# set hyper-parameters.
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class LocalformerModel(Model):
|
|||||||
n_jobs=10,
|
n_jobs=10,
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# set hyper-parameters.
|
# set hyper-parameters.
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class LSTM(Model):
|
|||||||
optimizer="adam",
|
optimizer="adam",
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("LSTM")
|
self.logger = get_module_logger("LSTM")
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class LSTM(Model):
|
|||||||
n_jobs=10,
|
n_jobs=10,
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("LSTM")
|
self.logger = get_module_logger("LSTM")
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class SandwichModel(nn.Module):
|
|||||||
rnn_layers,
|
rnn_layers,
|
||||||
dropout,
|
dropout,
|
||||||
device,
|
device,
|
||||||
**params,
|
**params
|
||||||
):
|
):
|
||||||
"""Build a Sandwich model
|
"""Build a Sandwich model
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ class Sandwich(Model):
|
|||||||
optimizer="adam",
|
optimizer="adam",
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("Sandwich")
|
self.logger = get_module_logger("Sandwich")
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ class SFM(Model):
|
|||||||
optimizer="gd",
|
optimizer="gd",
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("SFM")
|
self.logger = get_module_logger("SFM")
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class TCN(Model):
|
|||||||
optimizer="adam",
|
optimizer="adam",
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("TCN")
|
self.logger = get_module_logger("TCN")
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class TCN(Model):
|
|||||||
n_jobs=10,
|
n_jobs=10,
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("TCN")
|
self.logger = get_module_logger("TCN")
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class TCTS(Model):
|
|||||||
mode="soft",
|
mode="soft",
|
||||||
seed=None,
|
seed=None,
|
||||||
lowest_valid_performance=0.993,
|
lowest_valid_performance=0.993,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("TCTS")
|
self.logger = get_module_logger("TCTS")
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class TransformerModel(Model):
|
|||||||
n_jobs=10,
|
n_jobs=10,
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# set hyper-parameters.
|
# set hyper-parameters.
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class TransformerModel(Model):
|
|||||||
n_jobs=10,
|
n_jobs=10,
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
# set hyper-parameters.
|
# set hyper-parameters.
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class XGBModel(Model, FeatureInt):
|
|||||||
verbose_eval=20,
|
verbose_eval=20,
|
||||||
evals_result=dict(),
|
evals_result=dict(),
|
||||||
reweighter=None,
|
reweighter=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
df_train, df_valid = dataset.prepare(
|
df_train, df_valid = dataset.prepare(
|
||||||
["train", "valid"],
|
["train", "valid"],
|
||||||
@@ -63,7 +63,7 @@ class XGBModel(Model, FeatureInt):
|
|||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
verbose_eval=verbose_eval,
|
verbose_eval=verbose_eval,
|
||||||
evals_result=evals_result,
|
evals_result=evals_result,
|
||||||
**kwargs,
|
**kwargs
|
||||||
)
|
)
|
||||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||||
|
|||||||
@@ -4,10 +4,10 @@
|
|||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
|
import yaml
|
||||||
import pathlib
|
import pathlib
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import shutil
|
import shutil
|
||||||
from ruamel.yaml import YAML
|
|
||||||
from ...backtest.account import Account
|
from ...backtest.account import Account
|
||||||
from .user import User
|
from .user import User
|
||||||
from .utils import load_instance, save_instance
|
from .utils import load_instance, save_instance
|
||||||
@@ -110,8 +110,7 @@ class UserManager:
|
|||||||
raise ValueError("User data for {} already exists".format(user_id))
|
raise ValueError("User data for {} already exists".format(user_id))
|
||||||
|
|
||||||
with config_file.open("r") as fp:
|
with config_file.open("r") as fp:
|
||||||
yaml = YAML(typ="safe", pure=True)
|
config = yaml.safe_load(fp)
|
||||||
config = yaml.load(fp)
|
|
||||||
# load model
|
# load model
|
||||||
model = init_instance_by_config(config["model"])
|
model = init_instance_by_config(config["model"])
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@
|
|||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import pickle
|
import pickle
|
||||||
|
import yaml
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from ruamel.yaml import YAML
|
|
||||||
from ...data import D
|
from ...data import D
|
||||||
from ...config import C
|
from ...config import C
|
||||||
from ...log import get_module_logger
|
from ...log import get_module_logger
|
||||||
@@ -91,8 +91,7 @@ def prepare(um, today, user_id, exchange_config=None):
|
|||||||
dates.append(get_next_trading_date(dates[-1], future=True))
|
dates.append(get_next_trading_date(dates[-1], future=True))
|
||||||
if exchange_config:
|
if exchange_config:
|
||||||
with pathlib.Path(exchange_config).open("r") as fp:
|
with pathlib.Path(exchange_config).open("r") as fp:
|
||||||
yaml = YAML(typ="safe", pure=True)
|
exchange_paras = yaml.safe_load(fp)
|
||||||
exchange_paras = yaml.load(fp)
|
|
||||||
else:
|
else:
|
||||||
exchange_paras = {}
|
exchange_paras = {}
|
||||||
trade_exchange = Exchange(trade_dates=dates, **exchange_paras)
|
trade_exchange = Exchange(trade_dates=dates, **exchange_paras)
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ class HeatmapGraph(BaseGraph):
|
|||||||
x=self._df.columns,
|
x=self._df.columns,
|
||||||
y=self._df.index,
|
y=self._df.index,
|
||||||
z=self._df.values.tolist(),
|
z=self._df.values.tolist(),
|
||||||
**self._graph_kwargs,
|
**self._graph_kwargs
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return _data
|
return _data
|
||||||
@@ -213,7 +213,7 @@ class SubplotsGraph:
|
|||||||
sub_graph_layout: dict = None,
|
sub_graph_layout: dict = None,
|
||||||
sub_graph_data: list = None,
|
sub_graph_data: list = None,
|
||||||
subplots_kwargs: dict = None,
|
subplots_kwargs: dict = None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -355,7 +355,7 @@ class SubplotsGraph:
|
|||||||
df=self._df.loc[:, [column_name]],
|
df=self._df.loc[:, [column_name]],
|
||||||
name_dict={column_name: temp_name},
|
name_dict={column_name: temp_name},
|
||||||
graph_kwargs=_graph_kwargs,
|
graph_kwargs=_graph_kwargs,
|
||||||
),
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise TypeError()
|
raise TypeError()
|
||||||
|
|||||||
@@ -2,11 +2,11 @@
|
|||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ruamel.yaml import YAML
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import yaml
|
||||||
|
|
||||||
from qlib import auto_init
|
from qlib import auto_init
|
||||||
from qlib.log import get_module_logger
|
from qlib.log import get_module_logger
|
||||||
@@ -117,8 +117,7 @@ class Rolling:
|
|||||||
|
|
||||||
def _raw_conf(self) -> dict:
|
def _raw_conf(self) -> dict:
|
||||||
with self.conf_path.open("r") as f:
|
with self.conf_path.open("r") as f:
|
||||||
yaml = YAML(typ="safe", pure=True)
|
return yaml.safe_load(f)
|
||||||
return yaml.load(f)
|
|
||||||
|
|
||||||
def _replace_handler_with_cache(self, task: dict):
|
def _replace_handler_with_cache(self, task: dict):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,9 +4,9 @@
|
|||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
|
import yaml
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
from ruamel.yaml import YAML
|
|
||||||
|
|
||||||
|
|
||||||
class TunerConfigManager:
|
class TunerConfigManager:
|
||||||
@@ -16,8 +16,7 @@ class TunerConfigManager:
|
|||||||
self.config_path = config_path
|
self.config_path = config_path
|
||||||
|
|
||||||
with open(config_path) as fp:
|
with open(config_path) as fp:
|
||||||
yaml = YAML(typ="safe", pure=True)
|
config = yaml.safe_load(fp)
|
||||||
config = yaml.load(fp)
|
|
||||||
self.config = copy.deepcopy(config)
|
self.config = copy.deepcopy(config)
|
||||||
|
|
||||||
self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self)
|
self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self)
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ class DataLoader(abc.ABC):
|
|||||||
----------
|
----------
|
||||||
instruments : str or dict
|
instruments : str or dict
|
||||||
it can either be the market name or the config file of instruments generated by InstrumentProvider.
|
it can either be the market name or the config file of instruments generated by InstrumentProvider.
|
||||||
If the value of instruments is None, it means that no filtering is done.
|
|
||||||
start_time : str
|
start_time : str
|
||||||
start of the time range.
|
start of the time range.
|
||||||
end_time : str
|
end_time : str
|
||||||
@@ -51,11 +50,6 @@ class DataLoader(abc.ABC):
|
|||||||
-------
|
-------
|
||||||
pd.DataFrame:
|
pd.DataFrame:
|
||||||
data load from the under layer source
|
data load from the under layer source
|
||||||
|
|
||||||
Raise
|
|
||||||
-----
|
|
||||||
KeyError:
|
|
||||||
if the instruments filter is not supported, raise KeyError
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -326,13 +320,7 @@ class NestedDataLoader(DataLoader):
|
|||||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||||
df_full = None
|
df_full = None
|
||||||
for dl in self.data_loader_l:
|
for dl in self.data_loader_l:
|
||||||
try:
|
df_current = dl.load(instruments, start_time, end_time)
|
||||||
df_current = dl.load(instruments, start_time, end_time)
|
|
||||||
except KeyError:
|
|
||||||
warnings.warn(
|
|
||||||
"If the value of `instruments` cannot be processed, it will set instruments to None to get all the data."
|
|
||||||
)
|
|
||||||
df_current = dl.load(instruments=None, start_time=start_time, end_time=end_time)
|
|
||||||
if df_full is None:
|
if df_full is None:
|
||||||
df_full = df_current
|
df_full = df_current
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -104,24 +104,15 @@ class HashingStockStorage(BaseHandlerStorage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
stock_selector = slice(None)
|
stock_selector = slice(None)
|
||||||
time_selector = slice(None) # by default not filter by time.
|
|
||||||
|
|
||||||
if level is None:
|
if level is None:
|
||||||
# For directly applying.
|
|
||||||
if isinstance(selector, tuple) and self.stock_level < len(selector):
|
if isinstance(selector, tuple) and self.stock_level < len(selector):
|
||||||
# full selector format
|
|
||||||
stock_selector = selector[self.stock_level]
|
stock_selector = selector[self.stock_level]
|
||||||
time_selector = selector[1 - self.stock_level]
|
|
||||||
elif isinstance(selector, (list, str)) and self.stock_level == 0:
|
elif isinstance(selector, (list, str)) and self.stock_level == 0:
|
||||||
# only stock selector
|
|
||||||
stock_selector = selector
|
stock_selector = selector
|
||||||
elif level in ("instrument", self.stock_level):
|
elif level in ("instrument", self.stock_level):
|
||||||
if isinstance(selector, tuple):
|
if isinstance(selector, tuple):
|
||||||
# NOTE: How could the stock level selector be a tuple?
|
|
||||||
stock_selector = selector[0]
|
stock_selector = selector[0]
|
||||||
raise TypeError(
|
|
||||||
"I forget why would this case appear. But I think it does not make sense. So we raise a error for that case."
|
|
||||||
)
|
|
||||||
elif isinstance(selector, (list, str)):
|
elif isinstance(selector, (list, str)):
|
||||||
stock_selector = selector
|
stock_selector = selector
|
||||||
|
|
||||||
@@ -129,7 +120,7 @@ class HashingStockStorage(BaseHandlerStorage):
|
|||||||
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
|
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
|
||||||
|
|
||||||
if stock_selector == slice(None):
|
if stock_selector == slice(None):
|
||||||
return self.hash_df, time_selector
|
return self.hash_df
|
||||||
|
|
||||||
if isinstance(stock_selector, str):
|
if isinstance(stock_selector, str):
|
||||||
stock_selector = [stock_selector]
|
stock_selector = [stock_selector]
|
||||||
@@ -138,7 +129,7 @@ class HashingStockStorage(BaseHandlerStorage):
|
|||||||
for each_stock in sorted(stock_selector):
|
for each_stock in sorted(stock_selector):
|
||||||
if each_stock in self.hash_df:
|
if each_stock in self.hash_df:
|
||||||
select_dict[each_stock] = self.hash_df[each_stock]
|
select_dict[each_stock] = self.hash_df[each_stock]
|
||||||
return select_dict, time_selector
|
return select_dict
|
||||||
|
|
||||||
def fetch(
|
def fetch(
|
||||||
self,
|
self,
|
||||||
@@ -147,13 +138,10 @@ class HashingStockStorage(BaseHandlerStorage):
|
|||||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||||
fetch_orig: bool = True,
|
fetch_orig: bool = True,
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
fetch_stock_df_list, time_selector = self._fetch_hash_df_by_stock(selector=selector, level=level)
|
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
|
||||||
fetch_stock_df_list = list(fetch_stock_df_list.values())
|
|
||||||
for _index, stock_df in enumerate(fetch_stock_df_list):
|
for _index, stock_df in enumerate(fetch_stock_df_list):
|
||||||
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
|
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
|
||||||
fetch_index_df = fetch_df_by_index(
|
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig)
|
||||||
df=fetch_col_df, selector=time_selector, level="datetime", fetch_orig=fetch_orig
|
|
||||||
)
|
|
||||||
fetch_stock_df_list[_index] = fetch_index_df
|
fetch_stock_df_list[_index] = fetch_index_df
|
||||||
if len(fetch_stock_df_list) == 0:
|
if len(fetch_stock_df_list) == 0:
|
||||||
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
|
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
|
||||||
|
|||||||
@@ -164,7 +164,6 @@ class SeriesDFilter(BaseDFilter):
|
|||||||
timestamp = []
|
timestamp = []
|
||||||
_lbool = None
|
_lbool = None
|
||||||
_ltime = None
|
_ltime = None
|
||||||
_cur_start = None
|
|
||||||
for _ts, _bool in timestamp_series.items():
|
for _ts, _bool in timestamp_series.items():
|
||||||
# there is likely to be NAN when the filter series don't have the
|
# there is likely to be NAN when the filter series don't have the
|
||||||
# bool value, so we just change the NAN into False
|
# bool value, so we just change the NAN into False
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from ruamel.yaml import YAML
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
DELETE_KEY = "_delete_"
|
DELETE_KEY = "_delete_"
|
||||||
@@ -56,8 +57,7 @@ def parse_backtest_config(path: str) -> dict:
|
|||||||
del sys.modules[tmp_module_name]
|
del sys.modules[tmp_module_name]
|
||||||
else:
|
else:
|
||||||
with open(tmp_config_file.name) as input_stream:
|
with open(tmp_config_file.name) as input_stream:
|
||||||
yaml = YAML(typ="safe", pure=True)
|
config = yaml.safe_load(input_stream)
|
||||||
config = yaml.load(input_stream)
|
|
||||||
|
|
||||||
if "_base_" in config:
|
if "_base_" in config:
|
||||||
base_file_name = config.pop("_base_")
|
base_file_name = config.pop("_base_")
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import random
|
|||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ruamel.yaml import YAML
|
|
||||||
from typing import cast, List, Optional
|
from typing import cast, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
import yaml
|
||||||
from qlib.backtest import Order
|
from qlib.backtest import Order
|
||||||
from qlib.backtest.decision import OrderDir
|
from qlib.backtest.decision import OrderDir
|
||||||
from qlib.constant import ONE_MIN
|
from qlib.constant import ONE_MIN
|
||||||
@@ -263,7 +263,6 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with open(args.config_path, "r") as input_stream:
|
with open(args.config_path, "r") as input_stream:
|
||||||
yaml = YAML(typ="safe", pure=True)
|
config = yaml.safe_load(input_stream)
|
||||||
config = yaml.load(input_stream)
|
|
||||||
|
|
||||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import yaml
|
||||||
import redis
|
import redis
|
||||||
import bisect
|
import bisect
|
||||||
import struct
|
import struct
|
||||||
@@ -24,7 +25,6 @@ import pandas as pd
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union, Optional, Callable
|
from typing import List, Union, Optional, Callable
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from ruamel.yaml import YAML
|
|
||||||
from .file import (
|
from .file import (
|
||||||
get_or_create_path,
|
get_or_create_path,
|
||||||
save_multiple_parts_file,
|
save_multiple_parts_file,
|
||||||
@@ -244,13 +244,12 @@ def parse_config(config):
|
|||||||
if not isinstance(config, str):
|
if not isinstance(config, str):
|
||||||
return config
|
return config
|
||||||
# Check whether config is file
|
# Check whether config is file
|
||||||
yaml = YAML(typ="safe", pure=True)
|
|
||||||
if os.path.exists(config):
|
if os.path.exists(config):
|
||||||
with open(config, "r") as f:
|
with open(config, "r") as f:
|
||||||
return yaml.load(f)
|
return yaml.safe_load(f)
|
||||||
# Check whether the str can be parsed
|
# Check whether the str can be parsed
|
||||||
try:
|
try:
|
||||||
return yaml.load(config)
|
return yaml.safe_load(config)
|
||||||
except BaseException as base_exp:
|
except BaseException as base_exp:
|
||||||
raise ValueError("cannot parse config!") from base_exp
|
raise ValueError("cannot parse config!") from base_exp
|
||||||
|
|
||||||
@@ -800,7 +799,6 @@ def fill_placeholder(config: dict, config_extend: dict):
|
|||||||
)
|
)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
item_keys = None
|
|
||||||
while top < tail:
|
while top < tail:
|
||||||
now_item = item_queue[top]
|
now_item = item_queue[top]
|
||||||
top += 1
|
top += 1
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ def concat(data_list: Union[SingleData], axis=0) -> MultiData:
|
|||||||
all_index_map = dict(zip(all_index, range(len(all_index))))
|
all_index_map = dict(zip(all_index, range(len(all_index))))
|
||||||
|
|
||||||
# concat all
|
# concat all
|
||||||
tmp_data = np.full((len(all_index), len(data_list)), np.nan)
|
tmp_data = np.full((len(all_index), len(data_list)), np.NaN)
|
||||||
for data_id, index_data in enumerate(data_list):
|
for data_id, index_data in enumerate(data_list):
|
||||||
assert isinstance(index_data, SingleData)
|
assert isinstance(index_data, SingleData)
|
||||||
now_data_map = [all_index_map[index] for index in index_data.index]
|
now_data_map = [all_index_map[index] for index in index_data.index]
|
||||||
@@ -64,7 +64,7 @@ def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) ->
|
|||||||
new_index : list
|
new_index : list
|
||||||
the new_index of new SingleData.
|
the new_index of new SingleData.
|
||||||
fill_value : float
|
fill_value : float
|
||||||
fill the missing values or replace np.nan.
|
fill the missing values or replace np.NaN.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -444,7 +444,7 @@ class IndexData(metaclass=index_data_ops_creator):
|
|||||||
return self.__class__(~self.data.astype(bool), *self.indices)
|
return self.__class__(~self.data.astype(bool), *self.indices)
|
||||||
|
|
||||||
def abs(self):
|
def abs(self):
|
||||||
"""get the abs of data except np.nan."""
|
"""get the abs of data except np.NaN."""
|
||||||
tmp_data = np.absolute(self.data)
|
tmp_data = np.absolute(self.data)
|
||||||
return self.__class__(tmp_data, *self.indices)
|
return self.__class__(tmp_data, *self.indices)
|
||||||
|
|
||||||
@@ -566,8 +566,8 @@ class SingleData(IndexData):
|
|||||||
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
|
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
|
||||||
)
|
)
|
||||||
|
|
||||||
def reindex(self, index: Index, fill_value=np.nan) -> SingleData:
|
def reindex(self, index: Index, fill_value=np.NaN) -> SingleData:
|
||||||
"""reindex data and fill the missing value with np.nan.
|
"""reindex data and fill the missing value with np.NaN.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -615,7 +615,7 @@ class SingleData(IndexData):
|
|||||||
return pd.Series(self.data, index=self.index)
|
return pd.Series(self.data, index=self.index)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return str(pd.Series(self.data, index=self.index.tolist()))
|
return str(pd.Series(self.data, index=self.index))
|
||||||
|
|
||||||
|
|
||||||
class MultiData(IndexData):
|
class MultiData(IndexData):
|
||||||
@@ -651,4 +651,4 @@ class MultiData(IndexData):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return str(pd.DataFrame(self.data, index=self.index.tolist(), columns=self.columns.tolist()))
|
return str(pd.DataFrame(self.data, index=self.index, columns=self.columns))
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
import threading
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Callable, Text, Union
|
from typing import Callable, Text, Union
|
||||||
@@ -10,7 +9,7 @@ from joblib import Parallel, delayed
|
|||||||
from joblib._parallel_backends import MultiprocessingBackend
|
from joblib._parallel_backends import MultiprocessingBackend
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from queue import Empty, Queue
|
from queue import Queue
|
||||||
import concurrent
|
import concurrent
|
||||||
|
|
||||||
from qlib.config import C, QlibConfig
|
from qlib.config import C, QlibConfig
|
||||||
@@ -86,17 +85,7 @@ class AsyncCaller:
|
|||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
while True:
|
while True:
|
||||||
# NOTE:
|
data = self._q.get()
|
||||||
# atexit will only trigger when all the threads ended. So it may results in deadlock.
|
|
||||||
# So the child-threading should actively watch the status of main threading to stop itself.
|
|
||||||
main_thread = threading.main_thread()
|
|
||||||
if not main_thread.is_alive():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
data = self._q.get(timeout=1)
|
|
||||||
except Empty:
|
|
||||||
# NOTE: avoid deadlock. make checking main thread possible
|
|
||||||
continue
|
|
||||||
if data == self.STOP_MARK:
|
if data == self.STOP_MARK:
|
||||||
break
|
break
|
||||||
data()
|
data()
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import sys
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
from jinja2 import Template, meta
|
from jinja2 import Template, meta
|
||||||
from ruamel.yaml import YAML
|
import ruamel.yaml as yaml
|
||||||
|
|
||||||
import qlib
|
import qlib
|
||||||
from qlib.config import C
|
from qlib.config import C
|
||||||
@@ -104,8 +104,7 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
|||||||
"""
|
"""
|
||||||
# Render the template
|
# Render the template
|
||||||
rendered_yaml = render_template(config_path)
|
rendered_yaml = render_template(config_path)
|
||||||
yaml = YAML(typ="safe", pure=True)
|
config = yaml.safe_load(rendered_yaml)
|
||||||
config = yaml.load(rendered_yaml)
|
|
||||||
|
|
||||||
base_config_path = config.get("BASE_CONFIG_PATH", None)
|
base_config_path = config.get("BASE_CONFIG_PATH", None)
|
||||||
if base_config_path:
|
if base_config_path:
|
||||||
@@ -127,8 +126,7 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
|||||||
raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}")
|
raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}")
|
||||||
|
|
||||||
with open(path) as fp:
|
with open(path) as fp:
|
||||||
yaml = YAML(typ="safe", pure=True)
|
base_config = yaml.safe_load(fp)
|
||||||
base_config = yaml.load(fp)
|
|
||||||
logger.info(f"Load BASE_CONFIG_PATH succeed: {path.resolve()}")
|
logger.info(f"Load BASE_CONFIG_PATH succeed: {path.resolve()}")
|
||||||
config = update_config(base_config, config)
|
config = update_config(base_config, config)
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCod
|
|||||||
from mlflow.entities import ViewType
|
from mlflow.entities import ViewType
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Text
|
from typing import Optional, Text
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from .exp import MLflowExperiment, Experiment
|
from .exp import MLflowExperiment, Experiment
|
||||||
from ..config import C
|
from ..config import C
|
||||||
@@ -234,7 +233,7 @@ class ExpManager:
|
|||||||
# So we supported it in the interface wrapper
|
# So we supported it in the interface wrapper
|
||||||
pr = urlparse(self.uri)
|
pr = urlparse(self.uri)
|
||||||
if pr.scheme == "file":
|
if pr.scheme == "file":
|
||||||
with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110
|
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")): # pylint: disable=E0110
|
||||||
return self.create_exp(experiment_name), True
|
return self.create_exp(experiment_name), True
|
||||||
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
|
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
|
||||||
try:
|
try:
|
||||||
@@ -422,11 +421,7 @@ class MLflowExpManager(ExpManager):
|
|||||||
|
|
||||||
def list_experiments(self):
|
def list_experiments(self):
|
||||||
# retrieve all the existing experiments
|
# retrieve all the existing experiments
|
||||||
mlflow_version = int(mlflow.__version__.split(".", maxsplit=1)[0])
|
exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY)
|
||||||
if mlflow_version >= 2:
|
|
||||||
exps = self.client.search_experiments(view_type=ViewType.ACTIVE_ONLY)
|
|
||||||
else:
|
|
||||||
exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY) # pylint: disable=E1101
|
|
||||||
experiments = dict()
|
experiments = dict()
|
||||||
for exp in exps:
|
for exp in exps:
|
||||||
experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)
|
experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import shutil
|
|||||||
import pickle
|
import pickle
|
||||||
import tempfile
|
import tempfile
|
||||||
import subprocess
|
import subprocess
|
||||||
import platform
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -317,10 +316,7 @@ class MLflowRecorder(Recorder):
|
|||||||
This function will return the directory path of this recorder.
|
This function will return the directory path of this recorder.
|
||||||
"""
|
"""
|
||||||
if self.artifact_uri is not None:
|
if self.artifact_uri is not None:
|
||||||
if platform.system() == "Windows":
|
local_dir_path = Path(self.artifact_uri.lstrip("file:")) / ".."
|
||||||
local_dir_path = Path(self.artifact_uri.lstrip("file:").lstrip("/")).parent
|
|
||||||
else:
|
|
||||||
local_dir_path = Path(self.artifact_uri.lstrip("file:")).parent
|
|
||||||
local_dir_path = str(local_dir_path.resolve())
|
local_dir_path = str(local_dir_path.resolve())
|
||||||
if os.path.isdir(local_dir_path):
|
if os.path.isdir(local_dir_path):
|
||||||
return local_dir_path
|
return local_dir_path
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ termcolor==1.1.0
|
|||||||
tqdm==4.63.0
|
tqdm==4.63.0
|
||||||
trio==0.20.0
|
trio==0.20.0
|
||||||
trio-websocket==0.9.2
|
trio-websocket==0.9.2
|
||||||
urllib3==1.26.19
|
urllib3==1.26.8
|
||||||
wget==3.2
|
wget==3.2
|
||||||
wsproto==1.1.0
|
wsproto==1.1.0
|
||||||
yahooquery==2.2.15
|
yahooquery==2.2.15
|
||||||
|
|||||||
197
setup.py
197
setup.py
@@ -1,6 +1,9 @@
|
|||||||
from setuptools import setup, Extension
|
# Copyright (c) Microsoft Corporation.
|
||||||
import numpy
|
# Licensed under the MIT License.
|
||||||
import os
|
import os
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup, Extension
|
||||||
|
|
||||||
|
|
||||||
def read(rel_path: str) -> str:
|
def read(rel_path: str) -> str:
|
||||||
@@ -17,25 +20,185 @@ def get_version(rel_path: str) -> str:
|
|||||||
raise RuntimeError("Unable to find version string.")
|
raise RuntimeError("Unable to find version string.")
|
||||||
|
|
||||||
|
|
||||||
NUMPY_INCLUDE = numpy.get_include()
|
# Package meta-data.
|
||||||
|
NAME = "pyqlib"
|
||||||
|
DESCRIPTION = "A Quantitative-research Platform"
|
||||||
|
REQUIRES_PYTHON = ">=3.5.0"
|
||||||
|
|
||||||
VERSION = get_version("qlib/__init__.py")
|
VERSION = get_version("qlib/__init__.py")
|
||||||
|
|
||||||
|
# Detect Cython
|
||||||
|
try:
|
||||||
|
import Cython
|
||||||
|
|
||||||
|
ver = Cython.__version__
|
||||||
|
_CYTHON_INSTALLED = ver >= "0.28"
|
||||||
|
except ImportError:
|
||||||
|
_CYTHON_INSTALLED = False
|
||||||
|
|
||||||
|
if not _CYTHON_INSTALLED:
|
||||||
|
print("Required Cython version >= 0.28 is not detected!")
|
||||||
|
print('Please run "pip install --upgrade cython" first.')
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
# What packages are required for this module to be executed?
|
||||||
|
# `estimator` may depend on other packages. In order to reduce dependencies, it is not written here.
|
||||||
|
REQUIRED = [
|
||||||
|
"numpy>=1.12.0, <1.24",
|
||||||
|
"pandas>=0.25.1",
|
||||||
|
"scipy>=1.7.3",
|
||||||
|
"requests>=2.18.0",
|
||||||
|
"sacred>=0.7.4",
|
||||||
|
"python-socketio",
|
||||||
|
"redis>=3.0.1",
|
||||||
|
"python-redis-lock>=3.3.1",
|
||||||
|
"schedule>=0.6.0",
|
||||||
|
"cvxpy>=1.0.21",
|
||||||
|
"hyperopt==0.1.2",
|
||||||
|
"fire>=0.3.1",
|
||||||
|
"statsmodels",
|
||||||
|
"xlrd>=1.0.0",
|
||||||
|
"plotly>=4.12.0",
|
||||||
|
"matplotlib>=3.3",
|
||||||
|
"tables>=3.6.1",
|
||||||
|
"pyyaml>=5.3.1",
|
||||||
|
# To ensure stable operation of the experiment manager, we have limited the version of mlflow,
|
||||||
|
# and we need to verify whether version 2.0 of mlflow can serve qlib properly.
|
||||||
|
"mlflow>=1.12.1, <=1.30.0",
|
||||||
|
# mlflow 1.30.0 requires packaging<22, so we limit the packaging version, otherwise the CI will fail.
|
||||||
|
"packaging<22",
|
||||||
|
"tqdm",
|
||||||
|
"loguru",
|
||||||
|
"lightgbm>=3.3.0",
|
||||||
|
"tornado",
|
||||||
|
"joblib>=0.17.0",
|
||||||
|
# With the upgrading of ruamel.yaml to 0.18, the safe_load method was deprecated,
|
||||||
|
# which would cause qlib.workflow.cli to not work properly,
|
||||||
|
# and no good replacement has been found, so the version of ruamel.yaml has been restricted for now.
|
||||||
|
# Refs: https://pypi.org/project/ruamel.yaml/
|
||||||
|
"ruamel.yaml<=0.17.36",
|
||||||
|
"pymongo==3.7.2", # For task management
|
||||||
|
"scikit-learn>=0.22",
|
||||||
|
"dill",
|
||||||
|
"dataclasses;python_version<'3.7'",
|
||||||
|
"filelock",
|
||||||
|
"jinja2",
|
||||||
|
"gym",
|
||||||
|
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
|
||||||
|
"protobuf<=3.20.1;python_version<='3.8'",
|
||||||
|
"cryptography",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Numpy include
|
||||||
|
NUMPY_INCLUDE = numpy.get_include()
|
||||||
|
|
||||||
|
here = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
|
||||||
|
long_description = f.read()
|
||||||
|
|
||||||
|
|
||||||
|
# Cython Extensions
|
||||||
|
extensions = [
|
||||||
|
Extension(
|
||||||
|
"qlib.data._libs.rolling",
|
||||||
|
["qlib/data/_libs/rolling.pyx"],
|
||||||
|
language="c++",
|
||||||
|
include_dirs=[NUMPY_INCLUDE],
|
||||||
|
),
|
||||||
|
Extension(
|
||||||
|
"qlib.data._libs.expanding",
|
||||||
|
["qlib/data/_libs/expanding.pyx"],
|
||||||
|
language="c++",
|
||||||
|
include_dirs=[NUMPY_INCLUDE],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Where the magic happens:
|
||||||
setup(
|
setup(
|
||||||
|
name=NAME,
|
||||||
version=VERSION,
|
version=VERSION,
|
||||||
ext_modules=[
|
license="MIT Licence",
|
||||||
Extension(
|
url="https://github.com/microsoft/qlib",
|
||||||
"qlib.data._libs.rolling",
|
description=DESCRIPTION,
|
||||||
["qlib/data/_libs/rolling.pyx"],
|
long_description=long_description,
|
||||||
language="c++",
|
long_description_content_type="text/markdown",
|
||||||
include_dirs=[NUMPY_INCLUDE],
|
python_requires=REQUIRES_PYTHON,
|
||||||
),
|
packages=find_packages(exclude=("tests",)),
|
||||||
Extension(
|
# if your package is a single module, use this instead of 'packages':
|
||||||
"qlib.data._libs.expanding",
|
# py_modules=['qlib'],
|
||||||
["qlib/data/_libs/expanding.pyx"],
|
entry_points={
|
||||||
language="c++",
|
# 'console_scripts': ['mycli=mymodule:cli'],
|
||||||
include_dirs=[NUMPY_INCLUDE],
|
"console_scripts": [
|
||||||
),
|
"qrun=qlib.workflow.cli:run",
|
||||||
]
|
],
|
||||||
|
},
|
||||||
|
ext_modules=extensions,
|
||||||
|
install_requires=REQUIRED,
|
||||||
|
extras_require={
|
||||||
|
"dev": [
|
||||||
|
"coverage",
|
||||||
|
"pytest>=3",
|
||||||
|
"sphinx",
|
||||||
|
"sphinx_rtd_theme",
|
||||||
|
"pre-commit",
|
||||||
|
# CI dependencies
|
||||||
|
"wheel",
|
||||||
|
"setuptools",
|
||||||
|
"black",
|
||||||
|
# Version 3.0 of pylint had problems with the build process, so we limited the version of pylint.
|
||||||
|
"pylint<=2.17.6",
|
||||||
|
# Using the latest versions(0.981 and 0.982) of mypy,
|
||||||
|
# the error "multiprocessing.Value()" is detected in the file "qlib/rl/utils/data_queue.py",
|
||||||
|
# If this is fixed in a subsequent version of mypy, then we will revert to the latest version of mypy.
|
||||||
|
# References: https://github.com/python/typeshed/issues/8799
|
||||||
|
"mypy<0.981",
|
||||||
|
"flake8",
|
||||||
|
"nbqa",
|
||||||
|
"jupyter",
|
||||||
|
"nbconvert",
|
||||||
|
# The 5.0.0 version of importlib-metadata removed the deprecated endpoint,
|
||||||
|
# which prevented flake8 from working properly, so we restricted the version of importlib-metadata.
|
||||||
|
# To help ensure the dependencies of flake8 https://github.com/python/importlib_metadata/issues/406
|
||||||
|
"importlib-metadata<5.0.0",
|
||||||
|
"readthedocs_sphinx_ext",
|
||||||
|
"cmake",
|
||||||
|
"lxml",
|
||||||
|
"baostock",
|
||||||
|
"yahooquery",
|
||||||
|
# 2024-05-30 scs has released a new version: 3.2.4.post2,
|
||||||
|
# this version, causes qlib installation to fail, so we've limited the scs version a bit for now.
|
||||||
|
"scs<=3.2.4",
|
||||||
|
"beautifulsoup4",
|
||||||
|
# In version 0.4.11 of tianshou, the code:
|
||||||
|
# logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||||
|
# was changed in PR787,
|
||||||
|
# which causes pytest errors(AttributeError: 'dict' object has no attribute 'info') in CI,
|
||||||
|
# so we restricted the version of tianshou.
|
||||||
|
# References:
|
||||||
|
# https://github.com/thu-ml/tianshou/releases
|
||||||
|
"tianshou<=0.4.10",
|
||||||
|
"gym>=0.24", # If you do not put gym at the end, gym will degrade causing pytest results to fail.
|
||||||
|
],
|
||||||
|
"rl": [
|
||||||
|
"tianshou<=0.4.10",
|
||||||
|
"torch",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
include_package_data=True,
|
||||||
|
classifiers=[
|
||||||
|
# Trove classifiers
|
||||||
|
# Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
|
||||||
|
# 'License :: OSI Approved :: MIT License',
|
||||||
|
"Operating System :: POSIX :: Linux",
|
||||||
|
"Operating System :: Microsoft :: Windows",
|
||||||
|
"Operating System :: MacOS",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Development Status :: 3 - Alpha",
|
||||||
|
"Programming Language :: Python",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.7",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,16 +7,14 @@ import qlib
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent))
|
sys.path.append(str(Path(__file__).resolve().parent))
|
||||||
from qlib.data.dataset.loader import NestedDataLoader, QlibDataLoader
|
from qlib.data.dataset.loader import NestedDataLoader
|
||||||
from qlib.data.dataset.handler import DataHandlerLP
|
|
||||||
from qlib.contrib.data.loader import Alpha158DL, Alpha360DL
|
from qlib.contrib.data.loader import Alpha158DL, Alpha360DL
|
||||||
from qlib.data import D
|
|
||||||
|
|
||||||
|
|
||||||
class TestDataLoader(unittest.TestCase):
|
class TestDataLoader(unittest.TestCase):
|
||||||
|
|
||||||
def test_nested_data_loader(self):
|
def test_nested_data_loader(self):
|
||||||
qlib.init(kernels=1)
|
qlib.init()
|
||||||
nd = NestedDataLoader(
|
nd = NestedDataLoader(
|
||||||
dataloader_l=[
|
dataloader_l=[
|
||||||
{
|
{
|
||||||
@@ -30,7 +28,7 @@ class TestDataLoader(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
# Of course you can use StaticDataLoader
|
# Of course you can use StaticDataLoader
|
||||||
|
|
||||||
dataset = nd.load(start_time="2020-01-01", end_time="2020-01-31")
|
dataset = nd.load()
|
||||||
|
|
||||||
assert dataset is not None
|
assert dataset is not None
|
||||||
|
|
||||||
@@ -46,35 +44,6 @@ class TestDataLoader(unittest.TestCase):
|
|||||||
assert "LABEL0" in columns_list
|
assert "LABEL0" in columns_list
|
||||||
|
|
||||||
# Then you can use it wth DataHandler;
|
# Then you can use it wth DataHandler;
|
||||||
# NOTE: please note that the data processors are missing!!! You should add based on your requirements
|
|
||||||
|
|
||||||
"""
|
|
||||||
dataset.to_pickle("test_df.pkl")
|
|
||||||
nested_data_loader = NestedDataLoader(
|
|
||||||
dataloader_l=[
|
|
||||||
{
|
|
||||||
"class": "qlib.contrib.data.loader.Alpha158DL",
|
|
||||||
"kwargs": {"config": {"label": (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"class": "qlib.contrib.data.loader.Alpha360DL",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"class": "qlib.data.dataset.loader.StaticDataLoader",
|
|
||||||
"kwargs": {"config": "test_df.pkl"},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
data_handler_config = {
|
|
||||||
"start_time": "2008-01-01",
|
|
||||||
"end_time": "2020-08-01",
|
|
||||||
"instruments": "csi300",
|
|
||||||
"data_loader": nested_data_loader,
|
|
||||||
}
|
|
||||||
data_handler = DataHandlerLP(**data_handler_config)
|
|
||||||
data = data_handler.fetch()
|
|
||||||
print(data)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
import unittest
|
import unittest
|
||||||
import platform
|
|
||||||
import mlflow
|
import mlflow
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -27,10 +26,7 @@ class MLflowTest(unittest.TestCase):
|
|||||||
_ = mlflow.tracking.MlflowClient(tracking_uri=str(self.TMP_PATH))
|
_ = mlflow.tracking.MlflowClient(tracking_uri=str(self.TMP_PATH))
|
||||||
end = time.time()
|
end = time.time()
|
||||||
elapsed = end - start
|
elapsed = end - start
|
||||||
if platform.system() == "Linux":
|
self.assertLess(elapsed, 1e-2) # it can be done in less than 10ms
|
||||||
self.assertLess(elapsed, 1e-2) # it can be done in less than 10ms
|
|
||||||
else:
|
|
||||||
self.assertLess(elapsed, 2e-2)
|
|
||||||
print(elapsed)
|
print(elapsed)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class IndexDataTest(unittest.TestCase):
|
|||||||
print(sd.loc[:"c"])
|
print(sd.loc[:"c"])
|
||||||
|
|
||||||
def test_corner_cases(self):
|
def test_corner_cases(self):
|
||||||
sd = idd.MultiData([[1, 2], [3, np.nan]], index=["foo", "bar"], columns=["f", "g"])
|
sd = idd.MultiData([[1, 2], [3, np.NaN]], index=["foo", "bar"], columns=["f", "g"])
|
||||||
print(sd)
|
print(sd)
|
||||||
|
|
||||||
self.assertTrue(np.isnan(sd.loc["bar", "g"]))
|
self.assertTrue(np.isnan(sd.loc["bar", "g"]))
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from qlib.tests import TestAutoData
|
|
||||||
|
|
||||||
|
|
||||||
class TestNN(TestAutoData):
|
|
||||||
def test_both_dataset(self):
|
|
||||||
try:
|
|
||||||
from qlib.contrib.model.pytorch_general_nn import GeneralPTNN
|
|
||||||
from qlib.data.dataset import DatasetH, TSDatasetH
|
|
||||||
from qlib.data.dataset.handler import DataHandlerLP
|
|
||||||
except ImportError:
|
|
||||||
print("Import error.")
|
|
||||||
return
|
|
||||||
|
|
||||||
data_handler_config = {
|
|
||||||
"start_time": "2008-01-01",
|
|
||||||
"end_time": "2020-08-01",
|
|
||||||
"instruments": "csi300",
|
|
||||||
"data_loader": {
|
|
||||||
"class": "QlibDataLoader", # Assuming QlibDataLoader is a string reference to the class
|
|
||||||
"kwargs": {
|
|
||||||
"config": {
|
|
||||||
"feature": [["$high", "$close", "$low"], ["H", "C", "L"]],
|
|
||||||
"label": [["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]],
|
|
||||||
},
|
|
||||||
"freq": "day",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
# TODO: processors
|
|
||||||
"learn_processors": [
|
|
||||||
{
|
|
||||||
"class": "DropnaLabel",
|
|
||||||
},
|
|
||||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
segments = {
|
|
||||||
"train": ["2008-01-01", "2014-12-31"],
|
|
||||||
"valid": ["2015-01-01", "2016-12-31"],
|
|
||||||
"test": ["2017-01-01", "2020-08-01"],
|
|
||||||
}
|
|
||||||
data_handler = DataHandlerLP(**data_handler_config)
|
|
||||||
|
|
||||||
# time-series dataset
|
|
||||||
tsds = TSDatasetH(handler=data_handler, segments=segments)
|
|
||||||
|
|
||||||
# tabular dataset
|
|
||||||
tbds = DatasetH(handler=data_handler, segments=segments)
|
|
||||||
|
|
||||||
model_l = [
|
|
||||||
GeneralPTNN(
|
|
||||||
n_epochs=2,
|
|
||||||
batch_size=32,
|
|
||||||
n_jobs=0,
|
|
||||||
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
|
|
||||||
pt_model_kwargs={
|
|
||||||
"d_feat": 3,
|
|
||||||
"hidden_size": 8,
|
|
||||||
"num_layers": 1,
|
|
||||||
"dropout": 0.0,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
GeneralPTNN(
|
|
||||||
n_epochs=2,
|
|
||||||
batch_size=32,
|
|
||||||
n_jobs=0,
|
|
||||||
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
|
|
||||||
pt_model_kwargs={
|
|
||||||
"input_dim": 3,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
for ds, model in list(zip((tsds, tbds), model_l)):
|
|
||||||
model.fit(ds) # It works
|
|
||||||
model.predict(ds) # It works
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
from random import randint, choice
|
from random import randint, choice
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import logging
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Tuple
|
from typing import Any, Tuple
|
||||||
@@ -69,10 +69,6 @@ class AnyPolicy(BasePolicy):
|
|||||||
|
|
||||||
def test_simple_env_logger(caplog):
|
def test_simple_env_logger(caplog):
|
||||||
set_log_with_config(C.logging_config)
|
set_log_with_config(C.logging_config)
|
||||||
# In order for caplog to capture log messages, we configure it here:
|
|
||||||
# allow logs from the qlib logger to be passed to the parent logger.
|
|
||||||
C.logging_config["loggers"]["qlib"]["propagate"] = True
|
|
||||||
logging.config.dictConfig(C.logging_config)
|
|
||||||
for venv_cls_name in ["dummy", "shmem", "subproc"]:
|
for venv_cls_name in ["dummy", "shmem", "subproc"]:
|
||||||
writer = ConsoleWriter()
|
writer = ConsoleWriter()
|
||||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||||
@@ -84,12 +80,13 @@ def test_simple_env_logger(caplog):
|
|||||||
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
|
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
|
||||||
assert output_file.columns.tolist() == ["reward", "a", "c"]
|
assert output_file.columns.tolist() == ["reward", "a", "c"]
|
||||||
assert len(output_file) >= 30
|
assert len(output_file) >= 30
|
||||||
|
|
||||||
line_counter = 0
|
line_counter = 0
|
||||||
for line in caplog.text.splitlines():
|
for line in caplog.text.splitlines():
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line:
|
if line:
|
||||||
line_counter += 1
|
line_counter += 1
|
||||||
assert re.match(r".*reward .* {2}a .* \(([456])\.\d+\) {2}c .* \((14|15|16)\.\d+\)", line)
|
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
|
||||||
assert line_counter >= 3
|
assert line_counter >= 3
|
||||||
|
|
||||||
|
|
||||||
@@ -140,17 +137,15 @@ class RandomFivePolicy(BasePolicy):
|
|||||||
|
|
||||||
def test_logger_with_env_wrapper():
|
def test_logger_with_env_wrapper():
|
||||||
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
|
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
|
||||||
|
env_wrapper_factory = lambda: EnvWrapper(
|
||||||
|
SimpleSimulator,
|
||||||
|
DummyStateInterpreter(),
|
||||||
|
DummyActionInterpreter(),
|
||||||
|
data_iterator,
|
||||||
|
logger=LogCollector(LogLevel.DEBUG),
|
||||||
|
)
|
||||||
|
|
||||||
def env_wrapper_factory():
|
# loglevel can be debug here because metrics can all dump into csv
|
||||||
return EnvWrapper(
|
|
||||||
SimpleSimulator,
|
|
||||||
DummyStateInterpreter(),
|
|
||||||
DummyActionInterpreter(),
|
|
||||||
data_iterator,
|
|
||||||
logger=LogCollector(LogLevel.DEBUG),
|
|
||||||
)
|
|
||||||
|
|
||||||
# loglevel can be debugged here because metrics can all dump into csv
|
|
||||||
# otherwise, csv writer might crash
|
# otherwise, csv writer might crash
|
||||||
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
|
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
|
||||||
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
|
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
|
||||||
@@ -160,7 +155,7 @@ def test_logger_with_env_wrapper():
|
|||||||
|
|
||||||
output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
|
output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
|
||||||
assert len(output_df) == 20
|
assert len(output_df) == 20
|
||||||
# obs has an increasing trend
|
# obs has a increasing trend
|
||||||
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
|
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
|
||||||
assert (output_df["test_a"] == 233).all()
|
assert (output_df["test_a"] == 233).all()
|
||||||
assert (output_df["test_b"] == 200).all()
|
assert (output_df["test_b"] == 200).all()
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import shutil
|
|||||||
import unittest
|
import unittest
|
||||||
import pytest
|
import pytest
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import baostock as bs
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from qlib.data import D
|
from qlib.data import D
|
||||||
|
|||||||
Reference in New Issue
Block a user