mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
101 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5379c520f | ||
|
|
7ccf3f7658 | ||
|
|
2c21b8089a | ||
|
|
b87a2c294d | ||
|
|
3097dcc995 | ||
|
|
2fb9380b34 | ||
|
|
8fd6d5ca7e | ||
|
|
69bb755f37 | ||
|
|
39634b2158 | ||
|
|
16acb76aba | ||
|
|
4e0f5d5ec9 | ||
|
|
50c32ac15f | ||
|
|
80982f8904 | ||
|
|
477160e4ac | ||
|
|
3472e82d5c | ||
|
|
cb285bccac | ||
|
|
2e9a00a9f7 | ||
|
|
d631b4450b | ||
|
|
0826879481 | ||
|
|
2b41782f0c | ||
|
|
ac3fe9476f | ||
|
|
66c36226aa | ||
|
|
bb7ab1cf14 | ||
|
|
3dc5a7d299 | ||
|
|
7d66e4b788 | ||
|
|
213eb6c2cd | ||
|
|
94d138ec23 | ||
|
|
f26b341736 | ||
|
|
136b2ddf9a | ||
|
|
7095e755fa | ||
|
|
2d05a705e3 | ||
|
|
da920b7f95 | ||
|
|
d89fa0184c | ||
|
|
1b426503fc | ||
|
|
78b77e302b | ||
|
|
38f02d25dc | ||
|
|
de86e46ed0 | ||
|
|
ba8b6cc30f | ||
|
|
3525514704 | ||
|
|
3e72593b8c | ||
|
|
c38e799ce7 | ||
|
|
14d54aa2a1 | ||
|
|
89ae312109 | ||
|
|
3ea30c0290 | ||
|
|
4b8d70df1b | ||
|
|
a2996f7046 | ||
|
|
fbba768006 | ||
|
|
df557d29d5 | ||
|
|
be9cd9fe23 | ||
|
|
85cc74846b | ||
|
|
950408ef46 | ||
|
|
320bd65e19 | ||
|
|
e7a1b5ea1f | ||
|
|
67feeaeb00 | ||
|
|
4d621bff99 | ||
|
|
82f1ef2def | ||
|
|
186512f272 | ||
|
|
bda374180a | ||
|
|
014ff7d3fe | ||
|
|
23d9d5a0a9 | ||
|
|
7ce97c9da5 | ||
|
|
5a84aaf1dc | ||
|
|
afbb178e24 | ||
|
|
a0cef033cb | ||
|
|
7acb4f3484 | ||
|
|
431f574967 | ||
|
|
b604fe56b3 | ||
|
|
af4b8772d2 | ||
|
|
18fcdf1521 | ||
|
|
f2caf452e9 | ||
|
|
ca9f1861a4 | ||
|
|
b45b006ef2 | ||
|
|
82cf438401 | ||
|
|
9e635168c0 | ||
|
|
b7ace1a622 | ||
|
|
c9ed050ef0 | ||
|
|
2c33332dd6 | ||
|
|
a7d5a9b500 | ||
|
|
5190332c7e | ||
|
|
cde80206e4 | ||
|
|
a339fc11d1 | ||
|
|
33482047dc | ||
|
|
47bd13295b | ||
|
|
ebc0ca893e | ||
|
|
3a348aec9f | ||
|
|
37b908792b | ||
|
|
73ec0f4003 | ||
|
|
155c17f8ff | ||
|
|
41b94059aa | ||
|
|
7db83d84b7 | ||
|
|
35e0fdd1c0 | ||
|
|
598017f634 | ||
|
|
907c888c23 | ||
|
|
02fe6b6974 | ||
|
|
b892b21045 | ||
|
|
155f80323c | ||
|
|
63021018d6 | ||
|
|
f79a0eeaff | ||
|
|
8a087d0db9 | ||
|
|
2ae4be426a | ||
|
|
6ed83f7c04 |
21
.commitlintrc.js
Normal file
21
.commitlintrc.js
Normal file
@@ -0,0 +1,21 @@
|
||||
module.exports = {
|
||||
extends: ["@commitlint/config-conventional"],
|
||||
rules: {
|
||||
// Configuration Format: [level, applicability, value]
|
||||
// level: Error level, usually expressed as a number:
|
||||
// 0 - disable rule
|
||||
// 1 - Warning (does not prevent commits)
|
||||
// 2 - Error (will block the commit)
|
||||
// applicability: the conditions under which the rule applies, commonly used values:
|
||||
// “always” - always apply the rule
|
||||
// “never” - never apply the rule
|
||||
// value: the specific value of the rule, e.g. a maximum length of 100.
|
||||
// Refs: https://commitlint.js.org/reference/rules-configuration.html
|
||||
"header-max-length": [2, "always", 100],
|
||||
"type-enum": [
|
||||
2,
|
||||
"always",
|
||||
["build", "chore", "ci", "docs", "feat", "fix", "perf", "refactor", "revert", "style", "test", "Release-As"]
|
||||
]
|
||||
}
|
||||
};
|
||||
8
.dockerignore
Normal file
8
.dockerignore
Normal file
@@ -0,0 +1,8 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
.env
|
||||
.git
|
||||
|
||||
13
.github/PULL_REQUEST_TEMPLATE.md
vendored
13
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,3 +1,16 @@
|
||||
<!--- Thank you for submitting a Pull Request! In order to make our work smoother. -->
|
||||
<!--- please make sure your Pull Request meets the following requirements: -->
|
||||
<!--- 1. Provide a general summary of your changes in the Title above; -->
|
||||
<!--- 2. Add appropriate prefixes to titles, such as `build:`, `chore:`, `ci:`, `docs:`, `feat:`, `fix:`, `perf:`, `refactor:`, `revert:`, `style:`, `test:`(Ref: https://www.conventionalcommits.org/). -->
|
||||
<!--- Category: -->
|
||||
<!--- Patch Updates: `fix:` -->
|
||||
<!--- Example: fix(auth): correct login validation issue -->
|
||||
<!--- minor update (introduces new functionality): `feat` -->
|
||||
<!--- Example: feature(parser): add ability to parse arrays -->
|
||||
<!--- major update(destructive update): Include BREAKING CHANGE in the commit message footer, or add `! ` in the commit footer to indicate that there is a destructive update. -->
|
||||
<!--- Example: feat(auth)! : remove support for old authentication method -->
|
||||
<!--- Other updates: `build:`, `chore:`, `ci:`, `docs:`, `perf:`, `refactor:`, `revert:`, `style:`, `test:`. -->
|
||||
|
||||
<!--- Provide a general summary of your changes in the Title above -->
|
||||
|
||||
## Description
|
||||
|
||||
6
.github/labeler.yml
vendored
6
.github/labeler.yml
vendored
@@ -1,6 +0,0 @@
|
||||
documentation:
|
||||
- 'docs/**/*'
|
||||
- '**/*.md'
|
||||
|
||||
waiting for triage:
|
||||
- any: ['**/*', '!docs/**/*', '!**/*.md']
|
||||
14
.github/workflows/labeler.yml
vendored
14
.github/workflows/labeler.yml
vendored
@@ -1,14 +0,0 @@
|
||||
name: "Add label automatically"
|
||||
on:
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v4
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
35
.github/workflows/lint_title.yml
vendored
Normal file
35
.github/workflows/lint_title.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Lint pull request title
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- synchronize
|
||||
- reopened
|
||||
- edited
|
||||
|
||||
concurrency:
|
||||
cancel-in-progress: true
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
jobs:
|
||||
lint-title:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# This step is necessary because the lint title uses the .commitlintrc.js file in the project root directory.
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '16'
|
||||
|
||||
- name: Install commitlint
|
||||
run: npm install --save-dev @commitlint/{config-conventional,cli}
|
||||
|
||||
- name: Validate PR Title with commitlint
|
||||
env:
|
||||
BODY: ${{ github.event.pull_request.title }}
|
||||
run: |
|
||||
echo "$BODY" | npx commitlint --config .commitlintrc.js
|
||||
81
.github/workflows/python-publish.yml
vendored
81
.github/workflows/python-publish.yml
vendored
@@ -1,81 +0,0 @@
|
||||
# This workflows will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
deploy_with_bdist_wheel:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-11]
|
||||
# FIXME: macos-latest will raise error now.
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
# This is because on macos systems you can install pyqlib using
|
||||
# `pip install pyqlib` installs, it does not recognize the
|
||||
# `pyqlib-<version>-cp38-cp38-macosx_11_0_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_11_0_x86_64.whl`.
|
||||
# So we limit the version of python, in order to generate a version of qlib that is usable for macos: `pyqlib-<veresion>-cp38-cp37m
|
||||
# `pyqlib-<version>-cp38-cp38-macosx_10_15_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_10_15_x86_64.whl`.
|
||||
# Python 3.7.16, 3.8.16 can build macosx_10_15. But Python 3.7.17, 3.8.17 can build macosx_11_0
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os == 'macos-11' && matrix.python-version == '3.7'
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os == 'macos-11' && matrix.python-version == '3.8'
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.8.16"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os != 'macos-11'
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install setuptools wheel twine
|
||||
- name: Build wheel on ${{ matrix.os }}
|
||||
run: |
|
||||
pip install numpy
|
||||
pip install cython
|
||||
python setup.py bdist_wheel
|
||||
- name: Build and publish
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine upload dist/*
|
||||
|
||||
deploy_with_manylinux:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Build wheel on Linux
|
||||
uses: RalfG/python-wheels-manylinux-build@v0.3.1-manylinux2010_x86_64
|
||||
with:
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-versions: 'cp37-cp37m cp38-cp38'
|
||||
build-requirements: 'numpy cython'
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.7
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install twine
|
||||
- name: Build and publish
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine upload dist/pyqlib-*-manylinux*.whl
|
||||
22
.github/workflows/release-drafter.yml
vendored
22
.github/workflows/release-drafter.yml
vendored
@@ -1,22 +0,0 @@
|
||||
name: Release Drafter
|
||||
|
||||
on:
|
||||
push:
|
||||
# branches to consider in the event; optional, defaults to all
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
update_release_draft:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# Drafts your next Release notes as Pull Requests are merged into "master"
|
||||
- uses: release-drafter/release-drafter@v5.11.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
107
.github/workflows/release.yml
vendored
Normal file
107
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
release_created: ${{ steps.release_please.outputs.release_created }}
|
||||
|
||||
steps:
|
||||
- name: Release please
|
||||
id: release_please
|
||||
uses: googleapis/release-please-action@v4
|
||||
with:
|
||||
token: ${{ secrets.PAT }}
|
||||
release-type: simple
|
||||
|
||||
deploy_with_manylinux:
|
||||
needs: release
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Build wheel on Linux
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
uses: RalfG/python-wheels-manylinux-build@v0.7.1-manylinux2014_x86_64
|
||||
with:
|
||||
python-versions: 'cp38-cp38 cp39-cp39 cp310-cp310 cp311-cp311 cp312-cp312'
|
||||
build-requirements: 'numpy cython'
|
||||
|
||||
- name: Install dependencies
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
run: |
|
||||
python -m pip install twine
|
||||
|
||||
- name: Upload to PyPi
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.TESTPYPI }}
|
||||
run: |
|
||||
twine check dist/pyqlib-*-manylinux*.whl
|
||||
twine upload --repository-url https://test.pypi.org/legacy/ dist/pyqlib-*-manylinux*.whl --verbose
|
||||
|
||||
deploy_with_bdist_wheel:
|
||||
needs: release
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# After testing, the whl files of pyqlib built by macos-14 and macos-15 in python environments of 3.8, 3.9, 3.10, 3.11, 3.12,
|
||||
# the filenames are exactly duplicated, which will result in the duplicated whl files not being able to be uploaded to pypi,
|
||||
# so we chose to just keep the latest macos-latest. macos-latest currently points to macos-15.
|
||||
# Also, macos-13 will stop being supported on 2025-11-14.
|
||||
# Refs: https://github.blog/changelog/2025-07-11-upcoming-changes-to-macos-hosted-runners-macos-latest-migration-and-xcode-support-policy-updates/
|
||||
os: [windows-latest, macos-latest]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
run: |
|
||||
make dev
|
||||
|
||||
- name: Build wheel on ${{ matrix.os }}
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
run: |
|
||||
make build
|
||||
|
||||
- name: Upload to PyPi
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.TESTPYPI }}
|
||||
run: |
|
||||
twine check dist/*.whl
|
||||
twine upload --repository-url https://test.pypi.org/legacy/ dist/*.whl --verbose
|
||||
43
.github/workflows/test_qlib_from_pip.yml
vendored
43
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -1,5 +1,9 @@
|
||||
name: Test qlib from pip
|
||||
|
||||
concurrency:
|
||||
cancel-in-progress: true
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
@@ -13,28 +17,19 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-24.04, ubuntu-22.04, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from pip
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -42,26 +37,22 @@ jobs:
|
||||
- name: Update pip to the latest version
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
|
||||
- name: Qlib installation test
|
||||
run: |
|
||||
python -m pip install pyqlib
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
brew update
|
||||
brew install libomp || brew reinstall libomp
|
||||
python -m pip install --no-binary=:all: lightgbm
|
||||
|
||||
- name: Downloads dependencies data
|
||||
run: |
|
||||
cd ..
|
||||
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
cd qlib
|
||||
|
||||
- name: Test workflow by config
|
||||
|
||||
141
.github/workflows/test_qlib_from_source.yml
vendored
141
.github/workflows/test_qlib_from_source.yml
vendored
@@ -1,5 +1,9 @@
|
||||
name: Test qlib from source
|
||||
|
||||
concurrency:
|
||||
cancel-in-progress: true
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
@@ -14,28 +18,19 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-24.04, ubuntu-22.04, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -45,12 +40,12 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Installing pytorch for ubuntu
|
||||
if: ${{ matrix.os == 'ubuntu-20.04' || matrix.os == 'ubuntu-22.04' }}
|
||||
if: ${{ matrix.os == 'ubuntu-24.04' || matrix.os == 'ubuntu-22.04' }}
|
||||
run: |
|
||||
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
@@ -61,85 +56,36 @@ jobs:
|
||||
|
||||
- name: Set up Python tools
|
||||
run: |
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install -e .[dev]
|
||||
make dev
|
||||
|
||||
- name: Lint with Black
|
||||
# Python 3.7 will use a black with low level. So we use python with higher version for black check
|
||||
if: (matrix.python-version != '3.7')
|
||||
run: |
|
||||
pip install -U black # follow the latest version of black, previous Qlib dependency will downgrade black
|
||||
black . -l 120 --check --diff
|
||||
make black
|
||||
|
||||
- name: Make html with sphinx
|
||||
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
|
||||
if: ${{ matrix.os == 'ubuntu-22.04' }}
|
||||
run: |
|
||||
cd docs
|
||||
sphinx-build -W --keep-going -b html . _build
|
||||
cd ..
|
||||
make docs-gen
|
||||
|
||||
# Check Qlib with pylint
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
# C0209: consider-using-f-string
|
||||
# R0402: consider-using-from-import
|
||||
# R1705: no-else-return
|
||||
# R1710: inconsistent-return-statements
|
||||
# R1725: super-with-arguments
|
||||
# R1735: use-dict-literal
|
||||
# W0102: dangerous-default-value
|
||||
# W0212: protected-access
|
||||
# W0221: arguments-differ
|
||||
# W0223: abstract-method
|
||||
# W0231: super-init-not-called
|
||||
# W0237: arguments-renamed
|
||||
# W0612: unused-variable
|
||||
# W0621: redefined-outer-name
|
||||
# W0622: redefined-builtin
|
||||
# FIXME: specify exception type
|
||||
# W0703: broad-except
|
||||
# W1309: f-string-without-interpolation
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
|
||||
- name: Check Qlib with pylint
|
||||
run: |
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
make pylint
|
||||
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
# Description: We have used black to limit the length of each line to 120.
|
||||
# F541 f-string is missing placeholders
|
||||
# Description: The same thing is done when using pylint for detection.
|
||||
# E266 too many leading '#' for block comment
|
||||
# Description: To make the code more readable, a lot of "#" is used.
|
||||
# This error code appears centrally in:
|
||||
# qlib/backtest/executor.py
|
||||
# qlib/data/ops.py
|
||||
# qlib/utils/__init__.py
|
||||
# E402 module level import not at top of file
|
||||
# Description: There are times when module level import is not available at the top of the file.
|
||||
# W503 line break before binary operator
|
||||
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
|
||||
# E731 do not assign a lambda expression, use a def
|
||||
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
|
||||
# E203 whitespace before ':'
|
||||
# Description: If there is whitespace before ":", it cannot pass the black check.
|
||||
- name: Check Qlib with flake8
|
||||
run: |
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
make flake8
|
||||
|
||||
# https://github.com/python/mypy/issues/10600
|
||||
- name: Check Qlib with mypy
|
||||
run: |
|
||||
mypy qlib --install-types --non-interactive || true
|
||||
mypy qlib --verbose
|
||||
make mypy
|
||||
|
||||
# Due to issues that cannot be automatically fixed when running `nbqa black . -l 120 --check --diff` on Jupyter notebooks,
|
||||
# we reverted to a version of `black` earlier than 26.1.0 before performing the checks.
|
||||
- name: Check Qlib ipynb with nbqa
|
||||
run: |
|
||||
nbqa black . -l 120 --check --diff
|
||||
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$'
|
||||
python -m pip install "black<26.1"
|
||||
make nbqa
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
@@ -147,28 +93,39 @@ jobs:
|
||||
python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
brew update
|
||||
brew install libomp || brew reinstall libomp
|
||||
python -m pip install --no-binary=:all: lightgbm
|
||||
|
||||
# Run after data downloads
|
||||
- name: Check Qlib ipynb with nbconvert
|
||||
run: |
|
||||
# add more ipynb files in future
|
||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||
make nbconvert
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python -m pip install numba
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
- name: Unit tests with Pytest (MacOS)
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
uses: nick-fields/retry@v2
|
||||
with:
|
||||
timeout_minutes: 60
|
||||
max_attempts: 3
|
||||
command: |
|
||||
# Limit the number of threads in various libraries to prevent Segmentation faults caused by OpenMP multithreading conflicts under macOS.
|
||||
export OMP_NUM_THREADS=1 # Limit the number of OpenMP threads
|
||||
export MKL_NUM_THREADS=1 # Limit the number of Intel MKL threads
|
||||
export NUMEXPR_NUM_THREADS=1 # Limit the number of NumExpr threads
|
||||
export OPENBLAS_NUM_THREADS=1 # Limit the number of OpenBLAS threads
|
||||
export VECLIB_MAXIMUM_THREADS=1 # Limit the number of macOS Accelerate/vecLib threads
|
||||
cd tests
|
||||
python -m pytest . -m "not slow" --durations=0
|
||||
|
||||
- name: Unit tests with Pytest (Ubuntu and Windows)
|
||||
if: ${{ matrix.os != 'macos-13' && matrix.os != 'macos-14' && matrix.os != 'macos-15' }}
|
||||
uses: nick-fields/retry@v2
|
||||
with:
|
||||
timeout_minutes: 60
|
||||
|
||||
43
.github/workflows/test_qlib_from_source_slow.yml
vendored
43
.github/workflows/test_qlib_from_source_slow.yml
vendored
@@ -1,5 +1,9 @@
|
||||
name: Test qlib from source slow
|
||||
|
||||
concurrency:
|
||||
cancel-in-progress: true
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
@@ -14,52 +18,37 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-24.04, ubuntu-22.04, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source slow
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
make dev
|
||||
|
||||
- name: Downloads dependencies data
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
brew update
|
||||
brew install libomp || brew reinstall libomp
|
||||
python -m pip install --no-binary=:all: lightgbm
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
uses: nick-fields/retry@v2
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -22,6 +22,7 @@ dist/
|
||||
qlib/VERSION.txt
|
||||
qlib/data/_libs/expanding.cpp
|
||||
qlib/data/_libs/rolling.cpp
|
||||
qlib/_version.py
|
||||
examples/estimator/estimator_example/
|
||||
examples/rl/data/
|
||||
examples/rl/checkpoints/
|
||||
@@ -48,4 +49,5 @@ tags
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
.idea/
|
||||
.idea/
|
||||
.aider*
|
||||
|
||||
@@ -9,7 +9,7 @@ version: 2
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.7"
|
||||
python: "3.8"
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
|
||||
31
Dockerfile
Normal file
31
Dockerfile
Normal file
@@ -0,0 +1,31 @@
|
||||
FROM continuumio/miniconda3:latest
|
||||
|
||||
WORKDIR /qlib
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential
|
||||
|
||||
RUN conda create --name qlib_env python=3.8 -y
|
||||
RUN echo "conda activate qlib_env" >> ~/.bashrc
|
||||
ENV PATH /opt/conda/envs/qlib_env/bin:$PATH
|
||||
|
||||
RUN python -m pip install --upgrade pip
|
||||
|
||||
RUN python -m pip install numpy==1.23.5
|
||||
RUN python -m pip install pandas==1.5.3
|
||||
RUN python -m pip install importlib-metadata==5.2.0
|
||||
RUN python -m pip install "cloudpickle<3"
|
||||
RUN python -m pip install scikit-learn==1.3.2
|
||||
|
||||
RUN python -m pip install cython packaging tables matplotlib statsmodels
|
||||
RUN python -m pip install pybind11 cvxpy
|
||||
|
||||
ARG IS_STABLE="yes"
|
||||
|
||||
RUN if [ "$IS_STABLE" = "yes" ]; then \
|
||||
python -m pip install pyqlib; \
|
||||
else \
|
||||
python setup.py install; \
|
||||
fi
|
||||
@@ -1 +1,6 @@
|
||||
include qlib/VERSION.txt
|
||||
exclude tests/*
|
||||
include qlib/*
|
||||
include qlib/*/*
|
||||
include qlib/*/*/*
|
||||
include qlib/*/*/*/*
|
||||
include qlib/*/*/*/*/*
|
||||
|
||||
212
Makefile
Normal file
212
Makefile
Normal file
@@ -0,0 +1,212 @@
|
||||
.PHONY: clean deepclean prerequisite dependencies lightgbm rl develop lint docs package test analysis all install dev black pylint flake8 mypy nbqa nbconvert lint build upload docs-gen
|
||||
#You can modify it according to your terminal
|
||||
SHELL := /bin/bash
|
||||
|
||||
########################################################################################
|
||||
# Variables
|
||||
########################################################################################
|
||||
|
||||
# Documentation target directory, will be adapted to specific folder for readthedocs.
|
||||
PUBLIC_DIR := $(shell [ "$$READTHEDOCS" = "True" ] && echo "$$READTHEDOCS_OUTPUT/html" || echo "public")
|
||||
|
||||
SO_DIR := qlib/data/_libs
|
||||
SO_FILES := $(wildcard $(SO_DIR)/*.so)
|
||||
|
||||
ifeq ($(OS),Windows_NT)
|
||||
IS_WINDOWS = true
|
||||
else
|
||||
IS_WINDOWS = false
|
||||
endif
|
||||
|
||||
########################################################################################
|
||||
# Development Environment Management
|
||||
########################################################################################
|
||||
# Remove common intermediate files.
|
||||
clean:
|
||||
-rm -rf \
|
||||
$(PUBLIC_DIR) \
|
||||
qlib/data/_libs/*.cpp \
|
||||
qlib/data/_libs/*.so \
|
||||
mlruns \
|
||||
public \
|
||||
build \
|
||||
.coverage \
|
||||
.mypy_cache \
|
||||
.pytest_cache \
|
||||
.ruff_cache \
|
||||
Pipfile* \
|
||||
coverage.xml \
|
||||
dist \
|
||||
release-notes.md
|
||||
|
||||
find . -name '*.egg-info' -print0 | xargs -0 rm -rf
|
||||
find . -name '*.pyc' -print0 | xargs -0 rm -f
|
||||
find . -name '*.swp' -print0 | xargs -0 rm -f
|
||||
find . -name '.DS_Store' -print0 | xargs -0 rm -f
|
||||
find . -name '__pycache__' -print0 | xargs -0 rm -rf
|
||||
|
||||
# Remove pre-commit hook, virtual environment alongside itermediate files.
|
||||
deepclean: clean
|
||||
if command -v pre-commit > /dev/null 2>&1; then pre-commit uninstall --hook-type pre-push; fi
|
||||
if command -v pipenv >/dev/null 2>&1 && pipenv --venv >/dev/null 2>&1; then pipenv --rm; fi
|
||||
|
||||
# Prerequisite section
|
||||
# What this code does is compile two Cython modules, rolling and expanding, using setuptools and Cython,
|
||||
# and builds them as binary expansion modules that can be imported directly into Python.
|
||||
# Since pyproject.toml can't do that, we compile it here.
|
||||
|
||||
# pywinpty as a dependency of jupyter on windows, if you use pip install pywinpty installation,
|
||||
# will first download the tar.gz file, and then locally compiled and installed,
|
||||
# this will lead to some unnecessary trouble, so we choose to install the compiled whl file, to avoid trouble.
|
||||
prerequisite:
|
||||
@if [ -n "$(SO_FILES)" ]; then \
|
||||
echo "Shared library files exist, skipping build."; \
|
||||
else \
|
||||
echo "No shared library files found, building..."; \
|
||||
pip install --upgrade setuptools wheel; \
|
||||
python -m pip install cython numpy; \
|
||||
python -c "from setuptools import setup, Extension; from Cython.Build import cythonize; import numpy; extensions = [Extension('qlib.data._libs.rolling', ['qlib/data/_libs/rolling.pyx'], language='c++', include_dirs=[numpy.get_include()]), Extension('qlib.data._libs.expanding', ['qlib/data/_libs/expanding.pyx'], language='c++', include_dirs=[numpy.get_include()])]; setup(ext_modules=cythonize(extensions, language_level='3'), script_args=['build_ext', '--inplace'])"; \
|
||||
fi
|
||||
|
||||
@if [ "$(IS_WINDOWS)" = "true" ]; then \
|
||||
python -m pip install pywinpty --only-binary=:all:; \
|
||||
fi
|
||||
|
||||
# Install the package in editable mode.
|
||||
dependencies:
|
||||
python -m pip install --no-cache-dir -e .
|
||||
|
||||
lightgbm:
|
||||
python -m pip install --no-cache-dir lightgbm --prefer-binary
|
||||
|
||||
rl:
|
||||
python -m pip install --no-cache-dir -e .[rl]
|
||||
|
||||
develop:
|
||||
python -m pip install --no-cache-dir -e .[dev]
|
||||
|
||||
lint:
|
||||
python -m pip install --no-cache-dir -e .[lint]
|
||||
|
||||
docs:
|
||||
python -m pip install --no-cache-dir -e .[docs]
|
||||
|
||||
package:
|
||||
python -m pip install --no-cache-dir -e .[package]
|
||||
|
||||
test:
|
||||
python -m pip install --no-cache-dir -e .[test]
|
||||
|
||||
analysis:
|
||||
python -m pip install --no-cache-dir -e .[analysis]
|
||||
|
||||
client:
|
||||
python -m pip install --no-cache-dir -e .[client]
|
||||
|
||||
all:
|
||||
python -m pip install --no-cache-dir -e .[pywinpty,dev,lint,docs,package,test,analysis,rl]
|
||||
|
||||
install: prerequisite dependencies
|
||||
|
||||
dev: prerequisite all
|
||||
|
||||
########################################################################################
|
||||
# Lint and pre-commit
|
||||
########################################################################################
|
||||
|
||||
# Check lint with black.
|
||||
black:
|
||||
black . -l 120 --check --diff --exclude qlib/_version.py
|
||||
|
||||
# Check code folder with pylint.
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
# C0209: consider-using-f-string
|
||||
# R0402: consider-using-from-import
|
||||
# R1705: no-else-return
|
||||
# R1710: inconsistent-return-statements
|
||||
# R1725: super-with-arguments
|
||||
# R1735: use-dict-literal
|
||||
# W0102: dangerous-default-value
|
||||
# W0212: protected-access
|
||||
# W0221: arguments-differ
|
||||
# W0223: abstract-method
|
||||
# W0231: super-init-not-called
|
||||
# W0237: arguments-renamed
|
||||
# W0612: unused-variable
|
||||
# W0621: redefined-outer-name
|
||||
# W0622: redefined-builtin
|
||||
# FIXME: specify exception type
|
||||
# W0703: broad-except
|
||||
# W1309: f-string-without-interpolation
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# W4904: deprecated-class
|
||||
# R0917: too-many-positional-arguments
|
||||
# E1123: unexpected-keyword-arg
|
||||
# References for disable error: https://pylint.pycqa.org/en/latest/user_guide/messages/messages_overview.html
|
||||
# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000).
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
pylint:
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,W4904,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1730,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' qlib --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,E1123,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' scripts --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
|
||||
# Check code with flake8.
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
# Description: We have used black to limit the length of each line to 120.
|
||||
# F541 f-string is missing placeholders
|
||||
# Description: The same thing is done when using pylint for detection.
|
||||
# E266 too many leading '#' for block comment
|
||||
# Description: To make the code more readable, a lot of "#" is used.
|
||||
# This error code appears centrally in:
|
||||
# qlib/backtest/executor.py
|
||||
# qlib/data/ops.py
|
||||
# qlib/utils/__init__.py
|
||||
# E402 module level import not at top of file
|
||||
# Description: There are times when module level import is not available at the top of the file.
|
||||
# W503 line break before binary operator
|
||||
# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long.
|
||||
# E731 do not assign a lambda expression, use a def
|
||||
# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required.
|
||||
# E203 whitespace before ':'
|
||||
# Description: If there is whitespace before ":", it cannot pass the black check.
|
||||
flake8:
|
||||
flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib
|
||||
|
||||
# Check code with mypy.
|
||||
# https://github.com/python/mypy/issues/10600
|
||||
mypy:
|
||||
mypy qlib --install-types --non-interactive
|
||||
mypy qlib --verbose
|
||||
|
||||
# Check ipynb with nbqa.
|
||||
nbqa:
|
||||
nbqa black . -l 120 --check --diff
|
||||
nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}'
|
||||
|
||||
# Check ipynb with nbconvert.(Run after data downloads)
|
||||
# TODO: Add more ipynb files in future
|
||||
nbconvert:
|
||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||
|
||||
lint: black pylint flake8 mypy nbqa
|
||||
|
||||
########################################################################################
|
||||
# Package
|
||||
########################################################################################
|
||||
|
||||
# Build the package.
|
||||
build:
|
||||
python -m build --wheel
|
||||
|
||||
# Upload the package.
|
||||
upload:
|
||||
python -m twine upload dist/*
|
||||
|
||||
########################################################################################
|
||||
# Documentation
|
||||
########################################################################################
|
||||
|
||||
docs-gen:
|
||||
python -m sphinx.cmd.build -W docs $(PUBLIC_DIR)
|
||||
148
README.md
148
README.md
@@ -8,9 +8,44 @@
|
||||
[](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
## :newspaper: **What's NEW!** :sparkling_heart:
|
||||
|
||||
Recent released features
|
||||
|
||||
### Introducing <a href="https://github.com/microsoft/RD-Agent"><img src="docs/_static/img/rdagent_logo.png" alt="RD_Agent" style="height: 2em"></a>: LLM-Based Autonomous Evolving Agents for Industrial Data-Driven R&D
|
||||
|
||||
We are excited to announce the release of **RD-Agent**📢, a powerful tool that supports automated factor mining and model optimization in quant investment R&D.
|
||||
|
||||
RD-Agent is now available on [GitHub](https://github.com/microsoft/RD-Agent), and we welcome your star🌟!
|
||||
|
||||
To learn more, please visit the [RD-Agent repository](https://github.com/microsoft/RD-Agent). We have prepared several public demo videos for you:
|
||||
|
||||
| Scenario | Demo video (English) | Demo video (中文) |
|
||||
| -- | ------ | ------ |
|
||||
| Quant Factor Mining | [YouTube](https://www.youtube.com/watch?v=X4DK2QZKaKY&t=6s) | [YouTube](https://www.youtube.com/watch?v=X4DK2QZKaKY&t=6s) |
|
||||
| Quant Factor Mining from reports | [YouTube](https://www.youtube.com/watch?v=ECLTXVcSx-c) | [YouTube](https://www.youtube.com/watch?v=ECLTXVcSx-c) |
|
||||
| Quant Model Optimization | [YouTube](https://www.youtube.com/watch?v=dm0dWL49Bc0&t=104s) | [YouTube](https://www.youtube.com/watch?v=dm0dWL49Bc0&t=104s) |
|
||||
|
||||
- 📃**Paper**: [R&D-Agent-Quant: A Multi-Agent Framework for Data-Centric Factors and Model Joint Optimization](https://arxiv.org/abs/2505.15155)
|
||||
- 👾**Code**: https://github.com/microsoft/RD-Agent/
|
||||
```BibTeX
|
||||
@misc{li2025rdagentquant,
|
||||
title={R\&D-Agent-Quant: A Multi-Agent Framework for Data-Centric Factors and Model Joint Optimization},
|
||||
author={Yuante Li and Xu Yang and Xiao Yang and Minrui Xu and Xisen Wang and Weiqing Liu and Jiang Bian},
|
||||
year={2025},
|
||||
eprint={2505.15155},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.AI}
|
||||
}
|
||||
```
|
||||

|
||||
|
||||
***
|
||||
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| [R&D-Agent-Quant](https://arxiv.org/abs/2505.15155) Published | Apply R&D-Agent to Qlib for quant trading |
|
||||
| BPQP for End-to-end learning | 📈Coming soon!([Under review](https://github.com/microsoft/qlib/pull/1863)) |
|
||||
| 🔥LLM-driven Auto Quant Factory🔥 | 🚀 Released in [♾️RD-Agent](https://github.com/microsoft/RD-Agent) on Aug 8, 2024 |
|
||||
| KRNN and Sandwich models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1414/) on May 26, 2023 |
|
||||
| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |
|
||||
| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|
|
||||
@@ -40,7 +75,7 @@ Recent released features
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
|
||||
<img src="docs/_static/img/logo/1.png" />
|
||||
</p>
|
||||
|
||||
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
|
||||
@@ -132,17 +167,17 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
|
||||
## Installation
|
||||
|
||||
This table demonstrates the supported Python version of `Qlib`:
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:----:|
|
||||
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:------------------:|
|
||||
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
| Python 3.9 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.10 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.11 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.12 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
|
||||
**Note**:
|
||||
1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
1. `Qlib`Requires `tables` package, `hdf5` in tables does not support python3.9.
|
||||
2. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.8 or higher, or use `conda`'s Python to install ``Qlib`` from source.
|
||||
|
||||
### Install with pip
|
||||
Users can easily install ``Qlib`` by pip according to the following command.
|
||||
@@ -160,30 +195,43 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
|
||||
```bash
|
||||
pip install numpy
|
||||
pip install --upgrade cython
|
||||
pip install --upgrade cython
|
||||
```
|
||||
|
||||
* Clone the repository and install ``Qlib`` as follows.
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
pip install .
|
||||
pip install . # `pip install -e .[dev]` is recommended for development. check details in docs/developer/code_standard_and_dev_guide.rst
|
||||
```
|
||||
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommended approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
|
||||
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test_qlib_from_source.yml) may help you find the problem.
|
||||
|
||||
**Tips for Mac**: If you are using Mac with M1, you might encounter issues in building the wheel for LightGBM, which is due to missing dependencies from OpenMP. To solve the problem, install openmp first with ``brew install libomp`` and then run ``pip install .`` to build it successfully.
|
||||
|
||||
## Data Preparation
|
||||
❗ Due to more restrict data security policy. The official dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.
|
||||
Here is an example to download the latest data.
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/latest/download/qlib_bin.tar.gz
|
||||
mkdir -p ~/.qlib/qlib_data/cn_data
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=1
|
||||
rm -f qlib_bin.tar.gz
|
||||
```
|
||||
|
||||
The official dataset below will resume in short future.
|
||||
|
||||
|
||||
----
|
||||
|
||||
Load and prepare data by running the following code:
|
||||
|
||||
### Get with module
|
||||
```bash
|
||||
# get 1d data
|
||||
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# get 1min data
|
||||
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
|
||||
```
|
||||
|
||||
@@ -230,6 +278,16 @@ We recommend users to prepare their own data if they have a high-quality dataset
|
||||
* *trading_date*: start of trading day
|
||||
* *end_date*: end of trading day(not included)
|
||||
|
||||
### Checking the health of the data
|
||||
* We provide a script to check the health of the data, you can run the following commands to check whether the data is healthy or not.
|
||||
```
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
```
|
||||
* Of course, you can also add some parameters to adjust the test results, such as this.
|
||||
```
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data --missing_data_num 30055 --large_step_threshold_volume 94485 --large_step_threshold_price 20
|
||||
```
|
||||
* If you want more information about `check_data_health`, please refer to the [documentation](https://qlib.readthedocs.io/en/latest/component/data.html#checking-the-health-of-the-data).
|
||||
|
||||
<!--
|
||||
- Run the initialization code and get stock data:
|
||||
@@ -258,6 +316,38 @@ We recommend users to prepare their own data if they have a high-quality dataset
|
||||
```
|
||||
-->
|
||||
|
||||
## Docker images
|
||||
1. Pulling a docker image from a docker hub repository
|
||||
```bash
|
||||
docker pull pyqlib/qlib_image_stable:stable
|
||||
```
|
||||
2. Start a new Docker container
|
||||
```bash
|
||||
docker run -it --name <container name> -v <Mounted local directory>:/app pyqlib/qlib_image_stable:stable
|
||||
```
|
||||
3. At this point you are in the docker environment and can run the qlib scripts. An example:
|
||||
```bash
|
||||
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
>>> python qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
4. Exit the container
|
||||
```bash
|
||||
>>> exit
|
||||
```
|
||||
5. Restart the container
|
||||
```bash
|
||||
docker start -i -a <container name>
|
||||
```
|
||||
6. Stop the container
|
||||
```bash
|
||||
docker stop <container name>
|
||||
```
|
||||
7. Delete the container
|
||||
```bash
|
||||
docker rm <container name>
|
||||
```
|
||||
8. If you want to know more information, please refer to the [documentation](https://qlib.readthedocs.io/en/latest/developer/how_to_build_image.html).
|
||||
|
||||
## Auto Quant Research Workflow
|
||||
Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||
|
||||
@@ -268,9 +358,9 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
```
|
||||
If users want to use `qrun` under debug mode, please use the following command:
|
||||
```bash
|
||||
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pdb qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
The result of `qrun` is as follows, please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
|
||||
The result of `qrun` is as follows, please refer to [docs](https://qlib.readthedocs.io/en/latest/component/strategy.html#result) for more explanations about the result.
|
||||
|
||||
```bash
|
||||
|
||||
@@ -291,22 +381,22 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
```
|
||||
Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
|
||||
|
||||
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
|
||||
2. Graphical Reports Analysis: First, run `python -m pip install .[analysis]` to install the required dependencies. Then run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports.
|
||||
- Forecasting signal (model prediction) analysis
|
||||
- Cumulative Return of groups
|
||||

|
||||

|
||||
- Return distribution
|
||||

|
||||

|
||||
- Information Coefficient (IC)
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
- Auto Correlation of forecasting signal (model prediction)
|
||||

|
||||

|
||||
|
||||
- Portfolio analysis
|
||||
- Backtest return
|
||||

|
||||

|
||||
<!--
|
||||
- Score IC
|
||||

|
||||
@@ -386,6 +476,14 @@ python run_all_model.py run 10
|
||||
|
||||
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
### Break change
|
||||
In `pandas`, `group_key` is one of the parameters of the `groupby` method. From version 1.5 to 2.0 of `pandas`, the default value of `group_key` has been changed from `no default` to `True`, which will cause qlib to report an error during operation. So we set `group_key=False`, but it doesn't guarantee that some programmes will run correctly, including:
|
||||
* qlib\examples\rl_order_execution\scripts\gen_training_orders.py
|
||||
* qlib\examples\benchmarks\TRA\src\dataset.MTSDatasetH.py
|
||||
* qlib\examples\benchmarks\TFT\tft.py
|
||||
|
||||
|
||||
|
||||
## [Adapting to Market Dynamics](examples/benchmarks_dynamic)
|
||||
|
||||
Due to the non-stationary nature of the environment of the financial market, the data distribution may change in different periods, which makes the performance of models build on training data decays in the future test data.
|
||||
@@ -485,7 +583,7 @@ Qlib data are stored in a compact format, which is efficient to be combined into
|
||||
Join IM discussion groups:
|
||||
|[Gitter](https://gitter.im/Microsoft/qlib)|
|
||||
|----|
|
||||
||
|
||||
||
|
||||
|
||||
# Contributing
|
||||
We appreciate all contributions and thank all the contributors!
|
||||
@@ -521,7 +619,7 @@ You can find some impefect implementation in Qlib by `rg 'TODO|FIXME' qlib`
|
||||
|
||||
If you would like to become one of Qlib's maintainers to contribute more (e.g. help merge PR, triage issues), please contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)). We are glad to help to upgrade your permission.
|
||||
|
||||
## Licence
|
||||
## License
|
||||
Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the right to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
31
build_docker_image.sh
Normal file
31
build_docker_image.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
docker_user="your_dockerhub_username"
|
||||
|
||||
read -p "Do you want to build the nightly version of the qlib image? (default is stable) (yes/no): " answer;
|
||||
answer=$(echo "$answer" | tr '[:upper:]' '[:lower:]')
|
||||
|
||||
if [ "$answer" = "yes" ]; then
|
||||
# Build the nightly version of the qlib image
|
||||
docker build --build-arg IS_STABLE=no -t qlib_image -f ./Dockerfile .
|
||||
image_tag="nightly"
|
||||
else
|
||||
# Build the stable version of the qlib image
|
||||
docker build -t qlib_image -f ./Dockerfile .
|
||||
image_tag="stable"
|
||||
fi
|
||||
|
||||
read -p "Is it uploaded to docker hub? (default is no) (yes/no): " answer;
|
||||
answer=$(echo "$answer" | tr '[:upper:]' '[:lower:]')
|
||||
|
||||
if [ "$answer" = "yes" ]; then
|
||||
# Log in to Docker Hub
|
||||
# If you are a new docker hub user, please verify your email address before proceeding with this step.
|
||||
docker login
|
||||
# Tag the Docker image
|
||||
docker tag qlib_image "$docker_user/qlib_image:$image_tag"
|
||||
# Push the Docker image to Docker Hub
|
||||
docker push "$docker_user/qlib_image:$image_tag"
|
||||
else
|
||||
echo "Not uploaded to docker hub."
|
||||
fi
|
||||
BIN
docs/_static/img/rdagent_logo.png
vendored
Normal file
BIN
docs/_static/img/rdagent_logo.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 94 KiB |
@@ -42,7 +42,7 @@ Example
|
||||
|
||||
.. math::
|
||||
|
||||
DEA = \frac{EMA(DIF, 9)}{CLOSE}
|
||||
DEA = EMA(DIF, 9)
|
||||
|
||||
Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
|
||||
@@ -51,7 +51,7 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data.dataset.loader import QlibDataLoader
|
||||
>> MACD_EXP = '(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'
|
||||
>> MACD_EXP = '2 * ((EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9))'
|
||||
>> fields = [MACD_EXP] # MACD
|
||||
>> names = ['MACD']
|
||||
>> labels = ['Ref($close, -2)/Ref($close, -1) - 1'] # label
|
||||
@@ -66,17 +66,17 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
feature label
|
||||
MACD LABEL
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 -0.011547 -0.019672
|
||||
SH600004 0.002745 -0.014721
|
||||
SH600006 0.010133 0.002911
|
||||
SH600008 -0.001113 0.009818
|
||||
SH600009 0.025878 -0.017758
|
||||
2010-01-04 SH600000 0.008781 -0.019672
|
||||
SH600004 0.006699 -0.014721
|
||||
SH600006 0.005714 0.002911
|
||||
SH600008 0.000798 0.009818
|
||||
SH600009 0.017015 -0.017758
|
||||
... ... ...
|
||||
2017-12-29 SZ300124 0.007306 -0.005074
|
||||
SZ300136 -0.013492 0.056352
|
||||
SZ300144 -0.000966 0.011853
|
||||
SZ300251 0.004383 0.021739
|
||||
SZ300315 -0.030557 0.012455
|
||||
2017-12-29 SZ300124 0.015071 -0.005074
|
||||
SZ300136 -0.015466 0.056352
|
||||
SZ300144 0.013082 0.011853
|
||||
SZ300251 -0.001026 0.021739
|
||||
SZ300315 -0.007559 0.012455
|
||||
|
||||
Reference
|
||||
=========
|
||||
|
||||
@@ -108,10 +108,10 @@ Automatic update of daily frequency data
|
||||
|
||||
|
||||
|
||||
Converting CSV Format into Qlib Format
|
||||
--------------------------------------
|
||||
Converting CSV and Parquet Format into Qlib Format
|
||||
--------------------------------------------------
|
||||
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV or Parquet format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
|
||||
|
||||
Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format.
|
||||
Here are some example:
|
||||
@@ -126,17 +126,17 @@ for 1min data:
|
||||
|
||||
python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10
|
||||
|
||||
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
|
||||
Users can also provide their own data in CSV or Parquet format. However, the data **must satisfies** following criterions:
|
||||
|
||||
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
|
||||
- CSV or Parquet file is named after a specific stock *or* the CSV or Parquet file includes a column of the stock name
|
||||
|
||||
- Name the CSV file after a stock: `SH600000.csv`, `AAPL.csv` (not case sensitive).
|
||||
- Name the CSV or Parquet file after a stock: `SH600000.csv`, `AAPL.csv` or `SH600000.parquet`, `AAPL.parquet` (not case sensitive).
|
||||
|
||||
- CSV file includes a column of the stock name. User **must** specify the column name when dumping the data. Here is an example:
|
||||
- CSV or Parquet file includes a column of the stock name. User **must** specify the column name when dumping the data. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all ... --symbol_field_name symbol
|
||||
python scripts/dump_bin.py dump_all ... --symbol_field_name symbol --file_suffix <.csv or .parquet>
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
@@ -146,11 +146,11 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
|
||||
| SH600000 | 120 |
|
||||
+-----------+-------+
|
||||
|
||||
- CSV file **must** include a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
|
||||
- CSV or Parquet file **must** include a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all ... --date_field_name date
|
||||
python scripts/dump_bin.py dump_all ... --date_field_name date --file_suffix <.csv or .parquet>
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
@@ -163,23 +163,23 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
|
||||
+---------+------------+-------+------+----------+
|
||||
|
||||
|
||||
Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv_data/my_data``, they can run the following command to start the conversion.
|
||||
Supposed that users prepare their CSV or Parquet format data in the directory ``~/.qlib/my_data``, they can run the following command to start the conversion.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
|
||||
python scripts/dump_bin.py dump_all --data_path ~/.qlib/my_data --qlib_dir ~/.qlib/qlib_data/ --include_fields open,close,high,low,volume,factor --file_suffix <.csv or .parquet>
|
||||
|
||||
For other supported parameters when dumping the data into `.bin` file, users can refer to the information by running the following commands:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python dump_bin.py dump_all --help
|
||||
python scripts/dump_bin.py dump_all --help
|
||||
|
||||
After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/my_data`.
|
||||
After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/`.
|
||||
|
||||
.. note::
|
||||
|
||||
The arguments of `--include_fields` should correspond with the column names of CSV files. The columns names of dataset provided by ``Qlib`` should include open, close, high, low, volume and factor at least.
|
||||
The arguments of `--include_fields` should correspond with the column names of CSV or Parquet files. The columns names of dataset provided by ``Qlib`` should include open, close, high, low, volume and factor at least.
|
||||
|
||||
- `open`
|
||||
The adjusted opening price
|
||||
@@ -195,7 +195,58 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
The Restoration factor. Normally, ``factor = adjusted_price / original_price``, `adjusted price` reference: `split adjusted <https://www.investopedia.com/terms/s/splitadjusted.asp>`_
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
If you want to use your own alpha-factor which can't be calculate by OCHLV, like PE, EPS and so on, you could add it to the CSV files with OHCLV together and then dump it to the Qlib format data.
|
||||
If you want to use your own alpha-factor which can't be calculate by OCHLV, like PE, EPS and so on, you could add it to the CSV or Parquet files with OHCLV together and then dump it to the Qlib format data.
|
||||
|
||||
Checking the health of the data
|
||||
-------------------------------
|
||||
|
||||
``Qlib`` provides a script to check the health of the data.
|
||||
|
||||
- The main points to check are as follows
|
||||
|
||||
- Check if any data is missing in the DataFrame.
|
||||
|
||||
- Check if there are any large step changes above the threshold in the OHLCV columns.
|
||||
|
||||
- Check if any of the required columns (OLHCV) are missing in the DataFrame.
|
||||
|
||||
- Check if the 'factor' column is missing in the DataFrame.
|
||||
|
||||
- You can run the following commands to check whether the data is healthy or not.
|
||||
|
||||
for daily data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
|
||||
for 1min data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data_1min --freq 1min
|
||||
|
||||
- Of course, you can also add some parameters to adjust the test results.
|
||||
|
||||
- The available parameters are these.
|
||||
|
||||
- freq: Frequency of data.
|
||||
|
||||
- large_step_threshold_price: Maximum permitted price change
|
||||
|
||||
- large_step_threshold_volume: Maximum permitted volume change.
|
||||
|
||||
- missing_data_num: Maximum value for which data is allowed to be null.
|
||||
|
||||
- You can run the following commands to check whether the data is healthy or not.
|
||||
|
||||
for daily data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data --missing_data_num 30055 --large_step_threshold_volume 94485 --large_step_threshold_price 20
|
||||
|
||||
for 1min data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data --freq 1min --missing_data_num 35806 --large_step_threshold_volume 3205452000000 --large_step_threshold_price 0.91
|
||||
|
||||
Stock Pool (Market)
|
||||
-------------------
|
||||
|
||||
@@ -25,7 +25,7 @@ The design of the framework is shown in the yellow part in the middle of the fig
|
||||
|
||||
The frequency of the trading algorithm, decision content and execution environment can be customized by users (e.g. intraday trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The flexibility of the nested decision execution framework makes it easy for users to explore the effects of combining different levels of trading strategies and break down the optimization barriers between different levels of the trading algorithm.
|
||||
|
||||
The optimization for the nested decision execution framework can be implemented with the support of `QlibRL <https://qlib.readthedocs.io/en/latest/component/rl.html>`_. To know more about how to use the QlibRL, go to API Reference: `RL API <../reference/api.html#rl>`_.
|
||||
The optimization for the nested decision execution framework can be implemented with the support of `QlibRL <./rl/overall.html>`_. To know more about how to use the QlibRL, go to API Reference: `RL API <../reference/api.html#rl>`_.
|
||||
|
||||
Example
|
||||
=======
|
||||
|
||||
@@ -86,7 +86,7 @@ Example
|
||||
},
|
||||
}
|
||||
|
||||
# model initiaiton
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
|
||||
@@ -55,13 +55,16 @@ Below is a typical config file of ``qrun``.
|
||||
n_drop: 5
|
||||
signal: <PRED>
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
@@ -107,7 +110,7 @@ If users want to use ``qrun`` under debug mode, please use the following command
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pdb qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
@@ -21,8 +21,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from importlib.metadata import version as ver
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
@@ -63,9 +62,9 @@ author = "Microsoft"
|
||||
# built documents.
|
||||
#
|
||||
# The short X.Y version.
|
||||
version = pkg_resources.get_distribution("pyqlib").version
|
||||
version = ver("pyqlib")
|
||||
# The full version, including alpha/beta/rc tags.
|
||||
release = pkg_resources.get_distribution("pyqlib").version
|
||||
release = version
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
@@ -123,7 +122,6 @@ html_logo = "_static/img/logo/1.png"
|
||||
html_theme_options = {
|
||||
"logo_only": True,
|
||||
"collapse_navigation": False,
|
||||
"display_version": False,
|
||||
"navigation_depth": 4,
|
||||
}
|
||||
|
||||
|
||||
@@ -60,4 +60,4 @@ The `[dev]` option will help you to install some related packages when developin
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -e .[dev]
|
||||
pip install -e ".[dev]"
|
||||
81
docs/developer/how_to_build_image.rst
Normal file
81
docs/developer/how_to_build_image.rst
Normal file
@@ -0,0 +1,81 @@
|
||||
.. _docker_image:
|
||||
|
||||
==================
|
||||
Build Docker Image
|
||||
==================
|
||||
|
||||
Dockerfile
|
||||
==========
|
||||
|
||||
There is a **Dockerfile** file in the root directory of the project from which you can build the docker image. There are two build methods in Dockerfile to choose from.
|
||||
When executing the build command, use the ``--build-arg`` parameter to control the image version. The ``--build-arg`` parameter defaults to ``yes``, which builds the ``stable`` version of the qlib image.
|
||||
|
||||
1.For the ``stable`` version, use ``pip install pyqlib`` to build the qlib image.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker build --build-arg IS_STABLE=yes -t <image name> -f ./Dockerfile .
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker build -t <image name> -f ./Dockerfile .
|
||||
|
||||
2. For the ``nightly`` version, use current source code to build the qlib image.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker build --build-arg IS_STABLE=no -t <image name> -f ./Dockerfile .
|
||||
|
||||
Auto build of qlib images
|
||||
=========================
|
||||
|
||||
1. There is a **build_docker_image.sh** file in the root directory of your project, which can be used to automatically build docker images and upload them to your docker hub repository(Optional, configuration required).
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
sh build_docker_image.sh
|
||||
>>> Do you want to build the nightly version of the qlib image? (default is stable) (yes/no):
|
||||
>>> Is it uploaded to docker hub? (default is no) (yes/no):
|
||||
|
||||
2. If you want to upload the built image to your docker hub repository, you need to edit your **build_docker_image.sh** file first, fill in ``docker_user`` in the file, and then execute this file.
|
||||
|
||||
How to use qlib images
|
||||
======================
|
||||
1. Start a new Docker container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker run -it --name <container name> -v <Mounted local directory>:/app <image name>
|
||||
|
||||
2. At this point you are in the docker environment and can run the qlib scripts. An example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
>>> python qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
3. Exit the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
>>> exit
|
||||
|
||||
4. Restart the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker start -i -a <container name>
|
||||
|
||||
5. Stop the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker stop -i -a <container name>
|
||||
|
||||
6. Delete the container
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker rm <container name>
|
||||
|
||||
7. For more information on using docker see the `docker documentation <https://docs.docker.com/reference/cli/docker/>`_.
|
||||
@@ -61,6 +61,7 @@ Document Structure
|
||||
:caption: FOR DEVELOPERS:
|
||||
|
||||
Code Standard & Development Guidance <developer/code_standard_and_dev_guide.rst>
|
||||
How to build image <developer/how_to_build_image.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -5,3 +5,4 @@ scipy
|
||||
scikit-learn
|
||||
pandas
|
||||
tianshou
|
||||
sphinx_rtd_theme
|
||||
|
||||
@@ -129,7 +129,7 @@ For example, it looks quite long and complicated:
|
||||
|
||||
|
||||
But using string is not the only way to implement the expression. You can also implement expression by code.
|
||||
Here is an exmaple which does the same thing as above examples.
|
||||
Here is an example which does the same thing as above examples.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -71,7 +71,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
)
|
||||
|
||||
- Override the `predict` method
|
||||
- The parameters must include the parameter `dataset`, which will be userd to get the test dataset.
|
||||
- The parameters must include the parameter `dataset`, which will be used to get the test dataset.
|
||||
- Return the `prediction score`.
|
||||
- Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_ for the parameter types of the fit method.
|
||||
- Code Example: In the following example, users need to use `LightGBM` to predict the label(such as `preds`) of test data `x_test` and return it.
|
||||
|
||||
19
examples/benchmarks/GeneralPtNN/README.md
Normal file
19
examples/benchmarks/GeneralPtNN/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
|
||||
|
||||
# Introduction
|
||||
|
||||
What is GeneralPtNN
|
||||
- Fix previous design that fail to support both Time-series and tabular data
|
||||
- Now you can just replace the Pytorch model structure to run a NN model.
|
||||
|
||||
We provide an example to demonstrate the effectiveness of the current design.
|
||||
- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158-dataset)
|
||||
- `workflow_config_gru2mlp.yaml` to demonstrate we can convert config from time-series to tabular data with minimal changes
|
||||
- You only have to change the net & dataset class to make the conversion.
|
||||
- `workflow_config_mlp.yaml` achieved similar functionality with [MLP](../README.md#Alpha158-dataset)
|
||||
|
||||
# TODO
|
||||
|
||||
- We will align existing models to current design.
|
||||
|
||||
- The result of `workflow_config_mlp.yaml` is different with the result of [MLP](../README.md#Alpha158-dataset) since GeneralPtNN has a different stopping method compared to previous implementations. Specificly, GeneralPtNN controls training according to epoches, whereas previous methods controlled by max_steps.
|
||||
100
examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml
Executable file
100
examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml
Executable file
@@ -0,0 +1,100 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
n_epochs: 200
|
||||
lr: 2e-4
|
||||
early_stop: 10
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
n_jobs: 20
|
||||
GPU: 0
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_gru_ts.GRUModel"
|
||||
pt_model_kwargs: {
|
||||
"d_feat": 20,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.,
|
||||
}
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
93
examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml
Normal file
93
examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
lr: 1e-3
|
||||
n_epochs: 1
|
||||
batch_size: 800
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
|
||||
pt_model_kwargs:
|
||||
input_dim: 20
|
||||
layers: [20,]
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
98
examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml
Normal file
98
examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml
Normal file
@@ -0,0 +1,98 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "CSZFillna",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
}
|
||||
]
|
||||
learn_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "DropnaProcessor",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
},
|
||||
"DropnaLabel",
|
||||
{
|
||||
"class": "CSZScoreNorm",
|
||||
"kwargs": {"fields_group": "label"}
|
||||
}
|
||||
]
|
||||
process_type: "independent"
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
# FIXME: wrong parameters.
|
||||
lr: 2e-3
|
||||
batch_size: 8192
|
||||
loss: mse
|
||||
weight_decay: 0.0002
|
||||
optimizer: adam
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
|
||||
pt_model_kwargs:
|
||||
input_dim: 157
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -599,7 +599,7 @@ class TemporalFusionTransformer:
|
||||
print("Getting valid sampling locations.")
|
||||
valid_sampling_locations = []
|
||||
split_data_map = {}
|
||||
for identifier, df in data.groupby(id_col):
|
||||
for identifier, df in data.groupby(id_col, group_key=False):
|
||||
print("Getting locations for {}".format(identifier))
|
||||
num_entries = len(df)
|
||||
if num_entries >= self.time_steps:
|
||||
@@ -678,7 +678,7 @@ class TemporalFusionTransformer:
|
||||
input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
|
||||
|
||||
data_map = {}
|
||||
for _, sliced in data.groupby(id_col):
|
||||
for _, sliced in data.groupby(id_col, group_keys=False):
|
||||
col_mappings = {"identifier": [id_col], "time": [time_col], "outputs": [target_col], "inputs": input_cols}
|
||||
|
||||
for k in col_mappings:
|
||||
|
||||
@@ -19,7 +19,6 @@ from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
# To register new datasets, please add them here.
|
||||
ALLOW_DATASET = ["Alpha158", "Alpha360"]
|
||||
# To register new datasets, please add their configurations here.
|
||||
@@ -78,13 +77,15 @@ DATASET_SETTING = {
|
||||
|
||||
|
||||
def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
|
||||
return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
|
||||
return data_df[[col_shift]].groupby("instrument", group_keys=False).apply(lambda df: df.shift(shifts))
|
||||
|
||||
|
||||
def fill_test_na(test_df):
|
||||
test_df_res = test_df.copy()
|
||||
feature_cols = ~test_df_res.columns.str.contains("label", case=False)
|
||||
test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
|
||||
test_feature_fna = (
|
||||
test_df_res.loc[:, feature_cols].groupby("datetime", group_keys=False).apply(lambda df: df.fillna(df.mean()))
|
||||
)
|
||||
test_df_res.loc[:, feature_cols] = test_feature_fna
|
||||
return test_df_res
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import argparse
|
||||
|
||||
import qlib
|
||||
import ruamel.yaml as yaml
|
||||
from ruamel.yaml import YAML
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
|
||||
def main(seed, config_file="configs/config_alstm.yaml"):
|
||||
# set random seed
|
||||
with open(config_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(f)
|
||||
|
||||
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
|
||||
seed_suffix = ""
|
||||
|
||||
@@ -8,7 +8,6 @@ import pandas as pd
|
||||
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@@ -29,7 +28,7 @@ def _create_ts_slices(index, seq_len):
|
||||
assert index.is_lexsorted(), "index should be sorted"
|
||||
|
||||
# number of dates for each code
|
||||
sample_count_by_codes = pd.Series(0, index=index).groupby(level=0).size().values
|
||||
sample_count_by_codes = pd.Series(0, index=index).groupby(level=0, group_keys=False).size().values
|
||||
|
||||
# start_index for each code
|
||||
start_index_of_codes = np.roll(np.cumsum(sample_count_by_codes), 1)
|
||||
|
||||
@@ -110,7 +110,6 @@ task:
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
|
||||
@@ -104,7 +104,6 @@ task:
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
|
||||
@@ -104,7 +104,6 @@ task:
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size: 6
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
sns.set(color_codes=True)
|
||||
plt.rcParams["font.sans-serif"] = "SimHei"
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
@@ -18,7 +19,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
# +
|
||||
with open("./internal_data_s20.pkl", "rb") as f:
|
||||
data = pickle.load(f)
|
||||
data = restricted_pickle_load(f)
|
||||
|
||||
data.data_ic_df.columns.names = ["start_date", "end_date"]
|
||||
|
||||
@@ -52,7 +53,7 @@ pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].rolling(5).mean(
|
||||
|
||||
# +
|
||||
with open("./tasks_s20.pkl", "rb") as f:
|
||||
tasks = pickle.load(f)
|
||||
tasks = restricted_pickle_load(f)
|
||||
|
||||
task_df = {}
|
||||
for t in tasks:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -35,6 +36,10 @@ class DDGDABench(DDGDA):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
auto_init()
|
||||
kwargs = {}
|
||||
if os.environ.get("PROVIDER_URI", "") == "":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
else:
|
||||
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
|
||||
auto_init(**kwargs)
|
||||
fire.Fire(DDGDABench)
|
||||
|
||||
@@ -7,7 +7,7 @@ The table below shows the performances of different solutions on different forec
|
||||
## Alpha158 Dataset
|
||||
Here is the [crowd sourced version of qlib data](data_collector/crowd_source/README.md): https://github.com/chenditc/investment_data/releases
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
wget https://github.com/chenditc/investment_data/releases/latest/download/qlib_bin.tar.gz
|
||||
mkdir -p ~/.qlib/qlib_data/cn_data
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
rm -f qlib_bin.tar.gz
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -31,6 +32,10 @@ class RollingBenchmark(Rolling):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
auto_init()
|
||||
kwargs = {}
|
||||
if os.environ.get("PROVIDER_URI", "") == "":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
else:
|
||||
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
|
||||
auto_init(**kwargs)
|
||||
fire.Fire(RollingBenchmark)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from pprint import pprint
|
||||
from ruamel.yaml import YAML
|
||||
import subprocess
|
||||
import yaml
|
||||
from qlib.log import TimeInspector
|
||||
|
||||
from qlib import init
|
||||
@@ -30,7 +30,8 @@ if __name__ == "__main__":
|
||||
subprocess.run(f"qrun {config_path}", shell=True)
|
||||
|
||||
# 2) dump handler
|
||||
task_config = yaml.safe_load(config_path.open())
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
task_config = yaml.load(config_path.open())
|
||||
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
|
||||
pprint(hd_conf)
|
||||
hd: DataHandlerLP = init_instance_by_config(hd_conf)
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from pprint import pprint
|
||||
from ruamel.yaml import YAML
|
||||
import subprocess
|
||||
|
||||
import yaml
|
||||
|
||||
from qlib import init
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.log import TimeInspector
|
||||
@@ -29,7 +28,8 @@ if __name__ == "__main__":
|
||||
exp_name = "data_mem_reuse_demo"
|
||||
|
||||
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
|
||||
task_config = yaml.safe_load(config_path.open())
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
task_config = yaml.load(config_path.open())
|
||||
|
||||
# 1) without using processed data in memory
|
||||
with TimeInspector.logt("The original time without reusing processed data in memory:"):
|
||||
|
||||
@@ -25,7 +25,7 @@ class DayLast(ElemOperator):
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("last")
|
||||
return series.groupby(_calendar[series.index], group_keys=False).transform("last")
|
||||
|
||||
|
||||
class FFillNan(ElemOperator):
|
||||
@@ -44,7 +44,7 @@ class FFillNan(ElemOperator):
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="ffill")
|
||||
return series.ffill()
|
||||
|
||||
|
||||
class BFillNan(ElemOperator):
|
||||
@@ -63,7 +63,7 @@ class BFillNan(ElemOperator):
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="bfill")
|
||||
return series.bfill()
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
import fire
|
||||
|
||||
import qlib
|
||||
import pickle
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.config import HIGH_FREQ_CONFIG
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
@@ -125,10 +125,10 @@ class HighfreqWorkflow:
|
||||
del dataset, dataset_backtest
|
||||
##=============reload dataset=============
|
||||
with open("dataset.pkl", "rb") as file_dataset:
|
||||
dataset = pickle.load(file_dataset)
|
||||
dataset = restricted_pickle_load(file_dataset)
|
||||
|
||||
with open("dataset_backtest.pkl", "rb") as file_dataset_backtest:
|
||||
dataset_backtest = pickle.load(file_dataset_backtest)
|
||||
dataset_backtest = restricted_pickle_load(file_dataset_backtest)
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reinit dataset=============
|
||||
|
||||
@@ -9,7 +9,6 @@ from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
|
||||
@@ -95,7 +95,6 @@ pos 0.000000
|
||||
[1706497:MainThread](2021-12-07 14:08:30,627) INFO - qlib.timer - [log.py:113] - Time cost: 0.014s | waiting `async_log` Done
|
||||
"""
|
||||
|
||||
|
||||
from copy import deepcopy
|
||||
import qlib
|
||||
import fire
|
||||
|
||||
@@ -7,6 +7,7 @@ There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained models to the `online` models.
|
||||
Next, we will finish updating online predictions.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import fire
|
||||
import qlib
|
||||
|
||||
@@ -16,7 +16,7 @@ Current version of script with default value tries to connect localhost **via de
|
||||
|
||||
Run following command to install necessary libraries
|
||||
```
|
||||
pip install pytest coverage
|
||||
pip install pytest coverage gdown
|
||||
pip install arctic # NOTE: pip may fail to resolve the right package dependency !!! Please make sure the dependency are satisfied.
|
||||
```
|
||||
|
||||
@@ -27,7 +27,8 @@ pip install arctic # NOTE: pip may fail to resolve the right package dependency
|
||||
2. Please follow following steps to download example data
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
python ../../scripts/get_data.py download_data --target_dir . --file_name highfreq_orderbook_example_data.zip
|
||||
gdown https://drive.google.com/uc?id=15FuUqWn2rkCi8uhJYGEQWKakcEqLJNDG # Proxies may be necessary here.
|
||||
python ../../scripts/get_data.py _unzip --file_path highfreq_orderbook_example_data.zip --target_dir .
|
||||
```
|
||||
|
||||
3. Please import the example data to your mongo db
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
NOTE:
|
||||
- This scripts is a demo to import example data import Qlib
|
||||
- !!!!!!!!!!!!!!!TODO!!!!!!!!!!!!!!!!!!!:
|
||||
- Its structure is not well designed and very ugly, your contribution is welcome to make importing dataset easier
|
||||
NOTE:
|
||||
- This scripts is a demo to import example data import Qlib
|
||||
- !!!!!!!!!!!!!!!TODO!!!!!!!!!!!!!!!!!!!:
|
||||
- Its structure is not well designed and very ugly, your contribution is welcome to make importing dataset easier
|
||||
"""
|
||||
|
||||
from datetime import date, datetime as dt
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -20,7 +20,7 @@ We use China stock market data for our example.
|
||||
1. Prepare CSI300 weight:
|
||||
|
||||
```bash
|
||||
wget http://fintech.msra.cn/stock_data/downloads/csi300_weight.zip
|
||||
wget https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/csi300_weight.zip
|
||||
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
|
||||
rm -f csi300_weight.zip
|
||||
```
|
||||
|
||||
@@ -7,7 +7,7 @@ This folder comprises an example of Reinforcement Learning (RL) workflows for or
|
||||
### Get Data
|
||||
|
||||
```
|
||||
python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region hs300 --interval 5min
|
||||
python -m qlib.cli.data qlib_data --target_dir ./data/bin --region hs300 --interval 5min
|
||||
```
|
||||
|
||||
### Generate Pickle-Style Data
|
||||
|
||||
@@ -19,9 +19,9 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
|
||||
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
df = df.set_index(["instrument", "datetime", "date"])
|
||||
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
|
||||
df = df.groupby("date", group_keys=False).take(range(start_idx, end_idx)).droplevel(level=0)
|
||||
|
||||
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
|
||||
order_all = pd.DataFrame(df.groupby(level=(2, 0), group_keys=False).mean().dropna())
|
||||
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
|
||||
order_all = order_all[order_all["amount"] > 0.0]
|
||||
order_all["order_type"] = 0
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import pickle
|
||||
import os
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
for tag in ["test", "valid"]:
|
||||
files = os.listdir(os.path.join("data/orders/", tag))
|
||||
dfs = []
|
||||
for f in tqdm(files):
|
||||
df = pickle.load(open(os.path.join("data/orders/", tag, f), "rb"))
|
||||
with open(os.path.join("data/orders/", tag, f), "rb") as fr:
|
||||
df = restricted_pickle_load(fr)
|
||||
df = df.drop(["$close0"], axis=1)
|
||||
dfs.append(df)
|
||||
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import pickle
|
||||
|
||||
from datetime import datetime
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class RollingDataWorkflow:
|
||||
|
||||
def _load_pre_handler(self, path):
|
||||
with open(path, "rb") as file_dataset:
|
||||
pre_handler = pickle.load(file_dataset)
|
||||
pre_handler = restricted_pickle_load(file_dataset)
|
||||
return pre_handler
|
||||
|
||||
def rolling_process(self):
|
||||
|
||||
@@ -6,7 +6,6 @@ import sys
|
||||
import fire
|
||||
import time
|
||||
import glob
|
||||
import yaml
|
||||
import shutil
|
||||
import signal
|
||||
import inspect
|
||||
@@ -15,6 +14,7 @@ import functools
|
||||
import statistics
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from ruamel.yaml import YAML
|
||||
from pathlib import Path
|
||||
from operator import xor
|
||||
from pprint import pprint
|
||||
@@ -188,7 +188,8 @@ def gen_and_save_md_table(metrics, dataset):
|
||||
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
||||
def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
|
||||
with open(yaml_path, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(fp)
|
||||
try:
|
||||
del config["task"]["model"]["kwargs"]["seed"]
|
||||
except KeyError:
|
||||
|
||||
@@ -171,7 +171,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import plotly.graph_objects as go\n",
|
||||
"import plotly.io as pio\n",
|
||||
"\n",
|
||||
"pio.renderers.default = \"notebook\"\n",
|
||||
"fig = go.Figure(\n",
|
||||
" data=[\n",
|
||||
" go.Candlestick(\n",
|
||||
|
||||
@@ -161,7 +161,7 @@
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# model initiaiton\n",
|
||||
"# model initialization\n",
|
||||
"model = init_instance_by_config(task[\"model\"])\n",
|
||||
"dataset = init_instance_by_config(task[\"dataset\"])\n",
|
||||
"\n",
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Qlib provides two kinds of interfaces.
|
||||
Qlib provides two kinds of interfaces.
|
||||
(1) Users could define the Quant research workflow by a simple configuration.
|
||||
(2) Qlib is designed in a modularized way and supports creating research workflow by code just like building blocks.
|
||||
|
||||
The interface of (1) is `qrun XXX.yaml`. The interface of (2) is script like this, which nearly does the same thing as `qrun XXX.yaml`
|
||||
"""
|
||||
|
||||
import qlib
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
@@ -15,7 +16,6 @@ from qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
|
||||
125
pyproject.toml
125
pyproject.toml
@@ -1,2 +1,125 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "numpy", "Cython"]
|
||||
requires = ["setuptools", "setuptools-scm", "cython", "numpy>=1.24.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Operating System :: MacOS",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
name = "pyqlib"
|
||||
dynamic = ["version"]
|
||||
description = "A Quantitative-research Platform"
|
||||
requires-python = ">=3.8.0"
|
||||
readme = {file = "README.md", content-type = "text/markdown"}
|
||||
license = { text = "MIT" }
|
||||
|
||||
dependencies = [
|
||||
"pyyaml",
|
||||
"numpy",
|
||||
# Since version 1.1.0, pandas supports the ffill and bfill methods.
|
||||
# Since version 2.1.0, pandas has deprecated the method parameter of the fillna method.
|
||||
# qlib has updated the fillna method in PR 1987 and limited the minimum version of pandas.
|
||||
"pandas>=1.1",
|
||||
# I encoutered an Error that the set_uri does not work when downloading artifacts in mlflow 3.1.1;
|
||||
# But earlier versions of mlflow does not have this problem.
|
||||
# But when I switch to 2.*.* version, another error occurs, which is even more strange...
|
||||
"mlflow",
|
||||
"filelock>=3.16.0",
|
||||
"redis",
|
||||
"dill",
|
||||
"fire",
|
||||
"ruamel.yaml>=0.17.38",
|
||||
"python-redis-lock",
|
||||
"tqdm",
|
||||
"pymongo",
|
||||
"loguru",
|
||||
"lightgbm",
|
||||
"gym",
|
||||
"cvxpy",
|
||||
"joblib",
|
||||
"matplotlib",
|
||||
"jupyter",
|
||||
"nbconvert",
|
||||
"pyarrow",
|
||||
"pydantic-settings",
|
||||
"setuptools-scm",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest",
|
||||
"statsmodels",
|
||||
]
|
||||
# On macos-13 system, when using python version greater than or equal to 3.10,
|
||||
# pytorch can't fully support Numpy version above 2.0, so, when you want to install torch,
|
||||
# it will limit the version of Numpy less than 2.0.
|
||||
rl = [
|
||||
"tianshou<=0.4.10",
|
||||
"torch",
|
||||
"numpy<2.0.0",
|
||||
]
|
||||
|
||||
lint = [
|
||||
"black",
|
||||
"pylint",
|
||||
"mypy<1.5.0",
|
||||
"flake8",
|
||||
"nbqa",
|
||||
]
|
||||
# snowballstemmer, a dependency of sphinx, was released on 2025-05-08 with version 3.0.0,
|
||||
# which causes errors in the build process. So we've limited the version for now.
|
||||
docs = [
|
||||
# After upgrading scipy to version 1.16.0,
|
||||
# we encountered ImportError: cannot import name '_lazywhere', in the build documentation,
|
||||
# so we restricted the version of scipy to: 1.15.3
|
||||
"scipy<=1.15.3",
|
||||
"sphinx",
|
||||
"sphinx_rtd_theme",
|
||||
"readthedocs_sphinx_ext",
|
||||
"snowballstemmer<3.0",
|
||||
]
|
||||
package = [
|
||||
"twine",
|
||||
"build",
|
||||
]
|
||||
# test_pit dependency packages
|
||||
test = [
|
||||
"yahooquery",
|
||||
"baostock",
|
||||
]
|
||||
analysis = [
|
||||
"plotly",
|
||||
"statsmodels",
|
||||
]
|
||||
client = [
|
||||
"python-socketio<6",
|
||||
"tables",
|
||||
]
|
||||
|
||||
# In the process of releasing a new version, when checking the manylinux package with twine, an error is reported:
|
||||
# InvalidDistribution: Invalid distribution metadata: unrecognized or malformed field 'license-file'
|
||||
# To solve this problem, we added license-files here. Refs: https://github.com/pypa/twine/issues/1216
|
||||
[tool.setuptools]
|
||||
packages = [
|
||||
"qlib",
|
||||
]
|
||||
license-files = []
|
||||
|
||||
[project.scripts]
|
||||
qrun = "qlib.cli.run:run"
|
||||
|
||||
[tool.setuptools_scm]
|
||||
local_scheme = "no-local-version"
|
||||
version_scheme = "guess-next-dev"
|
||||
write_to = "qlib/_version.py"
|
||||
|
||||
@@ -2,14 +2,22 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.9.4.99"
|
||||
from setuptools_scm import get_version
|
||||
|
||||
try:
|
||||
from ._version import version as __version__
|
||||
except ImportError:
|
||||
__version__ = get_version(root="..", relative_to=__file__)
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
import yaml
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
from typing import Union
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
@@ -80,34 +88,41 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
LOG = get_module_logger("mount nfs", level=logging.INFO)
|
||||
if mount_path is None:
|
||||
raise ValueError(f"Invalid mount path: {mount_path}!")
|
||||
if not re.match(r"^[a-zA-Z0-9.:/\-_]+$", provider_uri):
|
||||
raise ValueError(f"Invalid provider_uri format: {provider_uri}")
|
||||
# FIXME: the C["provider_uri"] is modified in this function
|
||||
# If it is not modified, we can pass only provider_uri or mount_path instead of C
|
||||
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
|
||||
mount_command = ["sudo", "mount.nfs", provider_uri, mount_path]
|
||||
# If the provider uri looks like this 172.23.233.89//data/csdesign'
|
||||
# It will be a nfs path. The client provider will be used
|
||||
if not auto_mount: # pylint: disable=R1702
|
||||
if not Path(mount_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
f"Invalid mount path: {mount_path}! Please mount manually: {' '.join(mount_command)} or Set init parameter `auto_mount=True`"
|
||||
)
|
||||
else:
|
||||
# Judging system type
|
||||
sys_type = platform.system()
|
||||
if "windows" in sys_type.lower():
|
||||
# system: window
|
||||
exec_result = os.popen(f"mount -o anon {provider_uri} {mount_path}")
|
||||
result = exec_result.read()
|
||||
if "85" in result:
|
||||
LOG.warning(f"{provider_uri} on Windows:{mount_path} is already mounted")
|
||||
elif "53" in result:
|
||||
raise OSError("not find network path")
|
||||
elif "error" in result or "错误" in result:
|
||||
raise OSError("Invalid mount path")
|
||||
elif provider_uri in result:
|
||||
LOG.info("window success mount..")
|
||||
else:
|
||||
raise OSError(f"unknown error: {result}")
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["mount", "-o", "anon", provider_uri, mount_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
LOG.info("Mount finished.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_output = (e.stdout or "") + (e.stderr or "")
|
||||
if e.returncode == 85:
|
||||
LOG.warning(f"{provider_uri} already mounted at {mount_path}")
|
||||
elif e.returncode == 53:
|
||||
raise OSError("Network path not found") from e
|
||||
elif "error" in error_output.lower() or "错误" in error_output:
|
||||
raise OSError("Invalid mount path") from e
|
||||
else:
|
||||
raise OSError(f"Unknown mount error: {error_output.strip()}") from e
|
||||
else:
|
||||
# system: linux/Unix/Mac
|
||||
# check mount
|
||||
@@ -119,15 +134,19 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
_is_mount = False
|
||||
while _check_level_num:
|
||||
with subprocess.Popen(
|
||||
'mount | grep "{}"'.format(_remote_uri),
|
||||
shell=True,
|
||||
["mount"],
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as shell_r:
|
||||
_command_log = shell_r.stdout.readlines()
|
||||
_command_log = [line for line in _command_log if _remote_uri in line]
|
||||
if len(_command_log) > 0:
|
||||
for _c in _command_log:
|
||||
_temp_mount = _c.decode("utf-8").split(" ")[2]
|
||||
if isinstance(_c, str):
|
||||
_temp_mount = _c.split(" ")[2]
|
||||
else:
|
||||
_temp_mount = _c.decode("utf-8").split(" ")[2]
|
||||
_temp_mount = _temp_mount[:-1] if _temp_mount.endswith("/") else _temp_mount
|
||||
if _temp_mount == _mount_path:
|
||||
_is_mount = True
|
||||
@@ -152,16 +171,16 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
if not command_res:
|
||||
raise OSError("nfs-common is not found, please install it by execute: sudo apt install nfs-common")
|
||||
# manually mount
|
||||
command_status = os.system(mount_command)
|
||||
if command_status == 256:
|
||||
raise OSError(
|
||||
f"mount {provider_uri} on {mount_path} error! Needs SUDO! Please mount manually: {mount_command}"
|
||||
)
|
||||
elif command_status == 32512:
|
||||
# LOG.error("Command error")
|
||||
raise OSError(f"mount {provider_uri} on {mount_path} error! Command error")
|
||||
elif command_status == 0:
|
||||
LOG.info("Mount finished")
|
||||
try:
|
||||
subprocess.run(mount_command, check=True, capture_output=True, text=True)
|
||||
LOG.info("Mount finished.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
if e.returncode == 256:
|
||||
raise OSError("Mount failed: requires sudo or permission denied") from e
|
||||
elif e.returncode == 32512:
|
||||
raise OSError(f"mount {provider_uri} on {mount_path} error! Command error") from e
|
||||
else:
|
||||
raise OSError(f"Mount failed: {e.stderr}") from e
|
||||
else:
|
||||
LOG.warning(f"{_remote_uri} on {_mount_path} is already mounted")
|
||||
|
||||
@@ -176,7 +195,8 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
config = {}
|
||||
else:
|
||||
with open(conf_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(f)
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
@@ -272,7 +292,8 @@ def auto_init(**kwargs):
|
||||
logger = get_module_logger("Initialization")
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
conf = yaml.load(f)
|
||||
|
||||
conf_type = conf.get("conf_type", "origin")
|
||||
if conf_type == "origin":
|
||||
|
||||
@@ -18,7 +18,6 @@ from tqdm.auto import tqdm
|
||||
|
||||
from ..utils.time import Freq
|
||||
|
||||
|
||||
PORT_METRIC = Dict[str, Tuple[pd.DataFrame, dict]]
|
||||
INDICATOR_METRIC = Dict[str, Tuple[pd.DataFrame, Indicator]]
|
||||
|
||||
|
||||
@@ -897,6 +897,7 @@ class Exchange:
|
||||
# if we don't know current position, we choose to sell all
|
||||
# Otherwise, we clip the amount based on current position
|
||||
if position is not None:
|
||||
# TODO: make the trading shortable
|
||||
current_amount = (
|
||||
position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
|
||||
)
|
||||
|
||||
@@ -104,7 +104,7 @@ class PandasQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
|
||||
super().__init__(quote_df=quote_df, freq=freq)
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument", group_keys=False):
|
||||
quote_dict[stock_id] = stock_val.droplevel(level="instrument")
|
||||
self.data = quote_dict
|
||||
|
||||
@@ -137,7 +137,7 @@ class NumpyQuote(BaseQuote):
|
||||
"""
|
||||
super().__init__(quote_df=quote_df, freq=freq)
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument", group_keys=False):
|
||||
quote_dict[stock_id] = idd.MultiData(stock_val.droplevel(level="instrument"))
|
||||
quote_dict[stock_id].sort_index() # To support more flexible slicing, we must sort data first
|
||||
self.data = quote_dict
|
||||
@@ -278,7 +278,7 @@ class BaseSingleMetric:
|
||||
raise NotImplementedError(f"Please implement the `empty` method")
|
||||
|
||||
def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric:
|
||||
"""Replace np.NaN with fill_value in two metrics and add them."""
|
||||
"""Replace np.nan with fill_value in two metrics and add them."""
|
||||
|
||||
raise NotImplementedError(f"Please implement the `add` method")
|
||||
|
||||
@@ -412,7 +412,7 @@ class BaseOrderIndicator:
|
||||
metrics : Union[str, List[str]]
|
||||
all metrics needs to be sumed.
|
||||
fill_value : float, optional
|
||||
fill np.NaN with value. By default None.
|
||||
fill np.nan with value. By default None.
|
||||
"""
|
||||
|
||||
raise NotImplementedError(f"Please implement the 'sum_all_indicators' method")
|
||||
|
||||
@@ -311,7 +311,7 @@ class Position(BasePosition):
|
||||
freq=freq,
|
||||
disk_cache=True,
|
||||
).dropna()
|
||||
price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict()
|
||||
price_dict = price_df.groupby(["instrument"], group_keys=False).tail(1)["$close"].to_dict()
|
||||
|
||||
if len(price_dict) < len(stock_list):
|
||||
lack_stock = set(stock_list) - set(price_dict)
|
||||
|
||||
@@ -281,13 +281,13 @@ def brinson_pa(
|
||||
|
||||
stock_group_field = stock_df[group_field].unstack().T
|
||||
# FIXME: some attributes of some suspend stock is NAN.
|
||||
stock_group_field = stock_group_field.fillna(method="ffill")
|
||||
stock_group_field = stock_group_field.ffill()
|
||||
stock_group_field = stock_group_field.loc[start_date:end_date]
|
||||
|
||||
stock_group = get_stock_group(stock_group_field, bench_stock_weight, group_method, group_n)
|
||||
|
||||
deal_price_df = stock_df["deal_price"].unstack().T
|
||||
deal_price_df = deal_price_df.fillna(method="ffill")
|
||||
deal_price_df = deal_price_df.ffill()
|
||||
|
||||
# NOTE:
|
||||
# The return will be slightly different from the of the return in the report.
|
||||
|
||||
@@ -114,7 +114,11 @@ class PortfolioMetrics:
|
||||
_temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq)
|
||||
if len(_temp_result) == 0:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
return (
|
||||
_temp_result.groupby(level="datetime", group_keys=False)[_temp_result.columns.tolist()[0]]
|
||||
.mean()
|
||||
.fillna(0)
|
||||
)
|
||||
|
||||
def _sample_benchmark(
|
||||
self,
|
||||
@@ -325,9 +329,9 @@ class Indicator:
|
||||
|
||||
def _update_order_fulfill_rate(self) -> None:
|
||||
def func(deal_amount, amount):
|
||||
# deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
|
||||
# deal_amount is np.nan or None when there is no inner decision. So full fill rate is 0.
|
||||
tmp_deal_amount = deal_amount.reindex(amount.index, 0)
|
||||
tmp_deal_amount = tmp_deal_amount.replace({np.NaN: 0})
|
||||
tmp_deal_amount = tmp_deal_amount.replace({np.nan: 0})
|
||||
return tmp_deal_amount / amount
|
||||
|
||||
self.order_indicator.transfer(func, "ffr")
|
||||
@@ -354,8 +358,8 @@ class Indicator:
|
||||
)
|
||||
|
||||
def func(trade_price, deal_amount):
|
||||
# trade_price is np.NaN instead of inf when deal_amount is zero.
|
||||
tmp_deal_amount = deal_amount.replace({0: np.NaN})
|
||||
# trade_price is np.nan instead of inf when deal_amount is zero.
|
||||
tmp_deal_amount = deal_amount.replace({0: np.nan})
|
||||
return trade_price / tmp_deal_amount
|
||||
|
||||
self.order_indicator.transfer(func, "trade_price")
|
||||
@@ -425,7 +429,11 @@ class Indicator:
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)]
|
||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||
# ~(np.NaN < 1e-8) -> ~(False) -> True
|
||||
# ~(np.nan < 1e-8) -> ~(False) -> True
|
||||
|
||||
# if price_s is empty
|
||||
if price_s.empty:
|
||||
return None, None
|
||||
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
if agg == "vwap":
|
||||
|
||||
0
qlib/cli/__init__.py
Normal file
0
qlib/cli/__init__.py
Normal file
@@ -4,6 +4,5 @@
|
||||
import fire
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(GetData)
|
||||
@@ -1,18 +1,20 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import fire
|
||||
from jinja2 import Template, meta
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import ruamel.yaml as yaml
|
||||
from qlib.config import C
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.utils.data import update_config
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.utils import set_log_with_config
|
||||
from qlib.utils.data import update_config
|
||||
|
||||
set_log_with_config(C.logging_config)
|
||||
logger = get_module_logger("qrun", logging.INFO)
|
||||
@@ -47,12 +49,45 @@ def sys_config(config, config_path):
|
||||
sys.path.append(str(Path(config_path).parent.resolve().absolute() / p))
|
||||
|
||||
|
||||
def render_template(config_path: str) -> str:
|
||||
"""
|
||||
render the template based on the environment
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str
|
||||
configuration path
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
the rendered content
|
||||
"""
|
||||
with open(config_path, "r") as f:
|
||||
config = f.read()
|
||||
# Set up the Jinja2 environment
|
||||
template = Template(config)
|
||||
|
||||
# Parse the template to find undeclared variables
|
||||
env = template.environment
|
||||
parsed_content = env.parse(config)
|
||||
variables = meta.find_undeclared_variables(parsed_content)
|
||||
|
||||
# Get context from os.environ according to the variables
|
||||
context = {var: os.getenv(var, "") for var in variables if var in os.environ}
|
||||
logger.info(f"Render the template with the context: {context}")
|
||||
|
||||
# Render the template with the context
|
||||
rendered_content = template.render(context)
|
||||
return rendered_content
|
||||
|
||||
|
||||
# workflow handler function
|
||||
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
"""
|
||||
This is a Qlib CLI entrance.
|
||||
User can run the whole Quant research workflow defined by a configure file
|
||||
- the code is located here ``qlib/workflow/cli.py`
|
||||
- the code is located here ``qlib/cli/run.py``
|
||||
|
||||
User can specify a base_config file in your workflow.yml file by adding "BASE_CONFIG_PATH".
|
||||
Qlib will load the configuration in BASE_CONFIG_PATH first, and the user only needs to update the custom fields
|
||||
@@ -67,8 +102,10 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
market: csi300
|
||||
|
||||
"""
|
||||
with open(config_path) as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
# Render the template
|
||||
rendered_yaml = render_template(config_path)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
config = yaml.load(rendered_yaml)
|
||||
|
||||
base_config_path = config.get("BASE_CONFIG_PATH", None)
|
||||
if base_config_path:
|
||||
@@ -90,7 +127,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}")
|
||||
|
||||
with open(path) as fp:
|
||||
base_config = yaml.safe_load(fp)
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
base_config = yaml.load(fp)
|
||||
logger.info(f"Load BASE_CONFIG_PATH succeed: {path.resolve()}")
|
||||
config = update_config(base_config, config)
|
||||
|
||||
@@ -10,6 +10,7 @@ Two modes are supported
|
||||
- server
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@@ -27,6 +28,38 @@ from qlib.constant import REG_CN, REG_US, REG_TW
|
||||
if TYPE_CHECKING:
|
||||
from qlib.utils.time import Freq
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class MLflowSettings(BaseSettings):
|
||||
uri: str = "file:" + str(Path(os.getcwd()).resolve() / "mlruns")
|
||||
default_exp_name: str = "Experiment"
|
||||
|
||||
|
||||
class QSettings(BaseSettings):
|
||||
"""
|
||||
Qlib's settings.
|
||||
It tries to provide a default settings for most of Qlib's components.
|
||||
But it would be a long journey to provide a comprehensive settings for all of Qlib's components.
|
||||
|
||||
Here is some design guidelines:
|
||||
- The priority of settings is
|
||||
- Actively passed-in settings, like `qlib.init(provider_uri=...)`
|
||||
- The default settings
|
||||
- QSettings tries to provide default settings for most of Qlib's components.
|
||||
"""
|
||||
|
||||
mlflow: MLflowSettings = MLflowSettings()
|
||||
provider_uri: str = "~/.qlib/qlib_data/cn_data"
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="QLIB_",
|
||||
env_nested_delimiter="_",
|
||||
)
|
||||
|
||||
|
||||
QSETTINGS = QSettings()
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, default_conf):
|
||||
@@ -173,7 +206,11 @@ _default_config = {
|
||||
"filters": ["field_not_found"],
|
||||
}
|
||||
},
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||
# Normally this should be set to `False` to avoid duplicated logging [1].
|
||||
# However, due to bug in pytest, it requires log message to propagate to root logger to be captured by `caplog` [2].
|
||||
# [1] https://github.com/microsoft/qlib/pull/1661
|
||||
# [2] https://github.com/pytest-dev/pytest/issues/3697
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"], "propagate": False}},
|
||||
# To let qlib work with other packages, we shouldn't disable existing loggers.
|
||||
# Note that this param is default to True according to the documentation of logging.
|
||||
"disable_existing_loggers": False,
|
||||
@@ -183,8 +220,8 @@ _default_config = {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "file:" + str(Path(os.getcwd()).resolve() / "mlruns"),
|
||||
"default_exp_name": "Experiment",
|
||||
"uri": QSETTINGS.mlflow.uri,
|
||||
"default_exp_name": QSETTINGS.mlflow.default_exp_name,
|
||||
},
|
||||
},
|
||||
"pit_record_type": {
|
||||
@@ -226,7 +263,7 @@ MODE_CONF = {
|
||||
},
|
||||
"client": {
|
||||
# config it in user's own code
|
||||
"provider_uri": "~/.qlib/qlib_data/cn_data",
|
||||
"provider_uri": QSETTINGS.provider_uri,
|
||||
# cache
|
||||
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
|
||||
# Disable cache by default. Avoid introduce advanced features for beginners
|
||||
|
||||
@@ -6,10 +6,11 @@ import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.utils.data import guess_horizon
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@@ -32,7 +33,7 @@ def _create_ts_slices(index, seq_len):
|
||||
assert index.is_monotonic_increasing, "index should be sorted"
|
||||
|
||||
# number of dates for each instrument
|
||||
sample_count_by_insts = index.to_series().groupby(level=0).size().values
|
||||
sample_count_by_insts = index.to_series().groupby(level=0, group_keys=False).size().values
|
||||
|
||||
# start index for each instrument
|
||||
start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1)
|
||||
@@ -130,6 +131,14 @@ class MTSDatasetH(DatasetH):
|
||||
input_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
if horizon == 0:
|
||||
# Try to guess horizon
|
||||
if isinstance(handler, (dict, str)):
|
||||
handler = init_instance_by_config(handler)
|
||||
assert "label" in getattr(handler.data_loader, "fields", None)
|
||||
label = handler.data_loader.fields["label"][0][0]
|
||||
horizon = guess_horizon([label])
|
||||
|
||||
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
assert memory_mode in ["sample", "daily"], "unsupported memory mode"
|
||||
assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.contrib.data.loader import Alpha158DL, Alpha360DL
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_callable_kwargs
|
||||
@@ -57,7 +58,7 @@ class Alpha360(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
@@ -66,7 +67,7 @@ class Alpha360(DataHandlerLP):
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": self.get_feature_config(),
|
||||
"feature": Alpha360DL.get_feature_config(),
|
||||
"label": kwargs.pop("label", self.get_label_config()),
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
@@ -82,57 +83,12 @@ class Alpha360(DataHandlerLP):
|
||||
data_loader=data_loader,
|
||||
learn_processors=learn_processors,
|
||||
infer_processors=infer_processors,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_label_config(self):
|
||||
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
|
||||
|
||||
@staticmethod
|
||||
def get_feature_config():
|
||||
# NOTE:
|
||||
# Alpha360 tries to provide a dataset with original price data
|
||||
# the original price data includes the prices and volume in the last 60 days.
|
||||
# To make it easier to learn models from this dataset, all the prices and volume
|
||||
# are normalized by the latest price and volume data ( dividing by $close, $volume)
|
||||
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
|
||||
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($close, %d)/$close" % i]
|
||||
names += ["CLOSE%d" % i]
|
||||
fields += ["$close/$close"]
|
||||
names += ["CLOSE0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($open, %d)/$close" % i]
|
||||
names += ["OPEN%d" % i]
|
||||
fields += ["$open/$close"]
|
||||
names += ["OPEN0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($high, %d)/$close" % i]
|
||||
names += ["HIGH%d" % i]
|
||||
fields += ["$high/$close"]
|
||||
names += ["HIGH0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($low, %d)/$close" % i]
|
||||
names += ["LOW%d" % i]
|
||||
fields += ["$low/$close"]
|
||||
names += ["LOW0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($vwap, %d)/$close" % i]
|
||||
names += ["VWAP%d" % i]
|
||||
fields += ["$vwap/$close"]
|
||||
names += ["VWAP0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % i]
|
||||
names += ["VOLUME%d" % i]
|
||||
fields += ["$volume/($volume+1e-12)"]
|
||||
names += ["VOLUME0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class Alpha360vwap(Alpha360):
|
||||
def get_label_config(self):
|
||||
@@ -153,7 +109,7 @@ class Alpha158(DataHandlerLP):
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
@@ -178,7 +134,7 @@ class Alpha158(DataHandlerLP):
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
process_type=process_type,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
@@ -190,242 +146,11 @@ class Alpha158(DataHandlerLP):
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
return self.parse_config_to_fields(conf)
|
||||
return Alpha158DL.get_feature_config(conf)
|
||||
|
||||
def get_label_config(self):
|
||||
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
|
||||
|
||||
@staticmethod
|
||||
def parse_config_to_fields(config):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % d if d != 0 else "$volume/($volume+1e-12)" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
|
||||
def use(x):
|
||||
return x not in exclude and (include is None or x in include)
|
||||
|
||||
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
|
||||
if use("ROC"):
|
||||
# https://www.investopedia.com/terms/r/rateofchange.asp
|
||||
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
|
||||
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
# The rate of close price change in the past d days, divided by latest close price to remove unit
|
||||
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
# The R-sqaure value of linear regression for the past d days, represent the trend linear
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
# The max price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
# The low price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
# Used with MIN and MAX
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
# Get the percentile of current close price in past d day's close price.
|
||||
# Represent the current price level comparing to past N days, add additional information to moving average.
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
# Represent the price position between upper and lower resistent price for past d days.
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
# The number of days between current date and previous highest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
# The number of days between current date and previous lowest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
# The time period between previous lowest-price date occur after highest price date.
|
||||
# Large value suggest downward momemtum.
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
# The correlation between absolute close price and log scaled trading volume
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
# The correlation between price change ratio and volume change ratio
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
# The percentage of days in past d days that price go up.
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
# The percentage of days in past d days that price go down.
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
# The diff between past up day and past down day
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
# The total gain / the absolute total price changed
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
# The total lose / the absolute total price changed
|
||||
# Can be derived from SUMP by SUMN = 1 - SUMP
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
# The diff ratio between total gain and total lose
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
# The standard deviation for volume in past d days.
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
# The volume weighted price change volatility
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
# The diff ratio between total volume increase and total volume decrease
|
||||
# RSI indicator for volume
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class Alpha158vwap(Alpha158):
|
||||
def get_label_config(self):
|
||||
|
||||
310
qlib/contrib/data/loader.py
Normal file
310
qlib/contrib/data/loader.py
Normal file
@@ -0,0 +1,310 @@
|
||||
from qlib.data.dataset.loader import QlibDataLoader
|
||||
|
||||
|
||||
class Alpha360DL(QlibDataLoader):
|
||||
"""Dataloader to get Alpha360"""
|
||||
|
||||
def __init__(self, config=None, **kwargs):
|
||||
_config = {
|
||||
"feature": self.get_feature_config(),
|
||||
}
|
||||
if config is not None:
|
||||
_config.update(config)
|
||||
super().__init__(config=_config, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_feature_config():
|
||||
# NOTE:
|
||||
# Alpha360 tries to provide a dataset with original price data
|
||||
# the original price data includes the prices and volume in the last 60 days.
|
||||
# To make it easier to learn models from this dataset, all the prices and volume
|
||||
# are normalized by the latest price and volume data ( dividing by $close, $volume)
|
||||
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
|
||||
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($close, %d)/$close" % i]
|
||||
names += ["CLOSE%d" % i]
|
||||
fields += ["$close/$close"]
|
||||
names += ["CLOSE0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($open, %d)/$close" % i]
|
||||
names += ["OPEN%d" % i]
|
||||
fields += ["$open/$close"]
|
||||
names += ["OPEN0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($high, %d)/$close" % i]
|
||||
names += ["HIGH%d" % i]
|
||||
fields += ["$high/$close"]
|
||||
names += ["HIGH0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($low, %d)/$close" % i]
|
||||
names += ["LOW%d" % i]
|
||||
fields += ["$low/$close"]
|
||||
names += ["LOW0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($vwap, %d)/$close" % i]
|
||||
names += ["VWAP%d" % i]
|
||||
fields += ["$vwap/$close"]
|
||||
names += ["VWAP0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % i]
|
||||
names += ["VOLUME%d" % i]
|
||||
fields += ["$volume/($volume+1e-12)"]
|
||||
names += ["VOLUME0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class Alpha158DL(QlibDataLoader):
|
||||
"""Dataloader to get Alpha158"""
|
||||
|
||||
def __init__(self, config=None, **kwargs):
|
||||
_config = {
|
||||
"feature": self.get_feature_config(),
|
||||
}
|
||||
if config is not None:
|
||||
_config.update(config)
|
||||
super().__init__(config=_config, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_feature_config(
|
||||
config={
|
||||
"kbar": {},
|
||||
"price": {
|
||||
"windows": [0],
|
||||
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % d if d != 0 else "$volume/($volume+1e-12)" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
|
||||
def use(x):
|
||||
return x not in exclude and (include is None or x in include)
|
||||
|
||||
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
|
||||
if use("ROC"):
|
||||
# https://www.investopedia.com/terms/r/rateofchange.asp
|
||||
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
|
||||
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
# The rate of close price change in the past d days, divided by latest close price to remove unit
|
||||
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
# The R-sqaure value of linear regression for the past d days, represent the trend linear
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
# The max price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
# The low price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
# Used with MIN and MAX
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
# Get the percentile of current close price in past d day's close price.
|
||||
# Represent the current price level comparing to past N days, add additional information to moving average.
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
# Represent the price position between upper and lower resistent price for past d days.
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
# The number of days between current date and previous highest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
# The number of days between current date and previous lowest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
# The time period between previous lowest-price date occur after highest price date.
|
||||
# Large value suggest downward momemtum.
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
# The correlation between absolute close price and log scaled trading volume
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
# The correlation between price change ratio and volume change ratio
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
# The percentage of days in past d days that price go up.
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
# The percentage of days in past d days that price go down.
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
# The diff between past up day and past down day
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
# The total gain / the absolute total price changed
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
# The total lose / the absolute total price changed
|
||||
# Can be derived from SUMP by SUMN = 1 - SUMP
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
# The diff ratio between total gain and total lose
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
# The standard deviation for volume in past d days.
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
# The volume weighted price change volatility
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
# The diff ratio between total volume increase and total volume decrease
|
||||
# RSI indicator for volume
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
@@ -55,14 +55,18 @@ class ConfigSectionProcessor(Processor):
|
||||
|
||||
# Label
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^LABEL")]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_label_norm)
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime", group_keys=False).apply(_label_norm)
|
||||
|
||||
# Features
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x**0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = (
|
||||
df_focus[cols].apply(lambda x: x**0.25).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x**0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = (
|
||||
df_focus[cols].apply(lambda x: x**0.5).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
)
|
||||
|
||||
_cols = [
|
||||
"KMID",
|
||||
@@ -88,25 +92,35 @@ class ConfigSectionProcessor(Processor):
|
||||
]
|
||||
pat = "|".join(["^" + x for x in _cols])
|
||||
cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin(["HIGH0", "LOW0"]))]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^RSQR")]
|
||||
df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MAX|^HIGH0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = (
|
||||
df_focus[cols]
|
||||
.apply(lambda x: (x - 1) ** 0.5)
|
||||
.groupby(level="datetime", group_keys=False)
|
||||
.apply(_feature_norm)
|
||||
)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MIN|^LOW0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = (
|
||||
df_focus[cols]
|
||||
.apply(lambda x: (1 - x) ** 0.5)
|
||||
.groupby(level="datetime", group_keys=False)
|
||||
.apply(_feature_norm)
|
||||
)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^CORR|^CORD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^WVMA")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
|
||||
df[selected_cols] = df_focus.values
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ def calc_long_short_prec(
|
||||
long precision and short precision in time level
|
||||
"""
|
||||
if is_alpha:
|
||||
label = label - label.mean(level=date_col)
|
||||
label = label - label.groupby(level=date_col, group_keys=False).mean()
|
||||
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
|
||||
raise ValueError("Need more instruments to calculate precision")
|
||||
|
||||
@@ -47,23 +47,25 @@ def calc_long_short_prec(
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
|
||||
group = df.groupby(level=date_col)
|
||||
group = df.groupby(level=date_col, group_keys=False)
|
||||
|
||||
def N(x):
|
||||
return int(len(x) * quantile)
|
||||
|
||||
# find the top/low quantile of prediction and treat them as long and short target
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label)
|
||||
|
||||
groupll = long.groupby(date_col)
|
||||
groupll = long.groupby(date_col, group_keys=False)
|
||||
l_dom = groupll.apply(lambda x: x > 0)
|
||||
l_c = groupll.count()
|
||||
|
||||
groups = short.groupby(date_col)
|
||||
groups = short.groupby(date_col, group_keys=False)
|
||||
s_dom = groups.apply(lambda x: x < 0)
|
||||
s_c = groups.count()
|
||||
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
|
||||
return (l_dom.groupby(date_col, group_keys=False).sum() / l_c), (
|
||||
s_dom.groupby(date_col, group_keys=False).sum() / s_c
|
||||
)
|
||||
|
||||
|
||||
def calc_long_short_return(
|
||||
@@ -100,7 +102,7 @@ def calc_long_short_return(
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
group = df.groupby(level=date_col)
|
||||
group = df.groupby(level=date_col, group_keys=False)
|
||||
|
||||
def N(x):
|
||||
return int(len(x) * quantile)
|
||||
@@ -173,8 +175,8 @@ def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
ic = df.groupby(date_col, group_keys=False).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col, group_keys=False).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
|
||||
@@ -7,7 +7,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import warnings
|
||||
from typing import Union
|
||||
from typing import Union, Literal
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..utils import get_date_range
|
||||
@@ -20,20 +20,17 @@ from ..data import D
|
||||
from ..config import C
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
|
||||
logger = get_module_logger("Evaluate")
|
||||
|
||||
|
||||
def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
def risk_analysis(r, N: int = None, freq: str = "day", mode: Literal["sum", "product"] = "sum"):
|
||||
"""Risk Analysis
|
||||
NOTE:
|
||||
The calculation of annulaized return is different from the definition of annualized return.
|
||||
The calculation of annualized return is different from the definition of annualized return.
|
||||
It is implemented by design.
|
||||
Qlib tries to cumulated returns by summation instead of production to avoid the cumulated curve being skewed exponentially.
|
||||
Qlib tries to cumulate returns by summation instead of production to avoid the cumulated curve being skewed exponentially.
|
||||
All the calculation of annualized returns follows this principle in Qlib.
|
||||
|
||||
TODO: add a parameter to enable calculating metrics with production accumulation of return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
r : pandas.Series
|
||||
@@ -42,11 +39,14 @@ def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
scaler for annualizing information_ratio (day: 252, week: 50, month: 12), at least one of `N` and `freq` should exist
|
||||
freq: str
|
||||
analysis frequency used for calculating the scaler, at least one of `N` and `freq` should exist
|
||||
mode: Literal["sum", "product"]
|
||||
the method by which returns are accumulated:
|
||||
- "sum": Arithmetic accumulation (linear returns).
|
||||
- "product": Geometric accumulation (compounded returns).
|
||||
"""
|
||||
|
||||
def cal_risk_analysis_scaler(freq):
|
||||
_count, _freq = Freq.parse(freq)
|
||||
# len(D.calendar(start_time='2010-01-01', end_time='2019-12-31', freq='day')) = 2384
|
||||
_freq_scaler = {
|
||||
Freq.NORM_FREQ_MINUTE: 240 * 238,
|
||||
Freq.NORM_FREQ_DAY: 238,
|
||||
@@ -62,11 +62,26 @@ def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
if N is None:
|
||||
N = cal_risk_analysis_scaler(freq)
|
||||
|
||||
mean = r.mean()
|
||||
std = r.std(ddof=1)
|
||||
annualized_return = mean * N
|
||||
if mode == "sum":
|
||||
mean = r.mean()
|
||||
std = r.std(ddof=1)
|
||||
annualized_return = mean * N
|
||||
max_drawdown = (r.cumsum() - r.cumsum().cummax()).min()
|
||||
elif mode == "product":
|
||||
cumulative_curve = (1 + r).cumprod()
|
||||
# geometric mean (compound annual growth rate)
|
||||
mean = cumulative_curve.iloc[-1] ** (1 / len(r)) - 1
|
||||
# volatility of log returns
|
||||
std = np.log(1 + r).std(ddof=1)
|
||||
|
||||
cumulative_return = cumulative_curve.iloc[-1] - 1
|
||||
annualized_return = (1 + cumulative_return) ** (N / len(r)) - 1
|
||||
# max percentage drawdown from peak cumulative product
|
||||
max_drawdown = (cumulative_curve / cumulative_curve.cummax() - 1).min()
|
||||
else:
|
||||
raise ValueError(f"risk_analysis accumulation mode {mode} is not supported. Expected `sum` or `product`.")
|
||||
|
||||
information_ratio = mean / std * np.sqrt(N)
|
||||
max_drawdown = (r.cumsum() - r.cumsum().cummax()).min()
|
||||
data = {
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
|
||||
@@ -3,5 +3,4 @@
|
||||
|
||||
from .data_selection import MetaTaskDS, MetaDatasetDS, MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaTaskDS", "MetaDatasetDS", "MetaModelDS"]
|
||||
|
||||
@@ -4,5 +4,4 @@
|
||||
from .dataset import MetaDatasetDS, MetaTaskDS
|
||||
from .model import MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaDatasetDS", "MetaTaskDS", "MetaModelDS"]
|
||||
|
||||
@@ -106,7 +106,7 @@ class InternalData:
|
||||
|
||||
def _calc_perf(self, pred, label):
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
df = df.groupby("datetime").corr(method="spearman")
|
||||
df = df.groupby("datetime", group_keys=False).corr(method="spearman")
|
||||
corr = df.loc(axis=0)[:, "pred"]["label"].droplevel(axis=0, level=-1)
|
||||
return corr
|
||||
|
||||
@@ -161,7 +161,7 @@ class MetaTaskDS(MetaTask):
|
||||
raise ValueError(f"Most of samples are dropped. Please check this task: {task}")
|
||||
|
||||
assert (
|
||||
d_test.groupby("datetime").size().shape[0] >= 5
|
||||
d_test.groupby("datetime", group_keys=False).size().shape[0] >= 5
|
||||
), "In this segment, this trading dates is less than 5, you'd better check the data."
|
||||
|
||||
sample_time_belong = np.zeros((d_train.shape[0], time_perf.shape[1]))
|
||||
@@ -243,7 +243,7 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
trunc_days: int = None,
|
||||
rolling_ext_days: int = 0,
|
||||
exp_name: Union[str, InternalData],
|
||||
segments: Union[Dict[Text, Tuple], float],
|
||||
segments: Union[Dict[Text, Tuple], float, str],
|
||||
hist_step_n: int = 10,
|
||||
task_mode: str = MetaTask.PROC_MODE_FULL,
|
||||
fill_method: str = "max",
|
||||
@@ -271,12 +271,16 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
- str: the name of the experiment to store the performance of data
|
||||
- InternalData: a prepared internal data
|
||||
segments: Union[Dict[Text, Tuple], float]
|
||||
the segments to divide data
|
||||
both left and right
|
||||
if the segment is a Dict
|
||||
the segments to divide data
|
||||
both left and right are included
|
||||
if segments is a float:
|
||||
the float represents the percentage of data for training
|
||||
if segments is a string:
|
||||
it will try its best to put its data in training and ensure that the date `segments` is in the test set
|
||||
hist_step_n: int
|
||||
length of historical steps for the meta infomation
|
||||
Number of steps of the data similarity information
|
||||
task_mode : str
|
||||
Please refer to the docs of MetaTask
|
||||
"""
|
||||
@@ -383,10 +387,30 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
if isinstance(self.segments, float):
|
||||
train_task_n = int(len(self.meta_task_l) * self.segments)
|
||||
if segment == "train":
|
||||
return self.meta_task_l[:train_task_n]
|
||||
train_tasks = self.meta_task_l[:train_task_n]
|
||||
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
|
||||
return train_tasks
|
||||
elif segment == "test":
|
||||
return self.meta_task_l[train_task_n:]
|
||||
test_tasks = self.meta_task_l[train_task_n:]
|
||||
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
|
||||
return test_tasks
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
elif isinstance(self.segments, str):
|
||||
train_tasks = []
|
||||
test_tasks = []
|
||||
for t in self.meta_task_l:
|
||||
test_end = t.task["dataset"]["kwargs"]["segments"]["test"][1]
|
||||
if test_end is None or pd.Timestamp(test_end) < pd.Timestamp(self.segments):
|
||||
train_tasks.append(t)
|
||||
else:
|
||||
test_tasks.append(t)
|
||||
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
|
||||
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
|
||||
if segment == "train":
|
||||
return train_tasks
|
||||
elif segment == "test":
|
||||
return test_tasks
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@@ -53,7 +53,12 @@ class MetaModelDS(MetaTaskModel):
|
||||
max_epoch=100,
|
||||
seed=43,
|
||||
alpha=0.0,
|
||||
loss_skip_thresh=50,
|
||||
):
|
||||
"""
|
||||
loss_skip_size: int
|
||||
The number of threshold to skip the loss calculation for each day.
|
||||
"""
|
||||
self.step = step
|
||||
self.hist_step_n = hist_step_n
|
||||
self.clip_method = clip_method
|
||||
@@ -63,6 +68,7 @@ class MetaModelDS(MetaTaskModel):
|
||||
self.max_epoch = max_epoch
|
||||
self.fitted = False
|
||||
self.alpha = alpha
|
||||
self.loss_skip_thresh = loss_skip_thresh
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
|
||||
@@ -88,12 +94,14 @@ class MetaModelDS(MetaTaskModel):
|
||||
criterion = nn.MSELoss()
|
||||
loss = criterion(pred, meta_input["y_test"])
|
||||
elif self.criterion == "ic_loss":
|
||||
criterion = ICLoss()
|
||||
criterion = ICLoss(self.loss_skip_thresh)
|
||||
try:
|
||||
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"], skip_size=50)
|
||||
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"])
|
||||
except ValueError as e:
|
||||
get_module_logger("MetaModelDS").warning(f"Exception `{e}` when calculating IC loss")
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unknown criterion: {self.criterion}")
|
||||
|
||||
assert not np.isnan(loss.detach().item()), "NaN loss!"
|
||||
|
||||
@@ -117,7 +125,11 @@ class MetaModelDS(MetaTaskModel):
|
||||
loss_l.setdefault(phase, []).append(running_loss)
|
||||
|
||||
pred_y_all = pd.concat(pred_y_all)
|
||||
ic = pred_y_all.groupby("datetime").apply(lambda df: df["pred"].corr(df["label"], method="spearman")).mean()
|
||||
ic = (
|
||||
pred_y_all.groupby("datetime", group_keys=False)
|
||||
.apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
.mean()
|
||||
)
|
||||
|
||||
R.log_metrics(**{f"loss/{phase}": running_loss, "step": epoch})
|
||||
R.log_metrics(**{f"ic/{phase}": ic, "step": epoch})
|
||||
|
||||
@@ -10,7 +10,11 @@ from qlib.log import get_module_logger
|
||||
|
||||
|
||||
class ICLoss(nn.Module):
|
||||
def forward(self, pred, y, idx, skip_size=50):
|
||||
def __init__(self, skip_size=50):
|
||||
super().__init__()
|
||||
self.skip_size = skip_size
|
||||
|
||||
def forward(self, pred, y, idx):
|
||||
"""forward.
|
||||
FIXME:
|
||||
- Some times it will be a slightly different from the result from `pandas.corr()`
|
||||
@@ -33,7 +37,7 @@ class ICLoss(nn.Module):
|
||||
skip_n = 0
|
||||
for start_i, end_i in zip(diff_point, diff_point[1:]):
|
||||
pred_focus = pred[start_i:end_i] # TODO: just for fake
|
||||
if pred_focus.shape[0] < skip_size:
|
||||
if pred_focus.shape[0] < self.skip_size:
|
||||
# skip some days which have very small amount of stock.
|
||||
skip_n += 1
|
||||
continue
|
||||
@@ -50,6 +54,7 @@ class ICLoss(nn.Module):
|
||||
)
|
||||
ic_all += ic_day
|
||||
if len(diff_point) - 1 - skip_n <= 0:
|
||||
__import__("ipdb").set_trace()
|
||||
raise ValueError("No enough data for calculating IC")
|
||||
if skip_n > 0:
|
||||
get_module_logger("ICLoss").info(
|
||||
|
||||
@@ -33,7 +33,7 @@ class CatBoostModel(Model, FeatureInt):
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
reweighter=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
|
||||
@@ -31,7 +31,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
sub_weights=None,
|
||||
epochs=100,
|
||||
early_stopping_rounds=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
|
||||
self.num_models = num_models # the number of sub-models
|
||||
@@ -166,7 +166,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
|
||||
# calculate weights
|
||||
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
|
||||
h_avg = h.groupby("bins")["h_value"].mean()
|
||||
h_avg = h.groupby("bins", group_keys=False, observed=False)["h_value"].mean()
|
||||
weights = pd.Series(np.zeros(N, dtype=float))
|
||||
for b in h_avg.index:
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay**k_th * h_avg[b] + 0.1)
|
||||
|
||||
@@ -51,7 +51,7 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
w = reweighter.reweight(df)
|
||||
else:
|
||||
raise ValueError("Unsupported reweighter type.")
|
||||
ds_l.append((lgb.Dataset(x.values, label=y, weight=w), key))
|
||||
ds_l.append((lgb.Dataset(x.values, label=y, weight=w, free_raw_data=False), key))
|
||||
return ds_l
|
||||
|
||||
def fit(
|
||||
@@ -109,8 +109,10 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
|
||||
if dtrain.empty:
|
||||
ds_l = self._prepare_data(dataset, reweighter)
|
||||
dtrain, _ = ds_l[0]
|
||||
|
||||
if dtrain.construct().num_data() == 0:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
self.model = lgb.train(
|
||||
|
||||
@@ -90,8 +90,14 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
l_name = df_train["label"].columns[0]
|
||||
# Convert label into alpha
|
||||
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
|
||||
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
|
||||
df_train.loc[:, ("label", l_name)] = (
|
||||
df_train.loc[:, ("label", l_name)]
|
||||
- df_train.loc[:, ("label", l_name)].groupby(level=0, group_keys=False).mean()
|
||||
)
|
||||
df_valid.loc[:, ("label", l_name)] = (
|
||||
df_valid.loc[:, ("label", l_name)]
|
||||
- df_valid.loc[:, ("label", l_name)].groupby(level=0, group_keys=False).mean()
|
||||
)
|
||||
|
||||
def mapping_fn(x):
|
||||
return 0 if x < 0 else 1
|
||||
|
||||
@@ -63,6 +63,7 @@ class LinearModel(Model):
|
||||
df_train = pd.concat([df_train, df_valid])
|
||||
except KeyError:
|
||||
get_module_logger("LinearModel").info("include_valid=True, but valid does not exist")
|
||||
df_train = df_train.dropna()
|
||||
if df_train.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
if reweighter is not None:
|
||||
|
||||
@@ -56,7 +56,7 @@ class ADARNN(Model):
|
||||
n_splits=2,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**_
|
||||
**_,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADARNN")
|
||||
@@ -154,10 +154,7 @@ class ADARNN(Model):
|
||||
self.model.train()
|
||||
criterion = nn.MSELoss()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||
len_loader = np.inf
|
||||
for loader in train_loader_list:
|
||||
if len(loader) < len_loader:
|
||||
len_loader = len(loader)
|
||||
out_weight_list = None
|
||||
for data_all in zip(*train_loader_list):
|
||||
# for data_all in zip(*train_loader_list):
|
||||
self.train_optimizer.zero_grad()
|
||||
@@ -217,8 +214,10 @@ class ADARNN(Model):
|
||||
def calc_all_metrics(pred):
|
||||
"""pred is a pandas dataframe that has two attributes: score (pred) and label (real)"""
|
||||
res = {}
|
||||
ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score))
|
||||
rank_ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score, method="spearman"))
|
||||
ic = pred.groupby(level="datetime", group_keys=False).apply(lambda x: x.label.corr(x.score))
|
||||
rank_ic = pred.groupby(level="datetime", group_keys=False).apply(
|
||||
lambda x: x.label.corr(x.score, method="spearman")
|
||||
)
|
||||
res["ic"] = ic.mean()
|
||||
res["icir"] = ic.mean() / ic.std()
|
||||
res["ric"] = rank_ic.mean()
|
||||
@@ -571,6 +570,7 @@ class TransferLoss:
|
||||
Returns:
|
||||
[tensor] -- transfer loss
|
||||
"""
|
||||
loss = None
|
||||
if self.loss_type in ("mmd_lin", "mmd"):
|
||||
mmdloss = MMD_loss(kernel_type="linear")
|
||||
loss = mmdloss(X, Y)
|
||||
|
||||
@@ -63,7 +63,7 @@ class ADD(Model):
|
||||
mu=0.05,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADD")
|
||||
@@ -226,7 +226,7 @@ class ADD(Model):
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_count = df.groupby(level=0, group_keys=False).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
@@ -349,7 +349,7 @@ class ADD(Model):
|
||||
return best_score
|
||||
|
||||
def gen_market_label(self, df, raw_label):
|
||||
market_label = raw_label.groupby("datetime").mean().squeeze()
|
||||
market_label = raw_label.groupby("datetime", group_keys=False).mean().squeeze()
|
||||
bins = [-np.inf, self.lo, self.hi, np.inf]
|
||||
market_label = pd.cut(market_label, bins, labels=False)
|
||||
market_label.name = ("market_return", "market_return")
|
||||
@@ -357,7 +357,7 @@ class ADD(Model):
|
||||
return df
|
||||
|
||||
def fit_thresh(self, train_label):
|
||||
market_label = train_label.groupby("datetime").mean().squeeze()
|
||||
market_label = train_label.groupby("datetime", group_keys=False).mean().squeeze()
|
||||
self.lo, self.hi = market_label.quantile([1 / 3, 2 / 3])
|
||||
|
||||
def fit(
|
||||
|
||||
@@ -52,7 +52,7 @@ class ALSTM(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ALSTM")
|
||||
|
||||
@@ -56,7 +56,7 @@ class ALSTM(Model):
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ALSTM")
|
||||
@@ -160,6 +160,10 @@ class ALSTM(Model):
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
elif self.metric == "mse":
|
||||
mask = ~torch.isnan(label)
|
||||
weight = torch.ones_like(label)
|
||||
return -self.mse(pred[mask], label[mask], weight[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class GATs(Model):
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GATs")
|
||||
@@ -163,7 +163,7 @@ class GATs(Model):
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_count = df.groupby(level=0, group_keys=False).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
|
||||
@@ -27,7 +27,9 @@ class DailyBatchSampler(Sampler):
|
||||
def __init__(self, data_source):
|
||||
self.data_source = data_source
|
||||
# calculate number of samples in each batch
|
||||
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values
|
||||
self.daily_count = (
|
||||
pd.Series(index=self.data_source.get_index()).groupby("datetime", group_keys=False).size().values
|
||||
)
|
||||
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
|
||||
self.daily_index[0] = 0
|
||||
|
||||
@@ -73,7 +75,7 @@ class GATs(Model):
|
||||
GPU=0,
|
||||
n_jobs=10,
|
||||
seed=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GATs")
|
||||
@@ -181,7 +183,7 @@ class GATs(Model):
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_count = df.groupby(level=0, group_keys=False).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user