1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

update python version (#1868)

* update python version

* fix: Correct selector handling and add time filtering in storage.py

* fix: convert index and columns to list in repr methods

* feat: Add Makefile for managing project prerequisites

* feat: Add Cython extensions for rolling and expanding operations

* resolve install error

* fix lint error

* fix lint error

* fix lint error

* fix lint error

* fix lint error

* update build package

* update makefile

* update ci yaml

* fix docs build error

* fix ubuntu install error

* fix docs build error

* fix install error

* fix install error

* fix install error

* fix install error

* fix pylint error

* fix pylint error

* fix pylint error

* fix pylint error

* fix pylint error E1123

* fix pylint error R0917

* fix pytest error

* fix pytest error

* fix pytest error

* update code

* update code

* fix ci error

* fix pylint error

* fix black error

* fix pytest error

* fix CI error

* fix CI error

* add python version to CI

* add python version to CI

* add python version to CI

* fix pylint error

* fix pytest general nn error

* fix CI error

* optimize code

* add coments

* Extended macos version

* remove build package

---------

Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
Linlang
2024-12-17 11:30:06 +08:00
committed by GitHub
parent 7acb4f3484
commit a0cef033cb
63 changed files with 460 additions and 426 deletions

View File

@@ -12,43 +12,23 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
os: [windows-latest, macos-13] os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-latest]
# FIXME: macos-latest will raise error now. # FIXME: macos-latest will raise error now.
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129 # not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
# This is because on macos systems you can install pyqlib using
# `pip install pyqlib` installs, it does not recognize the
# `pyqlib-<version>-cp38-cp38-macosx_11_0_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_11_0_x86_64.whl`.
# So we limit the version of python, in order to generate a version of qlib that is usable for macos: `pyqlib-<veresion>-cp38-cp37m
# `pyqlib-<version>-cp38-cp38-macosx_10_15_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_10_15_x86_64.whl`.
# Python 3.7.16, 3.8.16 can build macosx_10_15. But Python 3.7.17, 3.8.17 can build macosx_11_0
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'macos-11' && matrix.python-version == '3.7'
uses: actions/setup-python@v2
with:
python-version: "3.7.16"
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'macos-11' && matrix.python-version == '3.8'
uses: actions/setup-python@v2
with:
python-version: "3.8.16"
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os != 'macos-11'
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip make dev
pip install setuptools wheel twine
- name: Build wheel on ${{ matrix.os }} - name: Build wheel on ${{ matrix.os }}
run: | run: |
pip install numpy make build
pip install cython
python setup.py bdist_wheel
- name: Build and publish - name: Build and publish
env: env:
TWINE_USERNAME: __token__ TWINE_USERNAME: __token__

View File

@@ -19,25 +19,15 @@ jobs:
# If you want to use python 3.7 in github action, then the latest macos system version is macos-13, # If you want to use python 3.7 in github action, then the latest macos system version is macos-13,
# after macos-13 python 3.7 is no longer supported. # after macos-13 python 3.7 is no longer supported.
# so we limit the macos version to macos-13. # so we limit the macos version to macos-13.
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13] os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129 # not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps: steps:
- name: Test qlib from source - name: Test qlib from source
uses: actions/checkout@v3 uses: actions/checkout@v3
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
# So we make the version number of python 3.7 for MacOS more specific.
# refs: https://github.com/actions/setup-python/issues/682
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-13' && matrix.python-version == '3.7')
uses: actions/setup-python@v4
with:
python-version: "3.7.16"
- name: Set up Python ${{ matrix.python-version }}
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-13' || matrix.python-version != '3.7')
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
@@ -47,7 +37,7 @@ jobs:
python -m pip install --upgrade pip python -m pip install --upgrade pip
- name: Installing pytorch for macos - name: Installing pytorch for macos
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-latest' }} if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-latest' }}
run: | run: |
python -m pip install torch torchvision torchaudio python -m pip install torch torchvision torchaudio
@@ -63,87 +53,33 @@ jobs:
- name: Set up Python tools - name: Set up Python tools
run: | run: |
python -m pip install --upgrade cython make dev
python -m pip install -e .[dev]
- name: Lint with Black - name: Lint with Black
# Python 3.7 will use a black with low level. So we use python with higher version for black check
if: (matrix.python-version != '3.7')
run: | run: |
pip install -U black # follow the latest version of black, previous Qlib dependency will downgrade black make black
black . -l 120 --check --diff
- name: Make html with sphinx - name: Make html with sphinx
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04. # Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
if: ${{ matrix.os == 'ubuntu-22.04' }} if: ${{ matrix.os == 'ubuntu-22.04' }}
run: | run: |
cd docs make docs-gen
sphinx-build -W --keep-going -b html . _build
cd ..
# Check Qlib with pylint
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
# C0103: invalid-name
# C0209: consider-using-f-string
# R0402: consider-using-from-import
# R1705: no-else-return
# R1710: inconsistent-return-statements
# R1725: super-with-arguments
# R1735: use-dict-literal
# W0102: dangerous-default-value
# W0212: protected-access
# W0221: arguments-differ
# W0223: abstract-method
# W0231: super-init-not-called
# W0237: arguments-renamed
# W0612: unused-variable
# W0621: redefined-outer-name
# W0622: redefined-builtin
# FIXME: specify exception type
# W0703: broad-except
# W1309: f-string-without-interpolation
# E1102: not-callable
# E1136: unsubscriptable-object
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
- name: Check Qlib with pylint - name: Check Qlib with pylint
run: | run: |
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" make pylint
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
# The following flake8 error codes were ignored:
# E501 line too long
# Description: We have used black to limit the length of each line to 120.
# F541 f-string is missing placeholders
# Description: The same thing is done when using pylint for detection.
# E266 too many leading '#' for block comment
# Description: To make the code more readable, a lot of "#" is used.
# This error code appears centrally in:
# qlib/backtest/executor.py
# qlib/data/ops.py
# qlib/utils/__init__.py
# E402 module level import not at top of file
# Description: There are times when module level import is not available at the top of the file.
# W503 line break before binary operator
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
# E731 do not assign a lambda expression, use a def
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
# E203 whitespace before ':'
# Description: If there is whitespace before ":", it cannot pass the black check.
- name: Check Qlib with flake8 - name: Check Qlib with flake8
run: | run: |
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib make flake8
# https://github.com/python/mypy/issues/10600
- name: Check Qlib with mypy - name: Check Qlib with mypy
run: | run: |
mypy qlib --install-types --non-interactive || true make mypy
mypy qlib --verbose
- name: Check Qlib ipynb with nbqa - name: Check Qlib ipynb with nbqa
run: | run: |
nbqa black . -l 120 --check --diff make nbqa
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$'
- name: Test data downloads - name: Test data downloads
run: | run: |
@@ -151,7 +87,7 @@ jobs:
python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl
- name: Install Lightgbm for MacOS - name: Install Lightgbm for MacOS
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-latest' }} if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-latest' }}
run: | run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)" /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
@@ -161,11 +97,9 @@ jobs:
brew unlink libomp brew unlink libomp
brew install libomp.rb brew install libomp.rb
# Run after data downloads
- name: Check Qlib ipynb with nbconvert - name: Check Qlib ipynb with nbconvert
run: | run: |
# add more ipynb files in future make nbconvert
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
- name: Test workflow by config (install from source) - name: Test workflow by config (install from source)
run: | run: |

View File

@@ -19,41 +19,29 @@ jobs:
# If you want to use python 3.7 in github action, then the latest macos system version is macos-13, # If you want to use python 3.7 in github action, then the latest macos system version is macos-13,
# after macos-13 python 3.7 is no longer supported. # after macos-13 python 3.7 is no longer supported.
# so we limit the macos version to macos-13. # so we limit the macos version to macos-13.
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13] os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129 # not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps: steps:
- name: Test qlib from source slow - name: Test qlib from source slow
uses: actions/checkout@v3 uses: actions/checkout@v3
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
# So we make the version number of python 3.7 for MacOS more specific.
# refs: https://github.com/actions/setup-python/issues/682
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-13' && matrix.python-version == '3.7')
uses: actions/setup-python@v4
with:
python-version: "3.7.16"
- name: Set up Python ${{ matrix.python-version }}
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-13' || matrix.python-version != '3.7')
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Set up Python tools - name: Set up Python tools
run: | run: |
python -m pip install --upgrade pip make dev
pip install --upgrade cython numpy
pip install -e .[dev]
- name: Downloads dependencies data - name: Downloads dependencies data
run: | run: |
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- name: Install Lightgbm for MacOS - name: Install Lightgbm for MacOS
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-latest' }} if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-latest' }}
run: | run: |
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)" /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm

3
.gitignore vendored
View File

@@ -48,4 +48,5 @@ tags
*.swp *.swp
./pretrain ./pretrain
.idea/ .idea/
.aider*

View File

@@ -1 +1,6 @@
include qlib/VERSION.txt exclude tests/*
include qlib/*
include qlib/*/*
include qlib/*/*/*
include qlib/*/*/*/*
include qlib/*/*/*/*/*

195
Makefile Normal file
View File

@@ -0,0 +1,195 @@
.PHONY: clean deepclean prerequisite dependencies lightgbm rl develop lint docs package test analysis all install dev black pylint flake8 mypy nbqa nbconvert lint build upload docs-gen
#You can modify it according to your terminal
SHELL := /bin/bash
########################################################################################
# Variables
########################################################################################
# Documentation target directory, will be adapted to specific folder for readthedocs.
PUBLIC_DIR := $(shell [ "$$READTHEDOCS" = "True" ] && echo "$$READTHEDOCS_OUTPUT/html" || echo "public")
SO_DIR := qlib/data/_libs
SO_FILES := $(wildcard $(SO_DIR)/*.so)
########################################################################################
# Development Environment Management
########################################################################################
# Remove common intermediate files.
clean:
-rm -rf \
$(PUBLIC_DIR) \
qlib/data/_libs/*.cpp \
qlib/data/_libs/*.so \
mlruns \
public \
build \
.coverage \
.mypy_cache \
.pytest_cache \
.ruff_cache \
Pipfile* \
coverage.xml \
dist \
release-notes.md
find . -name '*.egg-info' -print0 | xargs -0 rm -rf
find . -name '*.pyc' -print0 | xargs -0 rm -f
find . -name '*.swp' -print0 | xargs -0 rm -f
find . -name '.DS_Store' -print0 | xargs -0 rm -f
find . -name '__pycache__' -print0 | xargs -0 rm -rf
# Remove pre-commit hook, virtual environment alongside itermediate files.
deepclean: clean
if command -v pre-commit > /dev/null 2>&1; then pre-commit uninstall --hook-type pre-push; fi
if command -v pipenv >/dev/null 2>&1 && pipenv --venv >/dev/null 2>&1; then pipenv --rm; fi
# Prerequisite section
# What this code does is compile two Cython modules, rolling and expanding, using setuptools and Cython,
# and builds them as binary expansion modules that can be imported directly into Python.
# Since pyproject.toml can't do that, we compile it here.
prerequisite:
@if [ -n "$(SO_FILES)" ]; then \
echo "Shared library files exist, skipping build."; \
else \
echo "No shared library files found, building..."; \
pip install --upgrade setuptools wheel; \
python -m pip install cython numpy; \
python -c "from setuptools import setup, Extension; from Cython.Build import cythonize; import numpy; extensions = [Extension('qlib.data._libs.rolling', ['qlib/data/_libs/rolling.pyx'], language='c++', include_dirs=[numpy.get_include()]), Extension('qlib.data._libs.expanding', ['qlib/data/_libs/expanding.pyx'], language='c++', include_dirs=[numpy.get_include()])]; setup(ext_modules=cythonize(extensions, language_level='3'), script_args=['build_ext', '--inplace'])"; \
fi
# Install the package in editable mode.
dependencies:
python -m pip install -e .
lightgbm:
python -m pip install lightgbm --prefer-binary
rl:
python -m pip install -e .[rl]
develop:
python -m pip install -e .[dev]
lint:
python -m pip install -e .[lint]
docs:
python -m pip install -e .[docs]
package:
python -m pip install -e .[package]
test:
python -m pip install -e .[test]
analysis:
python -m pip install -e .[analysis]
all:
python -m pip install -e .[dev,lint,docs,package,test,analysis,rl]
install: prerequisite dependencies
dev: prerequisite all
########################################################################################
# Lint and pre-commit
########################################################################################
# Check lint with black.
black:
black . -l 120 --check --diff
# Check code folder with pylint.
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
# C0103: invalid-name
# C0209: consider-using-f-string
# R0402: consider-using-from-import
# R1705: no-else-return
# R1710: inconsistent-return-statements
# R1725: super-with-arguments
# R1735: use-dict-literal
# W0102: dangerous-default-value
# W0212: protected-access
# W0221: arguments-differ
# W0223: abstract-method
# W0231: super-init-not-called
# W0237: arguments-renamed
# W0612: unused-variable
# W0621: redefined-outer-name
# W0622: redefined-builtin
# FIXME: specify exception type
# W0703: broad-except
# W1309: f-string-without-interpolation
# E1102: not-callable
# E1136: unsubscriptable-object
# W4904: deprecated-class
# R0917: too-many-positional-arguments
# E1123: unexpected-keyword-arg
# References for disable error: https://pylint.pycqa.org/en/latest/user_guide/messages/messages_overview.html
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
pylint:
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,W4904,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1730,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' qlib --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,E1123,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' scripts --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
# Check code with flake8.
# The following flake8 error codes were ignored:
# E501 line too long
# Description: We have used black to limit the length of each line to 120.
# F541 f-string is missing placeholders
# Description: The same thing is done when using pylint for detection.
# E266 too many leading '#' for block comment
# Description: To make the code more readable, a lot of "#" is used.
# This error code appears centrally in:
# qlib/backtest/executor.py
# qlib/data/ops.py
# qlib/utils/__init__.py
# E402 module level import not at top of file
# Description: There are times when module level import is not available at the top of the file.
# W503 line break before binary operator
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
# E731 do not assign a lambda expression, use a def
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
# E203 whitespace before ':'
# Description: If there is whitespace before ":", it cannot pass the black check.
flake8:
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
# Check code with mypy.
# https://github.com/python/mypy/issues/10600
mypy:
mypy qlib --install-types --non-interactive
mypy qlib --verbose
# Check ipynb with nbqa.
nbqa:
nbqa black . -l 120 --check --diff
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}'
# Check ipynb with nbconvert.(Run after data downloads)
# TODO: Add more ipynb files in future
nbconvert:
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
lint: black pylint flake8 mypy nbqa
########################################################################################
# Package
########################################################################################
# Build the package.
build:
python -m build
# Upload the package.
upload:
python -m twine upload dist/*
########################################################################################
# Documentation
########################################################################################
docs-gen:
python -m sphinx.cmd.build -W docs $(PUBLIC_DIR)

View File

@@ -358,7 +358,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
``` ```
Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html). Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports 2. Graphical Reports Analysis: First, run `python -m pip install .[analysis]` to install the required dependencies. Then run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports.
- Forecasting signal (model prediction) analysis - Forecasting signal (model prediction) analysis
- Cumulative Return of groups - Cumulative Return of groups
![Cumulative Return](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_cumulative_return.png) ![Cumulative Return](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_cumulative_return.png)

View File

@@ -1,14 +1,15 @@
import argparse import argparse
import qlib import qlib
import ruamel.yaml as yaml from ruamel.yaml import YAML
from qlib.utils import init_instance_by_config from qlib.utils import init_instance_by_config
def main(seed, config_file="configs/config_alstm.yaml"): def main(seed, config_file="configs/config_alstm.yaml"):
# set random seed # set random seed
with open(config_file) as f: with open(config_file) as f:
config = yaml.safe_load(f) yaml = YAML(typ="safe", pure=True)
config = yaml.load(f)
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}" # seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
seed_suffix = "" seed_suffix = ""

View File

@@ -9,8 +9,8 @@ from copy import deepcopy
from pathlib import Path from pathlib import Path
import pickle import pickle
from pprint import pprint from pprint import pprint
from ruamel.yaml import YAML
import subprocess import subprocess
import yaml
from qlib.log import TimeInspector from qlib.log import TimeInspector
from qlib import init from qlib import init
@@ -30,7 +30,8 @@ if __name__ == "__main__":
subprocess.run(f"qrun {config_path}", shell=True) subprocess.run(f"qrun {config_path}", shell=True)
# 2) dump handler # 2) dump handler
task_config = yaml.safe_load(config_path.open()) yaml = YAML(typ="safe", pure=True)
task_config = yaml.load(config_path.open())
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"] hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
pprint(hd_conf) pprint(hd_conf)
hd: DataHandlerLP = init_instance_by_config(hd_conf) hd: DataHandlerLP = init_instance_by_config(hd_conf)

View File

@@ -9,10 +9,9 @@ from copy import deepcopy
from pathlib import Path from pathlib import Path
import pickle import pickle
from pprint import pprint from pprint import pprint
from ruamel.yaml import YAML
import subprocess import subprocess
import yaml
from qlib import init from qlib import init
from qlib.data.dataset.handler import DataHandlerLP from qlib.data.dataset.handler import DataHandlerLP
from qlib.log import TimeInspector from qlib.log import TimeInspector
@@ -29,7 +28,8 @@ if __name__ == "__main__":
exp_name = "data_mem_reuse_demo" exp_name = "data_mem_reuse_demo"
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml" config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
task_config = yaml.safe_load(config_path.open()) yaml = YAML(typ="safe", pure=True)
task_config = yaml.load(config_path.open())
# 1) without using processed data in memory # 1) without using processed data in memory
with TimeInspector.logt("The original time without reusing processed data in memory:"): with TimeInspector.logt("The original time without reusing processed data in memory:"):

View File

@@ -6,7 +6,6 @@ import sys
import fire import fire
import time import time
import glob import glob
import yaml
import shutil import shutil
import signal import signal
import inspect import inspect
@@ -15,6 +14,7 @@ import functools
import statistics import statistics
import subprocess import subprocess
from datetime import datetime from datetime import datetime
from ruamel.yaml import YAML
from pathlib import Path from pathlib import Path
from operator import xor from operator import xor
from pprint import pprint from pprint import pprint
@@ -188,7 +188,8 @@ def gen_and_save_md_table(metrics, dataset):
# read yaml, remove seed kwargs of model, and then save file in the temp_dir # read yaml, remove seed kwargs of model, and then save file in the temp_dir
def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir): def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
with open(yaml_path, "r") as fp: with open(yaml_path, "r") as fp:
config = yaml.safe_load(fp) yaml = YAML(typ="safe", pure=True)
config = yaml.load(fp)
try: try:
del config["task"]["model"]["kwargs"]["seed"] del config["task"]["model"]["kwargs"]["seed"]
except KeyError: except KeyError:

View File

@@ -1,2 +1,92 @@
[build-system] [build-system]
requires = ["setuptools", "numpy", "Cython"] requires = ["setuptools", "cython", "numpy>=1.24.0"]
build-backend = "setuptools.build_meta"
[project]
classifiers = [
"Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS",
"License :: OSI Approved :: MIT License",
"Development Status :: 3 - Alpha",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
name = "pyqlib"
dynamic = ["version"]
description = "A Quantitative-research Platform"
requires-python = ">=3.8.0"
dependencies = [
"pyyaml",
"numpy",
"pandas",
"mlflow",
"filelock>=3.16.0",
"redis",
"dill",
"fire",
"ruamel.yaml>=0.17.38",
"python-redis-lock",
"tqdm",
"pymongo",
"loguru",
"lightgbm",
"gym",
"cvxpy",
"joblib",
"matplotlib",
"jupyter",
"nbconvert",
]
[project.optional-dependencies]
dev = [
"pytest",
"statsmodels",
]
# On macos-13 system, when using python version greater than or equal to 3.10,
# pytorch can't fully support Numpy version above 2.0, so, when you want to install torch,
# it will limit the version of Numpy less than 2.0.
rl = [
"tianshou<=0.4.10",
"torch",
"numpy<2.0.0",
]
lint = [
"black",
"pylint",
"mypy<1.5.0",
"flake8",
"nbqa",
]
docs = [
"sphinx",
"sphinx_rtd_theme",
"readthedocs_sphinx_ext",
]
package = [
"twine",
"build",
]
# test_pit dependency packages
test = [
"yahooquery",
"baostock",
]
analysis = [
"plotly",
]
[tool.setuptools]
packages = [
"qlib",
]
[project.scripts]
qrun = "qlib.workflow.cli:run"

View File

@@ -6,7 +6,7 @@ __version__ = "0.9.5.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version __version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os import os
from typing import Union from typing import Union
import yaml from ruamel.yaml import YAML
import logging import logging
import platform import platform
import subprocess import subprocess
@@ -176,7 +176,8 @@ def init_from_yaml_conf(conf_path, **kwargs):
config = {} config = {}
else: else:
with open(conf_path) as f: with open(conf_path) as f:
config = yaml.safe_load(f) yaml = YAML(typ="safe", pure=True)
config = yaml.load(f)
config.update(kwargs) config.update(kwargs)
default_conf = config.pop("default_conf", "client") default_conf = config.pop("default_conf", "client")
init(default_conf, **config) init(default_conf, **config)
@@ -272,7 +273,8 @@ def auto_init(**kwargs):
logger = get_module_logger("Initialization") logger = get_module_logger("Initialization")
conf_pp = pp / "config.yaml" conf_pp = pp / "config.yaml"
with conf_pp.open() as f: with conf_pp.open() as f:
conf = yaml.safe_load(f) yaml = YAML(typ="safe", pure=True)
conf = yaml.load(f)
conf_type = conf.get("conf_type", "origin") conf_type = conf.get("conf_type", "origin")
if conf_type == "origin": if conf_type == "origin":

View File

@@ -278,7 +278,7 @@ class BaseSingleMetric:
raise NotImplementedError(f"Please implement the `empty` method") raise NotImplementedError(f"Please implement the `empty` method")
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric: def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
"""Replace np.NaN with fill_value in two metrics and add them.""" """Replace np.nan with fill_value in two metrics and add them."""
raise NotImplementedError(f"Please implement the `add` method") raise NotImplementedError(f"Please implement the `add` method")
@@ -412,7 +412,7 @@ class BaseOrderIndicator:
metrics : Union[str, List[str]] metrics : Union[str, List[str]]
all metrics needs to be sumed. all metrics needs to be sumed.
fill_value : float, optional fill_value : float, optional
fill np.NaN with value. By default None. fill np.nan with value. By default None.
""" """
raise NotImplementedError(f"Please implement the 'sum_all_indicators' method") raise NotImplementedError(f"Please implement the 'sum_all_indicators' method")

View File

@@ -325,9 +325,9 @@ class Indicator:
def _update_order_fulfill_rate(self) -> None: def _update_order_fulfill_rate(self) -> None:
def func(deal_amount, amount): def func(deal_amount, amount):
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0. # deal_amount is np.nan or None when there is no inner decision. So full fill rate is 0.
tmp_deal_amount = deal_amount.reindex(amount.index, 0) tmp_deal_amount = deal_amount.reindex(amount.index, 0)
tmp_deal_amount = tmp_deal_amount.replace({np.NaN: 0}) tmp_deal_amount = tmp_deal_amount.replace({np.nan: 0})
return tmp_deal_amount / amount return tmp_deal_amount / amount
self.order_indicator.transfer(func, "ffr") self.order_indicator.transfer(func, "ffr")
@@ -354,8 +354,8 @@ class Indicator:
) )
def func(trade_price, deal_amount): def func(trade_price, deal_amount):
# trade_price is np.NaN instead of inf when deal_amount is zero. # trade_price is np.nan instead of inf when deal_amount is zero.
tmp_deal_amount = deal_amount.replace({0: np.NaN}) tmp_deal_amount = deal_amount.replace({0: np.nan})
return trade_price / tmp_deal_amount return trade_price / tmp_deal_amount
self.order_indicator.transfer(func, "trade_price") self.order_indicator.transfer(func, "trade_price")
@@ -425,7 +425,7 @@ class Indicator:
assert isinstance(price_s, idd.SingleData) assert isinstance(price_s, idd.SingleData)
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)] price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8 # NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
# ~(np.NaN < 1e-8) -> ~(False) -> True # ~(np.nan < 1e-8) -> ~(False) -> True
assert isinstance(price_s, idd.SingleData) assert isinstance(price_s, idd.SingleData)
if agg == "vwap": if agg == "vwap":

View File

@@ -58,7 +58,7 @@ class Alpha360(DataHandlerLP):
fit_end_time=None, fit_end_time=None,
filter_pipe=None, filter_pipe=None,
inst_processors=None, inst_processors=None,
**kwargs **kwargs,
): ):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
@@ -83,7 +83,7 @@ class Alpha360(DataHandlerLP):
data_loader=data_loader, data_loader=data_loader,
learn_processors=learn_processors, learn_processors=learn_processors,
infer_processors=infer_processors, infer_processors=infer_processors,
**kwargs **kwargs,
) )
def get_label_config(self): def get_label_config(self):
@@ -109,7 +109,7 @@ class Alpha158(DataHandlerLP):
process_type=DataHandlerLP.PTYPE_A, process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None, filter_pipe=None,
inst_processors=None, inst_processors=None,
**kwargs **kwargs,
): ):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
@@ -134,7 +134,7 @@ class Alpha158(DataHandlerLP):
infer_processors=infer_processors, infer_processors=infer_processors,
learn_processors=learn_processors, learn_processors=learn_processors,
process_type=process_type, process_type=process_type,
**kwargs **kwargs,
) )
def get_feature_config(self): def get_feature_config(self):

View File

@@ -33,7 +33,7 @@ class CatBoostModel(Model, FeatureInt):
verbose_eval=20, verbose_eval=20,
evals_result=dict(), evals_result=dict(),
reweighter=None, reweighter=None,
**kwargs **kwargs,
): ):
df_train, df_valid = dataset.prepare( df_train, df_valid = dataset.prepare(
["train", "valid"], ["train", "valid"],

View File

@@ -31,7 +31,7 @@ class DEnsembleModel(Model, FeatureInt):
sub_weights=None, sub_weights=None,
epochs=100, epochs=100,
early_stopping_rounds=None, early_stopping_rounds=None,
**kwargs **kwargs,
): ):
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm" self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
self.num_models = num_models # the number of sub-models self.num_models = num_models # the number of sub-models

View File

@@ -56,7 +56,7 @@ class ADARNN(Model):
n_splits=2, n_splits=2,
GPU=0, GPU=0,
seed=None, seed=None,
**_ **_,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("ADARNN") self.logger = get_module_logger("ADARNN")
@@ -154,10 +154,7 @@ class ADARNN(Model):
self.model.train() self.model.train()
criterion = nn.MSELoss() criterion = nn.MSELoss()
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device) dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
len_loader = np.inf out_weight_list = None
for loader in train_loader_list:
if len(loader) < len_loader:
len_loader = len(loader)
for data_all in zip(*train_loader_list): for data_all in zip(*train_loader_list):
# for data_all in zip(*train_loader_list): # for data_all in zip(*train_loader_list):
self.train_optimizer.zero_grad() self.train_optimizer.zero_grad()
@@ -571,6 +568,7 @@ class TransferLoss:
Returns: Returns:
[tensor] -- transfer loss [tensor] -- transfer loss
""" """
loss = None
if self.loss_type in ("mmd_lin", "mmd"): if self.loss_type in ("mmd_lin", "mmd"):
mmdloss = MMD_loss(kernel_type="linear") mmdloss = MMD_loss(kernel_type="linear")
loss = mmdloss(X, Y) loss = mmdloss(X, Y)

View File

@@ -63,7 +63,7 @@ class ADD(Model):
mu=0.05, mu=0.05,
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("ADD") self.logger = get_module_logger("ADD")

View File

@@ -52,7 +52,7 @@ class ALSTM(Model):
optimizer="adam", optimizer="adam",
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("ALSTM") self.logger = get_module_logger("ALSTM")

View File

@@ -56,7 +56,7 @@ class ALSTM(Model):
n_jobs=10, n_jobs=10,
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("ALSTM") self.logger = get_module_logger("ALSTM")

View File

@@ -56,7 +56,7 @@ class GATs(Model):
optimizer="adam", optimizer="adam",
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("GATs") self.logger = get_module_logger("GATs")

View File

@@ -73,7 +73,7 @@ class GATs(Model):
GPU=0, GPU=0,
n_jobs=10, n_jobs=10,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("GATs") self.logger = get_module_logger("GATs")

View File

@@ -319,7 +319,12 @@ class GeneralPTNN(Model):
if self.use_gpu: if self.use_gpu:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def predict(self, dataset: Union[DatasetH, TSDatasetH]): def predict(
self,
dataset: Union[DatasetH, TSDatasetH],
batch_size=None,
n_jobs=None,
):
if not self.fitted: if not self.fitted:
raise ValueError("model is not fitted yet!") raise ValueError("model is not fitted yet!")

View File

@@ -52,7 +52,7 @@ class GRU(Model):
optimizer="adam", optimizer="adam",
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("GRU") self.logger = get_module_logger("GRU")

View File

@@ -54,7 +54,7 @@ class GRU(Model):
n_jobs=10, n_jobs=10,
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("GRU") self.logger = get_module_logger("GRU")

View File

@@ -59,7 +59,7 @@ class HIST(Model):
optimizer="adam", optimizer="adam",
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("HIST") self.logger = get_module_logger("HIST")

View File

@@ -55,7 +55,7 @@ class IGMTF(Model):
optimizer="adam", optimizer="adam",
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("IGMTF") self.logger = get_module_logger("IGMTF")

View File

@@ -255,7 +255,7 @@ class KRNN(Model):
optimizer="adam", optimizer="adam",
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("KRNN") self.logger = get_module_logger("KRNN")

View File

@@ -44,7 +44,7 @@ class LocalformerModel(Model):
n_jobs=10, n_jobs=10,
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# set hyper-parameters. # set hyper-parameters.
self.d_model = d_model self.d_model = d_model

View File

@@ -42,7 +42,7 @@ class LocalformerModel(Model):
n_jobs=10, n_jobs=10,
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# set hyper-parameters. # set hyper-parameters.
self.d_model = d_model self.d_model = d_model

View File

@@ -51,7 +51,7 @@ class LSTM(Model):
optimizer="adam", optimizer="adam",
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("LSTM") self.logger = get_module_logger("LSTM")

View File

@@ -53,7 +53,7 @@ class LSTM(Model):
n_jobs=10, n_jobs=10,
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("LSTM") self.logger = get_module_logger("LSTM")

View File

@@ -35,7 +35,7 @@ class SandwichModel(nn.Module):
rnn_layers, rnn_layers,
dropout, dropout,
device, device,
**params **params,
): ):
"""Build a Sandwich model """Build a Sandwich model
@@ -129,7 +129,7 @@ class Sandwich(Model):
optimizer="adam", optimizer="adam",
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("Sandwich") self.logger = get_module_logger("Sandwich")

View File

@@ -212,7 +212,7 @@ class SFM(Model):
optimizer="gd", optimizer="gd",
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("SFM") self.logger = get_module_logger("SFM")

View File

@@ -56,7 +56,7 @@ class TCN(Model):
optimizer="adam", optimizer="adam",
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("TCN") self.logger = get_module_logger("TCN")

View File

@@ -54,7 +54,7 @@ class TCN(Model):
n_jobs=10, n_jobs=10,
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("TCN") self.logger = get_module_logger("TCN")

View File

@@ -58,7 +58,7 @@ class TCTS(Model):
mode="soft", mode="soft",
seed=None, seed=None,
lowest_valid_performance=0.993, lowest_valid_performance=0.993,
**kwargs **kwargs,
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("TCTS") self.logger = get_module_logger("TCTS")

View File

@@ -43,7 +43,7 @@ class TransformerModel(Model):
n_jobs=10, n_jobs=10,
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# set hyper-parameters. # set hyper-parameters.
self.d_model = d_model self.d_model = d_model

View File

@@ -41,7 +41,7 @@ class TransformerModel(Model):
n_jobs=10, n_jobs=10,
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **kwargs,
): ):
# set hyper-parameters. # set hyper-parameters.
self.d_model = d_model self.d_model = d_model

View File

@@ -28,7 +28,7 @@ class XGBModel(Model, FeatureInt):
verbose_eval=20, verbose_eval=20,
evals_result=dict(), evals_result=dict(),
reweighter=None, reweighter=None,
**kwargs **kwargs,
): ):
df_train, df_valid = dataset.prepare( df_train, df_valid = dataset.prepare(
["train", "valid"], ["train", "valid"],
@@ -63,7 +63,7 @@ class XGBModel(Model, FeatureInt):
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval, verbose_eval=verbose_eval,
evals_result=evals_result, evals_result=evals_result,
**kwargs **kwargs,
) )
evals_result["train"] = list(evals_result["train"].values())[0] evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0] evals_result["valid"] = list(evals_result["valid"].values())[0]

View File

@@ -4,10 +4,10 @@
# pylint: skip-file # pylint: skip-file
# flake8: noqa # flake8: noqa
import yaml
import pathlib import pathlib
import pandas as pd import pandas as pd
import shutil import shutil
from ruamel.yaml import YAML
from ...backtest.account import Account from ...backtest.account import Account
from .user import User from .user import User
from .utils import load_instance, save_instance from .utils import load_instance, save_instance
@@ -110,7 +110,8 @@ class UserManager:
raise ValueError("User data for {} already exists".format(user_id)) raise ValueError("User data for {} already exists".format(user_id))
with config_file.open("r") as fp: with config_file.open("r") as fp:
config = yaml.safe_load(fp) yaml = YAML(typ="safe", pure=True)
config = yaml.load(fp)
# load model # load model
model = init_instance_by_config(config["model"]) model = init_instance_by_config(config["model"])

View File

@@ -6,8 +6,8 @@
import pathlib import pathlib
import pickle import pickle
import yaml
import pandas as pd import pandas as pd
from ruamel.yaml import YAML
from ...data import D from ...data import D
from ...config import C from ...config import C
from ...log import get_module_logger from ...log import get_module_logger
@@ -91,7 +91,8 @@ def prepare(um, today, user_id, exchange_config=None):
dates.append(get_next_trading_date(dates[-1], future=True)) dates.append(get_next_trading_date(dates[-1], future=True))
if exchange_config: if exchange_config:
with pathlib.Path(exchange_config).open("r") as fp: with pathlib.Path(exchange_config).open("r") as fp:
exchange_paras = yaml.safe_load(fp) yaml = YAML(typ="safe", pure=True)
exchange_paras = yaml.load(fp)
else: else:
exchange_paras = {} exchange_paras = {}
trade_exchange = Exchange(trade_dates=dates, **exchange_paras) trade_exchange = Exchange(trade_dates=dates, **exchange_paras)

View File

@@ -176,7 +176,7 @@ class HeatmapGraph(BaseGraph):
x=self._df.columns, x=self._df.columns,
y=self._df.index, y=self._df.index,
z=self._df.values.tolist(), z=self._df.values.tolist(),
**self._graph_kwargs **self._graph_kwargs,
) )
] ]
return _data return _data
@@ -213,7 +213,7 @@ class SubplotsGraph:
sub_graph_layout: dict = None, sub_graph_layout: dict = None,
sub_graph_data: list = None, sub_graph_data: list = None,
subplots_kwargs: dict = None, subplots_kwargs: dict = None,
**kwargs **kwargs,
): ):
""" """
@@ -355,7 +355,7 @@ class SubplotsGraph:
df=self._df.loc[:, [column_name]], df=self._df.loc[:, [column_name]],
name_dict={column_name: temp_name}, name_dict={column_name: temp_name},
graph_kwargs=_graph_kwargs, graph_kwargs=_graph_kwargs,
) ),
) )
else: else:
raise TypeError() raise TypeError()

View File

@@ -2,11 +2,11 @@
# Licensed under the MIT License. # Licensed under the MIT License.
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from ruamel.yaml import YAML
from typing import List, Optional, Union from typing import List, Optional, Union
import fire import fire
import pandas as pd import pandas as pd
import yaml
from qlib import auto_init from qlib import auto_init
from qlib.log import get_module_logger from qlib.log import get_module_logger
@@ -117,7 +117,8 @@ class Rolling:
def _raw_conf(self) -> dict: def _raw_conf(self) -> dict:
with self.conf_path.open("r") as f: with self.conf_path.open("r") as f:
return yaml.safe_load(f) yaml = YAML(typ="safe", pure=True)
return yaml.load(f)
def _replace_handler_with_cache(self, task: dict): def _replace_handler_with_cache(self, task: dict):
""" """

View File

@@ -4,9 +4,9 @@
# pylint: skip-file # pylint: skip-file
# flake8: noqa # flake8: noqa
import yaml
import copy import copy
import os import os
from ruamel.yaml import YAML
class TunerConfigManager: class TunerConfigManager:
@@ -16,7 +16,8 @@ class TunerConfigManager:
self.config_path = config_path self.config_path = config_path
with open(config_path) as fp: with open(config_path) as fp:
config = yaml.safe_load(fp) yaml = YAML(typ="safe", pure=True)
config = yaml.load(fp)
self.config = copy.deepcopy(config) self.config = copy.deepcopy(config)
self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self) self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self)

View File

@@ -104,15 +104,24 @@ class HashingStockStorage(BaseHandlerStorage):
""" """
stock_selector = slice(None) stock_selector = slice(None)
time_selector = slice(None) # by default not filter by time.
if level is None: if level is None:
# For directly applying.
if isinstance(selector, tuple) and self.stock_level < len(selector): if isinstance(selector, tuple) and self.stock_level < len(selector):
# full selector format
stock_selector = selector[self.stock_level] stock_selector = selector[self.stock_level]
time_selector = selector[1 - self.stock_level]
elif isinstance(selector, (list, str)) and self.stock_level == 0: elif isinstance(selector, (list, str)) and self.stock_level == 0:
# only stock selector
stock_selector = selector stock_selector = selector
elif level in ("instrument", self.stock_level): elif level in ("instrument", self.stock_level):
if isinstance(selector, tuple): if isinstance(selector, tuple):
# NOTE: How could the stock level selector be a tuple?
stock_selector = selector[0] stock_selector = selector[0]
raise TypeError(
"I forget why would this case appear. But I think it does not make sense. So we raise a error for that case."
)
elif isinstance(selector, (list, str)): elif isinstance(selector, (list, str)):
stock_selector = selector stock_selector = selector
@@ -120,7 +129,7 @@ class HashingStockStorage(BaseHandlerStorage):
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}") raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
if stock_selector == slice(None): if stock_selector == slice(None):
return self.hash_df return self.hash_df, time_selector
if isinstance(stock_selector, str): if isinstance(stock_selector, str):
stock_selector = [stock_selector] stock_selector = [stock_selector]
@@ -129,7 +138,7 @@ class HashingStockStorage(BaseHandlerStorage):
for each_stock in sorted(stock_selector): for each_stock in sorted(stock_selector):
if each_stock in self.hash_df: if each_stock in self.hash_df:
select_dict[each_stock] = self.hash_df[each_stock] select_dict[each_stock] = self.hash_df[each_stock]
return select_dict return select_dict, time_selector
def fetch( def fetch(
self, self,
@@ -138,10 +147,13 @@ class HashingStockStorage(BaseHandlerStorage):
col_set: Union[str, List[str]] = DataHandler.CS_ALL, col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True, fetch_orig: bool = True,
) -> pd.DataFrame: ) -> pd.DataFrame:
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values()) fetch_stock_df_list, time_selector = self._fetch_hash_df_by_stock(selector=selector, level=level)
fetch_stock_df_list = list(fetch_stock_df_list.values())
for _index, stock_df in enumerate(fetch_stock_df_list): for _index, stock_df in enumerate(fetch_stock_df_list):
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set) fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig) fetch_index_df = fetch_df_by_index(
df=fetch_col_df, selector=time_selector, level="datetime", fetch_orig=fetch_orig
)
fetch_stock_df_list[_index] = fetch_index_df fetch_stock_df_list[_index] = fetch_index_df
if len(fetch_stock_df_list) == 0: if len(fetch_stock_df_list) == 0:
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument") index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")

View File

@@ -164,6 +164,7 @@ class SeriesDFilter(BaseDFilter):
timestamp = [] timestamp = []
_lbool = None _lbool = None
_ltime = None _ltime = None
_cur_start = None
for _ts, _bool in timestamp_series.items(): for _ts, _bool in timestamp_series.items():
# there is likely to be NAN when the filter series don't have the # there is likely to be NAN when the filter series don't have the
# bool value, so we just change the NAN into False # bool value, so we just change the NAN into False

View File

@@ -7,8 +7,7 @@ import shutil
import sys import sys
import tempfile import tempfile
from importlib import import_module from importlib import import_module
from ruamel.yaml import YAML
import yaml
DELETE_KEY = "_delete_" DELETE_KEY = "_delete_"
@@ -57,7 +56,8 @@ def parse_backtest_config(path: str) -> dict:
del sys.modules[tmp_module_name] del sys.modules[tmp_module_name]
else: else:
with open(tmp_config_file.name) as input_stream: with open(tmp_config_file.name) as input_stream:
config = yaml.safe_load(input_stream) yaml = YAML(typ="safe", pure=True)
config = yaml.load(input_stream)
if "_base_" in config: if "_base_" in config:
base_file_name = config.pop("_base_") base_file_name = config.pop("_base_")

View File

@@ -8,12 +8,12 @@ import random
import sys import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from ruamel.yaml import YAML
from typing import cast, List, Optional from typing import cast, List, Optional
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import yaml
from qlib.backtest import Order from qlib.backtest import Order
from qlib.backtest.decision import OrderDir from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN from qlib.constant import ONE_MIN
@@ -263,6 +263,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
with open(args.config_path, "r") as input_stream: with open(args.config_path, "r") as input_stream:
config = yaml.safe_load(input_stream) yaml = YAML(typ="safe", pure=True)
config = yaml.load(input_stream)
main(config, run_training=not args.no_training, run_backtest=args.run_backtest) main(config, run_training=not args.no_training, run_backtest=args.run_backtest)

View File

@@ -10,7 +10,6 @@ import os
import re import re
import copy import copy
import json import json
import yaml
import redis import redis
import bisect import bisect
import struct import struct
@@ -25,6 +24,7 @@ import pandas as pd
from pathlib import Path from pathlib import Path
from typing import List, Union, Optional, Callable from typing import List, Union, Optional, Callable
from packaging import version from packaging import version
from ruamel.yaml import YAML
from .file import ( from .file import (
get_or_create_path, get_or_create_path,
save_multiple_parts_file, save_multiple_parts_file,
@@ -244,12 +244,13 @@ def parse_config(config):
if not isinstance(config, str): if not isinstance(config, str):
return config return config
# Check whether config is file # Check whether config is file
yaml = YAML(typ="safe", pure=True)
if os.path.exists(config): if os.path.exists(config):
with open(config, "r") as f: with open(config, "r") as f:
return yaml.safe_load(f) return yaml.load(f)
# Check whether the str can be parsed # Check whether the str can be parsed
try: try:
return yaml.safe_load(config) return yaml.load(config)
except BaseException as base_exp: except BaseException as base_exp:
raise ValueError("cannot parse config!") from base_exp raise ValueError("cannot parse config!") from base_exp
@@ -799,6 +800,7 @@ def fill_placeholder(config: dict, config_extend: dict):
) )
return value return value
item_keys = None
while top < tail: while top < tail:
now_item = item_queue[top] now_item = item_queue[top]
top += 1 top += 1

View File

@@ -44,7 +44,7 @@ def concat(data_list: Union[SingleData], axis=0) -> MultiData:
all_index_map = dict(zip(all_index, range(len(all_index)))) all_index_map = dict(zip(all_index, range(len(all_index))))
# concat all # concat all
tmp_data = np.full((len(all_index), len(data_list)), np.NaN) tmp_data = np.full((len(all_index), len(data_list)), np.nan)
for data_id, index_data in enumerate(data_list): for data_id, index_data in enumerate(data_list):
assert isinstance(index_data, SingleData) assert isinstance(index_data, SingleData)
now_data_map = [all_index_map[index] for index in index_data.index] now_data_map = [all_index_map[index] for index in index_data.index]
@@ -64,7 +64,7 @@ def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) ->
new_index : list new_index : list
the new_index of new SingleData. the new_index of new SingleData.
fill_value : float fill_value : float
fill the missing values or replace np.NaN. fill the missing values or replace np.nan.
Returns Returns
------- -------
@@ -444,7 +444,7 @@ class IndexData(metaclass=index_data_ops_creator):
return self.__class__(~self.data.astype(bool), *self.indices) return self.__class__(~self.data.astype(bool), *self.indices)
def abs(self): def abs(self):
"""get the abs of data except np.NaN.""" """get the abs of data except np.nan."""
tmp_data = np.absolute(self.data) tmp_data = np.absolute(self.data)
return self.__class__(tmp_data, *self.indices) return self.__class__(tmp_data, *self.indices)
@@ -566,8 +566,8 @@ class SingleData(IndexData):
f"The indexes of self and other do not meet the requirements of the four arithmetic operations" f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
) )
def reindex(self, index: Index, fill_value=np.NaN) -> SingleData: def reindex(self, index: Index, fill_value=np.nan) -> SingleData:
"""reindex data and fill the missing value with np.NaN. """reindex data and fill the missing value with np.nan.
Parameters Parameters
---------- ----------
@@ -615,7 +615,7 @@ class SingleData(IndexData):
return pd.Series(self.data, index=self.index) return pd.Series(self.data, index=self.index)
def __repr__(self) -> str: def __repr__(self) -> str:
return str(pd.Series(self.data, index=self.index)) return str(pd.Series(self.data, index=self.index.tolist()))
class MultiData(IndexData): class MultiData(IndexData):
@@ -651,4 +651,4 @@ class MultiData(IndexData):
) )
def __repr__(self) -> str: def __repr__(self) -> str:
return str(pd.DataFrame(self.data, index=self.index, columns=self.columns)) return str(pd.DataFrame(self.data, index=self.index.tolist(), columns=self.columns.tolist()))

View File

@@ -7,7 +7,7 @@ import sys
import fire import fire
from jinja2 import Template, meta from jinja2 import Template, meta
import ruamel.yaml as yaml from ruamel.yaml import YAML
import qlib import qlib
from qlib.config import C from qlib.config import C
@@ -104,7 +104,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
""" """
# Render the template # Render the template
rendered_yaml = render_template(config_path) rendered_yaml = render_template(config_path)
config = yaml.safe_load(rendered_yaml) yaml = YAML(typ="safe", pure=True)
config = yaml.load(rendered_yaml)
base_config_path = config.get("BASE_CONFIG_PATH", None) base_config_path = config.get("BASE_CONFIG_PATH", None)
if base_config_path: if base_config_path:
@@ -126,7 +127,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}") raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}")
with open(path) as fp: with open(path) as fp:
base_config = yaml.safe_load(fp) yaml = YAML(typ="safe", pure=True)
base_config = yaml.load(fp)
logger.info(f"Load BASE_CONFIG_PATH succeed: {path.resolve()}") logger.info(f"Load BASE_CONFIG_PATH succeed: {path.resolve()}")
config = update_config(base_config, config) config = update_config(base_config, config)

View File

@@ -8,6 +8,7 @@ from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCod
from mlflow.entities import ViewType from mlflow.entities import ViewType
import os import os
from typing import Optional, Text from typing import Optional, Text
from pathlib import Path
from .exp import MLflowExperiment, Experiment from .exp import MLflowExperiment, Experiment
from ..config import C from ..config import C
@@ -233,7 +234,7 @@ class ExpManager:
# So we supported it in the interface wrapper # So we supported it in the interface wrapper
pr = urlparse(self.uri) pr = urlparse(self.uri)
if pr.scheme == "file": if pr.scheme == "file":
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")): # pylint: disable=E0110 with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110
return self.create_exp(experiment_name), True return self.create_exp(experiment_name), True
# NOTE: for other schemes like http, we double check to avoid create exp conflicts # NOTE: for other schemes like http, we double check to avoid create exp conflicts
try: try:
@@ -421,7 +422,11 @@ class MLflowExpManager(ExpManager):
def list_experiments(self): def list_experiments(self):
# retrieve all the existing experiments # retrieve all the existing experiments
exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY) mlflow_version = int(mlflow.__version__.split(".", maxsplit=1)[0])
if mlflow_version >= 2:
exps = self.client.search_experiments(view_type=ViewType.ACTIVE_ONLY)
else:
exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY) # pylint: disable=E1101
experiments = dict() experiments = dict()
for exp in exps: for exp in exps:
experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri) experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)

View File

@@ -9,6 +9,7 @@ import shutil
import pickle import pickle
import tempfile import tempfile
import subprocess import subprocess
import platform
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
@@ -316,7 +317,10 @@ class MLflowRecorder(Recorder):
This function will return the directory path of this recorder. This function will return the directory path of this recorder.
""" """
if self.artifact_uri is not None: if self.artifact_uri is not None:
local_dir_path = Path(self.artifact_uri.lstrip("file:")) / ".." if platform.system() == "Windows":
local_dir_path = Path(self.artifact_uri.lstrip("file:").lstrip("/")).parent
else:
local_dir_path = Path(self.artifact_uri.lstrip("file:")).parent
local_dir_path = str(local_dir_path.resolve()) local_dir_path = str(local_dir_path.resolve())
if os.path.isdir(local_dir_path): if os.path.isdir(local_dir_path):
return local_dir_path return local_dir_path

3
setup.cfg Normal file
View File

@@ -0,0 +1,3 @@
[metadata]
name = qlib
version = attr: qlib.__version__

208
setup.py
View File

@@ -1,208 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import numpy
from setuptools import find_packages, setup, Extension
def read(rel_path: str) -> str:
here = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(here, rel_path), encoding="utf-8") as fp:
return fp.read()
def get_version(rel_path: str) -> str:
for line in read(rel_path).splitlines():
if line.startswith("__version__"):
delim = '"' if '"' in line else "'"
return line.split(delim)[1]
raise RuntimeError("Unable to find version string.")
# Package meta-data.
NAME = "pyqlib"
DESCRIPTION = "A Quantitative-research Platform"
REQUIRES_PYTHON = ">=3.5.0"
VERSION = get_version("qlib/__init__.py")
# Detect Cython
try:
import Cython
ver = Cython.__version__
_CYTHON_INSTALLED = ver >= "0.28"
except ImportError:
_CYTHON_INSTALLED = False
if not _CYTHON_INSTALLED:
print("Required Cython version >= 0.28 is not detected!")
print('Please run "pip install --upgrade cython" first.')
exit(-1)
# What packages are required for this module to be executed?
# `estimator` may depend on other packages. In order to reduce dependencies, it is not written here.
REQUIRED = [
"numpy>=1.12.0, <1.24",
"pandas>=0.25.1",
"scipy>=1.7.3",
# scs is a dependency package,
# and the latest version of scs: scs-3.2.4.post3.tar.gz causes the documentation build to fail,
# so we have temporarily limited the version of scs.
"scs<=3.2.4",
"requests>=2.18.0",
"sacred>=0.7.4",
"python-socketio",
"redis>=3.0.1",
"python-redis-lock>=3.3.1",
"schedule>=0.6.0",
"cvxpy>=1.0.21",
"hyperopt==0.1.2",
"fire>=0.3.1",
"statsmodels",
"xlrd>=1.0.0",
"plotly>=4.12.0",
"matplotlib>=3.3",
"tables>=3.6.1",
"pyyaml>=5.3.1",
# To ensure stable operation of the experiment manager, we have limited the version of mlflow,
# and we need to verify whether version 2.0 of mlflow can serve qlib properly.
"mlflow>=1.12.1, <=1.30.0",
# mlflow 1.30.0 requires packaging<22, so we limit the packaging version, otherwise the CI will fail.
"packaging<22",
"tqdm",
"loguru",
"lightgbm>=3.3.0",
"tornado",
"joblib>=0.17.0",
# With the upgrading of ruamel.yaml to 0.18, the safe_load method was deprecated,
# which would cause qlib.workflow.cli to not work properly,
# and no good replacement has been found, so the version of ruamel.yaml has been restricted for now.
# Refs: https://pypi.org/project/ruamel.yaml/
"ruamel.yaml<=0.17.36",
"pymongo==3.7.2", # For task management
"scikit-learn>=0.22",
"dill",
"dataclasses;python_version<'3.7'",
"filelock",
"jinja2",
"gym",
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
"protobuf<=3.20.1;python_version<='3.8'",
"cryptography",
]
# Numpy include
NUMPY_INCLUDE = numpy.get_include()
here = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read()
# Cython Extensions
extensions = [
Extension(
"qlib.data._libs.rolling",
["qlib/data/_libs/rolling.pyx"],
language="c++",
include_dirs=[NUMPY_INCLUDE],
),
Extension(
"qlib.data._libs.expanding",
["qlib/data/_libs/expanding.pyx"],
language="c++",
include_dirs=[NUMPY_INCLUDE],
),
]
# Where the magic happens:
setup(
name=NAME,
version=VERSION,
license="MIT Licence",
url="https://github.com/microsoft/qlib",
description=DESCRIPTION,
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=REQUIRES_PYTHON,
packages=find_packages(exclude=("tests",)),
# if your package is a single module, use this instead of 'packages':
# py_modules=['qlib'],
entry_points={
# 'console_scripts': ['mycli=mymodule:cli'],
"console_scripts": [
"qrun=qlib.workflow.cli:run",
],
},
ext_modules=extensions,
install_requires=REQUIRED,
extras_require={
"dev": [
"coverage",
"pytest>=3",
"sphinx",
"sphinx_rtd_theme",
"pre-commit",
# CI dependencies
"wheel",
"setuptools",
"black",
# Version 3.0 of pylint had problems with the build process, so we limited the version of pylint.
"pylint<=2.17.6",
# Using the latest versions(0.981 and 0.982) of mypy,
# the error "multiprocessing.Value()" is detected in the file "qlib/rl/utils/data_queue.py",
# If this is fixed in a subsequent version of mypy, then we will revert to the latest version of mypy.
# References: https://github.com/python/typeshed/issues/8799
"mypy<0.981",
"flake8",
"nbqa",
"jupyter",
"nbconvert",
# The 5.0.0 version of importlib-metadata removed the deprecated endpoint,
# which prevented flake8 from working properly, so we restricted the version of importlib-metadata.
# To help ensure the dependencies of flake8 https://github.com/python/importlib_metadata/issues/406
"importlib-metadata<5.0.0",
"readthedocs_sphinx_ext",
"cmake",
"lxml",
"baostock",
"yahooquery",
# 2024-05-30 scs has released a new version: 3.2.4.post2,
# this version, causes qlib installation to fail, so we've limited the scs version a bit for now.
"scs<=3.2.4",
"beautifulsoup4",
# In version 0.4.11 of tianshou, the code:
# logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
# was changed in PR787,
# which causes pytest errors(AttributeError: 'dict' object has no attribute 'info') in CI,
# so we restricted the version of tianshou.
# References:
# https://github.com/thu-ml/tianshou/releases
"tianshou<=0.4.10",
"gym>=0.24", # If you do not put gym at the end, gym will degrade causing pytest results to fail.
],
"rl": [
"tianshou<=0.4.10",
"torch",
],
},
include_package_data=True,
classifiers=[
# Trove classifiers
# Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
# 'License :: OSI Approved :: MIT License',
"Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS",
"License :: OSI Approved :: MIT License",
"Development Status :: 3 - Alpha",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
],
)

View File

@@ -16,7 +16,7 @@ from qlib.data import D
class TestDataLoader(unittest.TestCase): class TestDataLoader(unittest.TestCase):
def test_nested_data_loader(self): def test_nested_data_loader(self):
qlib.init() qlib.init(kernels=1)
nd = NestedDataLoader( nd = NestedDataLoader(
dataloader_l=[ dataloader_l=[
{ {
@@ -30,7 +30,7 @@ class TestDataLoader(unittest.TestCase):
) )
# Of course you can use StaticDataLoader # Of course you can use StaticDataLoader
dataset = nd.load() dataset = nd.load(start_time="2020-01-01", end_time="2020-01-31")
assert dataset is not None assert dataset is not None

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
import unittest import unittest
import platform
import mlflow import mlflow
import time import time
from pathlib import Path from pathlib import Path
@@ -26,7 +27,10 @@ class MLflowTest(unittest.TestCase):
_ = mlflow.tracking.MlflowClient(tracking_uri=str(self.TMP_PATH)) _ = mlflow.tracking.MlflowClient(tracking_uri=str(self.TMP_PATH))
end = time.time() end = time.time()
elapsed = end - start elapsed = end - start
self.assertLess(elapsed, 1e-2) # it can be done in less than 10ms if platform.system() == "Linux":
self.assertLess(elapsed, 1e-2) # it can be done in less than 10ms
else:
self.assertLess(elapsed, 2e-2)
print(elapsed) print(elapsed)

View File

@@ -70,7 +70,7 @@ class IndexDataTest(unittest.TestCase):
print(sd.loc[:"c"]) print(sd.loc[:"c"])
def test_corner_cases(self): def test_corner_cases(self):
sd = idd.MultiData([[1, 2], [3, np.NaN]], index=["foo", "bar"], columns=["f", "g"]) sd = idd.MultiData([[1, 2], [3, np.nan]], index=["foo", "bar"], columns=["f", "g"])
print(sd) print(sd)
self.assertTrue(np.isnan(sd.loc["bar", "g"])) self.assertTrue(np.isnan(sd.loc["bar", "g"]))

View File

@@ -50,6 +50,8 @@ class TestNN(TestAutoData):
model_l = [ model_l = [
GeneralPTNN( GeneralPTNN(
n_epochs=2, n_epochs=2,
batch_size=32,
n_jobs=0,
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel", pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
pt_model_kwargs={ pt_model_kwargs={
"d_feat": 3, "d_feat": 3,
@@ -60,6 +62,8 @@ class TestNN(TestAutoData):
), ),
GeneralPTNN( GeneralPTNN(
n_epochs=2, n_epochs=2,
batch_size=32,
n_jobs=0,
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
pt_model_kwargs={ pt_model_kwargs={
"input_dim": 3, "input_dim": 3,

View File

@@ -8,7 +8,6 @@ import shutil
import unittest import unittest
import pytest import pytest
import pandas as pd import pandas as pd
import baostock as bs
from pathlib import Path from pathlib import Path
from qlib.data import D from qlib.data import D