mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-15 18:28:24 +08:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5379c520f | ||
|
|
7ccf3f7658 | ||
|
|
2c21b8089a | ||
|
|
b87a2c294d | ||
|
|
3097dcc995 | ||
|
|
2fb9380b34 | ||
|
|
8fd6d5ca7e | ||
|
|
69bb755f37 | ||
|
|
39634b2158 | ||
|
|
16acb76aba | ||
|
|
4e0f5d5ec9 | ||
|
|
50c32ac15f | ||
|
|
80982f8904 | ||
|
|
477160e4ac | ||
|
|
3472e82d5c | ||
|
|
cb285bccac | ||
|
|
2e9a00a9f7 | ||
|
|
d631b4450b | ||
|
|
0826879481 | ||
|
|
2b41782f0c | ||
|
|
ac3fe9476f | ||
|
|
66c36226aa | ||
|
|
bb7ab1cf14 | ||
|
|
3dc5a7d299 | ||
|
|
7d66e4b788 | ||
|
|
213eb6c2cd | ||
|
|
94d138ec23 | ||
|
|
f26b341736 | ||
|
|
136b2ddf9a | ||
|
|
7095e755fa | ||
|
|
2d05a705e3 |
21
.commitlintrc.js
Normal file
21
.commitlintrc.js
Normal file
@@ -0,0 +1,21 @@
|
||||
module.exports = {
|
||||
extends: ["@commitlint/config-conventional"],
|
||||
rules: {
|
||||
// Configuration Format: [level, applicability, value]
|
||||
// level: Error level, usually expressed as a number:
|
||||
// 0 - disable rule
|
||||
// 1 - Warning (does not prevent commits)
|
||||
// 2 - Error (will block the commit)
|
||||
// applicability: the conditions under which the rule applies, commonly used values:
|
||||
// “always” - always apply the rule
|
||||
// “never” - never apply the rule
|
||||
// value: the specific value of the rule, e.g. a maximum length of 100.
|
||||
// Refs: https://commitlint.js.org/reference/rules-configuration.html
|
||||
"header-max-length": [2, "always", 100],
|
||||
"type-enum": [
|
||||
2,
|
||||
"always",
|
||||
["build", "chore", "ci", "docs", "feat", "fix", "perf", "refactor", "revert", "style", "test", "Release-As"]
|
||||
]
|
||||
}
|
||||
};
|
||||
13
.github/PULL_REQUEST_TEMPLATE.md
vendored
13
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,3 +1,16 @@
|
||||
<!--- Thank you for submitting a Pull Request! In order to make our work smoother. -->
|
||||
<!--- please make sure your Pull Request meets the following requirements: -->
|
||||
<!--- 1. Provide a general summary of your changes in the Title above; -->
|
||||
<!--- 2. Add appropriate prefixes to titles, such as `build:`, `chore:`, `ci:`, `docs:`, `feat:`, `fix:`, `perf:`, `refactor:`, `revert:`, `style:`, `test:`(Ref: https://www.conventionalcommits.org/). -->
|
||||
<!--- Category: -->
|
||||
<!--- Patch Updates: `fix:` -->
|
||||
<!--- Example: fix(auth): correct login validation issue -->
|
||||
<!--- minor update (introduces new functionality): `feat` -->
|
||||
<!--- Example: feature(parser): add ability to parse arrays -->
|
||||
<!--- major update(destructive update): Include BREAKING CHANGE in the commit message footer, or add `! ` in the commit footer to indicate that there is a destructive update. -->
|
||||
<!--- Example: feat(auth)! : remove support for old authentication method -->
|
||||
<!--- Other updates: `build:`, `chore:`, `ci:`, `docs:`, `perf:`, `refactor:`, `revert:`, `style:`, `test:`. -->
|
||||
|
||||
<!--- Provide a general summary of your changes in the Title above -->
|
||||
|
||||
## Description
|
||||
|
||||
6
.github/labeler.yml
vendored
6
.github/labeler.yml
vendored
@@ -1,6 +0,0 @@
|
||||
documentation:
|
||||
- 'docs/**/*'
|
||||
- '**/*.md'
|
||||
|
||||
waiting for triage:
|
||||
- any: ['**/*', '!docs/**/*', '!**/*.md']
|
||||
14
.github/workflows/labeler.yml
vendored
14
.github/workflows/labeler.yml
vendored
@@ -1,14 +0,0 @@
|
||||
name: "Add label automatically"
|
||||
on:
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v4
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
35
.github/workflows/lint_title.yml
vendored
Normal file
35
.github/workflows/lint_title.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Lint pull request title
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- synchronize
|
||||
- reopened
|
||||
- edited
|
||||
|
||||
concurrency:
|
||||
cancel-in-progress: true
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
jobs:
|
||||
lint-title:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# This step is necessary because the lint title uses the .commitlintrc.js file in the project root directory.
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '16'
|
||||
|
||||
- name: Install commitlint
|
||||
run: npm install --save-dev @commitlint/{config-conventional,cli}
|
||||
|
||||
- name: Validate PR Title with commitlint
|
||||
env:
|
||||
BODY: ${{ github.event.pull_request.title }}
|
||||
run: |
|
||||
echo "$BODY" | npx commitlint --config .commitlintrc.js
|
||||
65
.github/workflows/python-publish.yml
vendored
65
.github/workflows/python-publish.yml
vendored
@@ -1,65 +0,0 @@
|
||||
# This workflows will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
deploy_with_bdist_wheel:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-13, macos-latest]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
exclude:
|
||||
- os: macos-13
|
||||
python-version: "3.11"
|
||||
- os: macos-13
|
||||
python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
make dev
|
||||
- name: Build wheel on ${{ matrix.os }}
|
||||
run: |
|
||||
make build
|
||||
- name: Upload to PyPi
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine check dist/*.whl
|
||||
twine upload dist/*.whl --verbose
|
||||
|
||||
deploy_with_manylinux:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Build wheel on Linux
|
||||
uses: RalfG/python-wheels-manylinux-build@v0.7.1-manylinux2014_x86_64
|
||||
with:
|
||||
python-versions: 'cp38-cp38 cp39-cp39 cp310-cp310 cp311-cp311 cp312-cp312'
|
||||
build-requirements: 'numpy cython'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install twine
|
||||
- name: Upload to PyPi
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine check dist/pyqlib-*-manylinux*.whl
|
||||
twine upload dist/pyqlib-*-manylinux*.whl --verbose
|
||||
22
.github/workflows/release-drafter.yml
vendored
22
.github/workflows/release-drafter.yml
vendored
@@ -1,22 +0,0 @@
|
||||
name: Release Drafter
|
||||
|
||||
on:
|
||||
push:
|
||||
# branches to consider in the event; optional, defaults to all
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
update_release_draft:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# Drafts your next Release notes as Pull Requests are merged into "master"
|
||||
- uses: release-drafter/release-drafter@v5.11.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
107
.github/workflows/release.yml
vendored
Normal file
107
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
release_created: ${{ steps.release_please.outputs.release_created }}
|
||||
|
||||
steps:
|
||||
- name: Release please
|
||||
id: release_please
|
||||
uses: googleapis/release-please-action@v4
|
||||
with:
|
||||
token: ${{ secrets.PAT }}
|
||||
release-type: simple
|
||||
|
||||
deploy_with_manylinux:
|
||||
needs: release
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Build wheel on Linux
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
uses: RalfG/python-wheels-manylinux-build@v0.7.1-manylinux2014_x86_64
|
||||
with:
|
||||
python-versions: 'cp38-cp38 cp39-cp39 cp310-cp310 cp311-cp311 cp312-cp312'
|
||||
build-requirements: 'numpy cython'
|
||||
|
||||
- name: Install dependencies
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
run: |
|
||||
python -m pip install twine
|
||||
|
||||
- name: Upload to PyPi
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.TESTPYPI }}
|
||||
run: |
|
||||
twine check dist/pyqlib-*-manylinux*.whl
|
||||
twine upload --repository-url https://test.pypi.org/legacy/ dist/pyqlib-*-manylinux*.whl --verbose
|
||||
|
||||
deploy_with_bdist_wheel:
|
||||
needs: release
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# After testing, the whl files of pyqlib built by macos-14 and macos-15 in python environments of 3.8, 3.9, 3.10, 3.11, 3.12,
|
||||
# the filenames are exactly duplicated, which will result in the duplicated whl files not being able to be uploaded to pypi,
|
||||
# so we chose to just keep the latest macos-latest. macos-latest currently points to macos-15.
|
||||
# Also, macos-13 will stop being supported on 2025-11-14.
|
||||
# Refs: https://github.blog/changelog/2025-07-11-upcoming-changes-to-macos-hosted-runners-macos-latest-migration-and-xcode-support-policy-updates/
|
||||
os: [windows-latest, macos-latest]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
run: |
|
||||
make dev
|
||||
|
||||
- name: Build wheel on ${{ matrix.os }}
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
run: |
|
||||
make build
|
||||
|
||||
- name: Upload to PyPi
|
||||
if: needs.release.outputs.release_created == 'true'
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.TESTPYPI }}
|
||||
run: |
|
||||
twine check dist/*.whl
|
||||
twine upload --repository-url https://test.pypi.org/legacy/ dist/*.whl --verbose
|
||||
22
.github/workflows/test_qlib_from_pip.yml
vendored
22
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -1,5 +1,9 @@
|
||||
name: Test qlib from pip
|
||||
|
||||
concurrency:
|
||||
cancel-in-progress: true
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
@@ -21,7 +25,9 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Test qlib from pip
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
@@ -32,19 +38,9 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
# Will cancel this step when the next qlib version is released. The current qlib version is: 0.9.6
|
||||
- name: Installing pywinpt for windows
|
||||
if: ${{ matrix.os == 'windows-latest' }}
|
||||
run: |
|
||||
python -m pip install pywinpty --only-binary=:all:
|
||||
|
||||
# # joblib was released on 2025-05-04 with version 1.5.0, in which _backend_args was removed and replaced by _backend_kwargs.
|
||||
# This change caused the application to fail, so the version of joblib is restricted here.
|
||||
# This restriction will be removed in the next release. The current qlib version is: 0.9.6
|
||||
- name: Qlib installation test
|
||||
run: |
|
||||
python -m pip install pyqlib
|
||||
python -m pip install "joblib<=1.4.2"
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
@@ -53,12 +49,10 @@ jobs:
|
||||
brew install libomp || brew reinstall libomp
|
||||
python -m pip install --no-binary=:all: lightgbm
|
||||
|
||||
# When the new version is released it should be changed to:
|
||||
# python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
- name: Downloads dependencies data
|
||||
run: |
|
||||
cd ..
|
||||
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
cd qlib
|
||||
|
||||
- name: Test workflow by config
|
||||
|
||||
11
.github/workflows/test_qlib_from_source.yml
vendored
11
.github/workflows/test_qlib_from_source.yml
vendored
@@ -1,5 +1,9 @@
|
||||
name: Test qlib from source
|
||||
|
||||
concurrency:
|
||||
cancel-in-progress: true
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
@@ -22,7 +26,9 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
@@ -74,8 +80,11 @@ jobs:
|
||||
run: |
|
||||
make mypy
|
||||
|
||||
# Due to issues that cannot be automatically fixed when running `nbqa black . -l 120 --check --diff` on Jupyter notebooks,
|
||||
# we reverted to a version of `black` earlier than 26.1.0 before performing the checks.
|
||||
- name: Check Qlib ipynb with nbqa
|
||||
run: |
|
||||
python -m pip install "black<26.1"
|
||||
make nbqa
|
||||
|
||||
- name: Test data downloads
|
||||
|
||||
10
.github/workflows/test_qlib_from_source_slow.yml
vendored
10
.github/workflows/test_qlib_from_source_slow.yml
vendored
@@ -1,5 +1,9 @@
|
||||
name: Test qlib from source slow
|
||||
|
||||
concurrency:
|
||||
cancel-in-progress: true
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
@@ -22,7 +26,9 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source slow
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
@@ -37,8 +43,6 @@ jobs:
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
# install.sh file contents from: https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh
|
||||
# brew_install.sh file contents from: https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -22,6 +22,7 @@ dist/
|
||||
qlib/VERSION.txt
|
||||
qlib/data/_libs/expanding.cpp
|
||||
qlib/data/_libs/rolling.cpp
|
||||
qlib/_version.py
|
||||
examples/estimator/estimator_example/
|
||||
examples/rl/data/
|
||||
examples/rl/checkpoints/
|
||||
|
||||
0
CHANGELOG.md
Normal file
0
CHANGELOG.md
Normal file
25
Makefile
25
Makefile
@@ -74,34 +74,37 @@ prerequisite:
|
||||
|
||||
# Install the package in editable mode.
|
||||
dependencies:
|
||||
python -m pip install -e .
|
||||
python -m pip install --no-cache-dir -e .
|
||||
|
||||
lightgbm:
|
||||
python -m pip install lightgbm --prefer-binary
|
||||
python -m pip install --no-cache-dir lightgbm --prefer-binary
|
||||
|
||||
rl:
|
||||
python -m pip install -e .[rl]
|
||||
python -m pip install --no-cache-dir -e .[rl]
|
||||
|
||||
develop:
|
||||
python -m pip install -e .[dev]
|
||||
python -m pip install --no-cache-dir -e .[dev]
|
||||
|
||||
lint:
|
||||
python -m pip install -e .[lint]
|
||||
python -m pip install --no-cache-dir -e .[lint]
|
||||
|
||||
docs:
|
||||
python -m pip install -e .[docs]
|
||||
python -m pip install --no-cache-dir -e .[docs]
|
||||
|
||||
package:
|
||||
python -m pip install -e .[package]
|
||||
python -m pip install --no-cache-dir -e .[package]
|
||||
|
||||
test:
|
||||
python -m pip install -e .[test]
|
||||
python -m pip install --no-cache-dir -e .[test]
|
||||
|
||||
analysis:
|
||||
python -m pip install -e .[analysis]
|
||||
python -m pip install --no-cache-dir -e .[analysis]
|
||||
|
||||
client:
|
||||
python -m pip install --no-cache-dir -e .[client]
|
||||
|
||||
all:
|
||||
python -m pip install -e .[pywinpty,dev,lint,docs,package,test,analysis,rl]
|
||||
python -m pip install --no-cache-dir -e .[pywinpty,dev,lint,docs,package,test,analysis,rl]
|
||||
|
||||
install: prerequisite dependencies
|
||||
|
||||
@@ -113,7 +116,7 @@ dev: prerequisite all
|
||||
|
||||
# Check lint with black.
|
||||
black:
|
||||
black . -l 120 --check --diff
|
||||
black . -l 120 --check --diff --exclude qlib/_version.py
|
||||
|
||||
# Check code folder with pylint.
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
|
||||
11
README.md
11
README.md
@@ -17,14 +17,13 @@ We are excited to announce the release of **RD-Agent**📢, a powerful tool that
|
||||
|
||||
RD-Agent is now available on [GitHub](https://github.com/microsoft/RD-Agent), and we welcome your star🌟!
|
||||
|
||||
To learn more, please visit our [♾️Demo page](https://rdagent.azurewebsites.net/). Here, you will find demo videos in both English and Chinese to help you better understand the scenario and usage of RD-Agent.
|
||||
To learn more, please visit the [RD-Agent repository](https://github.com/microsoft/RD-Agent). We have prepared several public demo videos for you:
|
||||
|
||||
We have prepared several demo videos for you:
|
||||
| Scenario | Demo video (English) | Demo video (中文) |
|
||||
| -- | ------ | ------ |
|
||||
| Quant Factor Mining | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=zh) |
|
||||
| Quant Factor Mining from reports | [Link](https://rdagent.azurewebsites.net/report_factor?lang=en) | [Link](https://rdagent.azurewebsites.net/report_factor?lang=zh) |
|
||||
| Quant Model Optimization | [Link](https://rdagent.azurewebsites.net/model_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/model_loop?lang=zh) |
|
||||
| Quant Factor Mining | [YouTube](https://www.youtube.com/watch?v=X4DK2QZKaKY&t=6s) | [YouTube](https://www.youtube.com/watch?v=X4DK2QZKaKY&t=6s) |
|
||||
| Quant Factor Mining from reports | [YouTube](https://www.youtube.com/watch?v=ECLTXVcSx-c) | [YouTube](https://www.youtube.com/watch?v=ECLTXVcSx-c) |
|
||||
| Quant Model Optimization | [YouTube](https://www.youtube.com/watch?v=dm0dWL49Bc0&t=104s) | [YouTube](https://www.youtube.com/watch?v=dm0dWL49Bc0&t=104s) |
|
||||
|
||||
- 📃**Paper**: [R&D-Agent-Quant: A Multi-Agent Framework for Data-Centric Factors and Model Joint Optimization](https://arxiv.org/abs/2505.15155)
|
||||
- 👾**Code**: https://github.com/microsoft/RD-Agent/
|
||||
@@ -324,7 +323,7 @@ We recommend users to prepare their own data if they have a high-quality dataset
|
||||
```
|
||||
2. Start a new Docker container
|
||||
```bash
|
||||
docker run -it --name <container name> -v <Mounted local directory>:/app qlib_image_stable
|
||||
docker run -it --name <container name> -v <Mounted local directory>:/app pyqlib/qlib_image_stable:stable
|
||||
```
|
||||
3. At this point you are in the docker environment and can run the qlib scripts. An example:
|
||||
```bash
|
||||
|
||||
@@ -42,7 +42,7 @@ Example
|
||||
|
||||
.. math::
|
||||
|
||||
DEA = \frac{EMA(DIF, 9)}{CLOSE}
|
||||
DEA = EMA(DIF, 9)
|
||||
|
||||
Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
|
||||
@@ -51,7 +51,7 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data.dataset.loader import QlibDataLoader
|
||||
>> MACD_EXP = '(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'
|
||||
>> MACD_EXP = '2 * ((EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9))'
|
||||
>> fields = [MACD_EXP] # MACD
|
||||
>> names = ['MACD']
|
||||
>> labels = ['Ref($close, -2)/Ref($close, -1) - 1'] # label
|
||||
@@ -66,17 +66,17 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
feature label
|
||||
MACD LABEL
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 -0.011547 -0.019672
|
||||
SH600004 0.002745 -0.014721
|
||||
SH600006 0.010133 0.002911
|
||||
SH600008 -0.001113 0.009818
|
||||
SH600009 0.025878 -0.017758
|
||||
2010-01-04 SH600000 0.008781 -0.019672
|
||||
SH600004 0.006699 -0.014721
|
||||
SH600006 0.005714 0.002911
|
||||
SH600008 0.000798 0.009818
|
||||
SH600009 0.017015 -0.017758
|
||||
... ... ...
|
||||
2017-12-29 SZ300124 0.007306 -0.005074
|
||||
SZ300136 -0.013492 0.056352
|
||||
SZ300144 -0.000966 0.011853
|
||||
SZ300251 0.004383 0.021739
|
||||
SZ300315 -0.030557 0.012455
|
||||
2017-12-29 SZ300124 0.015071 -0.005074
|
||||
SZ300136 -0.015466 0.056352
|
||||
SZ300144 0.013082 0.011853
|
||||
SZ300251 -0.001026 0.021739
|
||||
SZ300315 -0.007559 0.012455
|
||||
|
||||
Reference
|
||||
=========
|
||||
|
||||
@@ -21,8 +21,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from importlib.metadata import version as ver
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
@@ -63,9 +62,9 @@ author = "Microsoft"
|
||||
# built documents.
|
||||
#
|
||||
# The short X.Y version.
|
||||
version = pkg_resources.get_distribution("pyqlib").version
|
||||
version = ver("pyqlib")
|
||||
# The full version, including alpha/beta/rc tags.
|
||||
release = pkg_resources.get_distribution("pyqlib").version
|
||||
release = version
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
|
||||
@@ -129,7 +129,7 @@ For example, it looks quite long and complicated:
|
||||
|
||||
|
||||
But using string is not the only way to implement the expression. You can also implement expression by code.
|
||||
Here is an exmaple which does the same thing as above examples.
|
||||
Here is an example which does the same thing as above examples.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -71,7 +71,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
)
|
||||
|
||||
- Override the `predict` method
|
||||
- The parameters must include the parameter `dataset`, which will be userd to get the test dataset.
|
||||
- The parameters must include the parameter `dataset`, which will be used to get the test dataset.
|
||||
- Return the `prediction score`.
|
||||
- Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_ for the parameter types of the fit method.
|
||||
- Code Example: In the following example, users need to use `LightGBM` to predict the label(such as `preds`) of test data `x_test` and return it.
|
||||
|
||||
@@ -19,7 +19,6 @@ from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
# To register new datasets, please add them here.
|
||||
ALLOW_DATASET = ["Alpha158", "Alpha360"]
|
||||
# To register new datasets, please add their configurations here.
|
||||
|
||||
@@ -8,7 +8,6 @@ import pandas as pd
|
||||
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
sns.set(color_codes=True)
|
||||
plt.rcParams["font.sans-serif"] = "SimHei"
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
@@ -18,7 +19,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
# +
|
||||
with open("./internal_data_s20.pkl", "rb") as f:
|
||||
data = pickle.load(f)
|
||||
data = restricted_pickle_load(f)
|
||||
|
||||
data.data_ic_df.columns.names = ["start_date", "end_date"]
|
||||
|
||||
@@ -52,7 +53,7 @@ pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].rolling(5).mean(
|
||||
|
||||
# +
|
||||
with open("./tasks_s20.pkl", "rb") as f:
|
||||
tasks = pickle.load(f)
|
||||
tasks = restricted_pickle_load(f)
|
||||
|
||||
task_df = {}
|
||||
for t in tasks:
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
import fire
|
||||
|
||||
import qlib
|
||||
import pickle
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.config import HIGH_FREQ_CONFIG
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
@@ -125,10 +125,10 @@ class HighfreqWorkflow:
|
||||
del dataset, dataset_backtest
|
||||
##=============reload dataset=============
|
||||
with open("dataset.pkl", "rb") as file_dataset:
|
||||
dataset = pickle.load(file_dataset)
|
||||
dataset = restricted_pickle_load(file_dataset)
|
||||
|
||||
with open("dataset_backtest.pkl", "rb") as file_dataset_backtest:
|
||||
dataset_backtest = pickle.load(file_dataset_backtest)
|
||||
dataset_backtest = restricted_pickle_load(file_dataset_backtest)
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reinit dataset=============
|
||||
|
||||
@@ -9,7 +9,6 @@ from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
|
||||
@@ -95,7 +95,6 @@ pos 0.000000
|
||||
[1706497:MainThread](2021-12-07 14:08:30,627) INFO - qlib.timer - [log.py:113] - Time cost: 0.014s | waiting `async_log` Done
|
||||
"""
|
||||
|
||||
|
||||
from copy import deepcopy
|
||||
import qlib
|
||||
import fire
|
||||
|
||||
@@ -7,6 +7,7 @@ There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained models to the `online` models.
|
||||
Next, we will finish updating online predictions.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import fire
|
||||
import qlib
|
||||
|
||||
@@ -27,7 +27,7 @@ pip install arctic # NOTE: pip may fail to resolve the right package dependency
|
||||
2. Please follow following steps to download example data
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
gdown https://drive.google.com/uc?id=15nZF7tFT_eKVZAcMFL1qPS4jGyJflH7e # Proxies may be necessary here.
|
||||
gdown https://drive.google.com/uc?id=15FuUqWn2rkCi8uhJYGEQWKakcEqLJNDG # Proxies may be necessary here.
|
||||
python ../../scripts/get_data.py _unzip --file_path highfreq_orderbook_example_data.zip --target_dir .
|
||||
```
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ NOTE:
|
||||
- !!!!!!!!!!!!!!!TODO!!!!!!!!!!!!!!!!!!!:
|
||||
- Its structure is not well designed and very ugly, your contribution is welcome to make importing dataset easier
|
||||
"""
|
||||
|
||||
from datetime import date, datetime as dt
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import pickle
|
||||
import os
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
for tag in ["test", "valid"]:
|
||||
files = os.listdir(os.path.join("data/orders/", tag))
|
||||
dfs = []
|
||||
for f in tqdm(files):
|
||||
df = pickle.load(open(os.path.join("data/orders/", tag, f), "rb"))
|
||||
with open(os.path.join("data/orders/", tag, f), "rb") as fr:
|
||||
df = restricted_pickle_load(fr)
|
||||
df = df.drop(["$close0"], axis=1)
|
||||
dfs.append(df)
|
||||
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import pickle
|
||||
|
||||
from datetime import datetime
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class RollingDataWorkflow:
|
||||
|
||||
def _load_pre_handler(self, path):
|
||||
with open(path, "rb") as file_dataset:
|
||||
pre_handler = pickle.load(file_dataset)
|
||||
pre_handler = restricted_pickle_load(file_dataset)
|
||||
return pre_handler
|
||||
|
||||
def rolling_process(self):
|
||||
|
||||
@@ -7,6 +7,7 @@ Qlib provides two kinds of interfaces.
|
||||
|
||||
The interface of (1) is `qrun XXX.yaml`. The interface of (2) is script like this, which nearly does the same thing as `qrun XXX.yaml`
|
||||
"""
|
||||
|
||||
import qlib
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
@@ -15,7 +16,6 @@ from qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "cython", "numpy>=1.24.0"]
|
||||
requires = ["setuptools", "setuptools-scm", "cython", "numpy>=1.24.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
@@ -22,11 +22,15 @@ dynamic = ["version"]
|
||||
description = "A Quantitative-research Platform"
|
||||
requires-python = ">=3.8.0"
|
||||
readme = {file = "README.md", content-type = "text/markdown"}
|
||||
license = { text = "MIT" }
|
||||
|
||||
dependencies = [
|
||||
"pyyaml",
|
||||
"numpy",
|
||||
"pandas>=0.24",
|
||||
# Since version 1.1.0, pandas supports the ffill and bfill methods.
|
||||
# Since version 2.1.0, pandas has deprecated the method parameter of the fillna method.
|
||||
# qlib has updated the fillna method in PR 1987 and limited the minimum version of pandas.
|
||||
"pandas>=1.1",
|
||||
# I encoutered an Error that the set_uri does not work when downloading artifacts in mlflow 3.1.1;
|
||||
# But earlier versions of mlflow does not have this problem.
|
||||
# But when I switch to 2.*.* version, another error occurs, which is even more strange...
|
||||
@@ -49,6 +53,7 @@ dependencies = [
|
||||
"nbconvert",
|
||||
"pyarrow",
|
||||
"pydantic-settings",
|
||||
"setuptools-scm",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -64,6 +69,7 @@ rl = [
|
||||
"torch",
|
||||
"numpy<2.0.0",
|
||||
]
|
||||
|
||||
lint = [
|
||||
"black",
|
||||
"pylint",
|
||||
@@ -96,6 +102,10 @@ analysis = [
|
||||
"plotly",
|
||||
"statsmodels",
|
||||
]
|
||||
client = [
|
||||
"python-socketio<6",
|
||||
"tables",
|
||||
]
|
||||
|
||||
# In the process of releasing a new version, when checking the manylinux package with twine, an error is reported:
|
||||
# InvalidDistribution: Invalid distribution metadata: unrecognized or malformed field 'license-file'
|
||||
@@ -108,3 +118,8 @@ license-files = []
|
||||
|
||||
[project.scripts]
|
||||
qrun = "qlib.cli.run:run"
|
||||
|
||||
[tool.setuptools_scm]
|
||||
local_scheme = "no-local-version"
|
||||
version_scheme = "guess-next-dev"
|
||||
write_to = "qlib/_version.py"
|
||||
|
||||
@@ -2,15 +2,22 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.9.7"
|
||||
from setuptools_scm import get_version
|
||||
|
||||
try:
|
||||
from ._version import version as __version__
|
||||
except ImportError:
|
||||
__version__ = get_version(root="..", relative_to=__file__)
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
import re
|
||||
from typing import Union
|
||||
from ruamel.yaml import YAML
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
from typing import Union
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
@@ -136,7 +143,10 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
_command_log = [line for line in _command_log if _remote_uri in line]
|
||||
if len(_command_log) > 0:
|
||||
for _c in _command_log:
|
||||
_temp_mount = _c.decode("utf-8").split(" ")[2]
|
||||
if isinstance(_c, str):
|
||||
_temp_mount = _c.split(" ")[2]
|
||||
else:
|
||||
_temp_mount = _c.decode("utf-8").split(" ")[2]
|
||||
_temp_mount = _temp_mount[:-1] if _temp_mount.endswith("/") else _temp_mount
|
||||
if _temp_mount == _mount_path:
|
||||
_is_mount = True
|
||||
|
||||
@@ -18,7 +18,6 @@ from tqdm.auto import tqdm
|
||||
|
||||
from ..utils.time import Freq
|
||||
|
||||
|
||||
PORT_METRIC = Dict[str, Tuple[pd.DataFrame, dict]]
|
||||
INDICATOR_METRIC = Dict[str, Tuple[pd.DataFrame, Indicator]]
|
||||
|
||||
|
||||
@@ -281,13 +281,13 @@ def brinson_pa(
|
||||
|
||||
stock_group_field = stock_df[group_field].unstack().T
|
||||
# FIXME: some attributes of some suspend stock is NAN.
|
||||
stock_group_field = stock_group_field.fillna(method="ffill")
|
||||
stock_group_field = stock_group_field.ffill()
|
||||
stock_group_field = stock_group_field.loc[start_date:end_date]
|
||||
|
||||
stock_group = get_stock_group(stock_group_field, bench_stock_weight, group_method, group_n)
|
||||
|
||||
deal_price_df = stock_df["deal_price"].unstack().T
|
||||
deal_price_df = deal_price_df.fillna(method="ffill")
|
||||
deal_price_df = deal_price_df.ffill()
|
||||
|
||||
# NOTE:
|
||||
# The return will be slightly different from the of the return in the report.
|
||||
|
||||
@@ -4,6 +4,5 @@
|
||||
import fire
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(GetData)
|
||||
|
||||
@@ -87,7 +87,7 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
"""
|
||||
This is a Qlib CLI entrance.
|
||||
User can run the whole Quant research workflow defined by a configure file
|
||||
- the code is located here ``qlib/cli/run.py`
|
||||
- the code is located here ``qlib/cli/run.py``
|
||||
|
||||
User can specify a base_config file in your workflow.yml file by adding "BASE_CONFIG_PATH".
|
||||
Qlib will load the configuration in BASE_CONFIG_PATH first, and the user only needs to update the custom fields
|
||||
|
||||
@@ -10,6 +10,7 @@ Two modes are supported
|
||||
- server
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
@@ -11,7 +11,6 @@ from qlib.utils import init_instance_by_config
|
||||
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from ..data import D
|
||||
from ..config import C
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
|
||||
logger = get_module_logger("Evaluate")
|
||||
|
||||
|
||||
|
||||
@@ -3,5 +3,4 @@
|
||||
|
||||
from .data_selection import MetaTaskDS, MetaDatasetDS, MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaTaskDS", "MetaDatasetDS", "MetaModelDS"]
|
||||
|
||||
@@ -4,5 +4,4 @@
|
||||
from .dataset import MetaDatasetDS, MetaTaskDS
|
||||
from .model import MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaDatasetDS", "MetaTaskDS", "MetaModelDS"]
|
||||
|
||||
@@ -51,7 +51,7 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
w = reweighter.reweight(df)
|
||||
else:
|
||||
raise ValueError("Unsupported reweighter type.")
|
||||
ds_l.append((lgb.Dataset(x.values, label=y, weight=w), key))
|
||||
ds_l.append((lgb.Dataset(x.values, label=y, weight=w, free_raw_data=False), key))
|
||||
return ds_l
|
||||
|
||||
def fit(
|
||||
@@ -109,8 +109,10 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
|
||||
if dtrain.empty:
|
||||
ds_l = self._prepare_data(dataset, reweighter)
|
||||
dtrain, _ = ds_l[0]
|
||||
|
||||
if dtrain.construct().num_data() == 0:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
self.model = lgb.train(
|
||||
|
||||
@@ -10,6 +10,7 @@ import os
|
||||
import gc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from packaging import version
|
||||
from typing import Callable, Optional, Text, Union
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
|
||||
@@ -148,7 +149,7 @@ class DNNModelPytorch(Model):
|
||||
if scheduler == "default":
|
||||
# In torch version 2.7.0, the verbose parameter has been removed. Reference Link:
|
||||
# https://github.com/pytorch/pytorch/pull/147301/files#diff-036a7470d5307f13c9a6a51c3a65dd014f00ca02f476c545488cd856bea9bcf2L1313
|
||||
if str(torch.__version__).split("+", maxsplit=1)[0] <= "2.6.0":
|
||||
if version.parse(str(torch.__version__).split("+", maxsplit=1)[0]) <= version.parse("2.6.0"):
|
||||
# Reduce learning rate when loss has stopped decrease
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # pylint: disable=E1123
|
||||
self.train_optimizer,
|
||||
|
||||
@@ -317,7 +317,7 @@ class TabnetModel(Model):
|
||||
feature = x_train_values.float().to(self.device)
|
||||
label = y_train_values.float().to(self.device)
|
||||
priors = 1 - S_mask
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
vec, sparse_loss = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
|
||||
@@ -348,7 +348,7 @@ class TabnetModel(Model):
|
||||
S_mask = S_mask.to(self.device)
|
||||
priors = 1 - S_mask
|
||||
with torch.no_grad():
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
vec, sparse_loss = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
|
||||
@@ -12,6 +12,7 @@ from ...data import D
|
||||
from ...config import C
|
||||
from ...log import get_module_logger
|
||||
from ...utils import get_next_trading_date
|
||||
from ...utils.pickle_utils import restricted_pickle_load
|
||||
from ...backtest.exchange import Exchange
|
||||
|
||||
log = get_module_logger("utils")
|
||||
@@ -30,7 +31,7 @@ def load_instance(file_path):
|
||||
if not file_path.exists():
|
||||
raise ValueError("Cannot find file {}".format(file_path))
|
||||
with file_path.open("rb") as fr:
|
||||
instance = pickle.load(fr)
|
||||
instance = restricted_pickle_load(fr)
|
||||
return instance
|
||||
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ class FFillNan(ElemOperator):
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="ffill")
|
||||
return series.ffill()
|
||||
|
||||
|
||||
class BFillNan(ElemOperator):
|
||||
@@ -154,7 +154,7 @@ class BFillNan(ElemOperator):
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="bfill")
|
||||
return series.bfill()
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
|
||||
@@ -3,5 +3,4 @@
|
||||
|
||||
from .analysis_model_performance import model_performance_graph
|
||||
|
||||
|
||||
__all__ = ["model_performance_graph"]
|
||||
|
||||
@@ -7,5 +7,4 @@ from .report import report_graph
|
||||
from .rank_label import rank_label_graph
|
||||
from .risk_analysis import risk_analysis_graph
|
||||
|
||||
|
||||
__all__ = ["cumulative_return_graph", "score_ic_graph", "report_graph", "rank_label_graph", "risk_analysis_graph"]
|
||||
|
||||
@@ -33,7 +33,7 @@ def parse_position(position: dict = None) -> pd.DataFrame:
|
||||
|
||||
position_weight_df = get_stock_weight_df(position)
|
||||
# If the day does not exist, use the last weight
|
||||
position_weight_df.fillna(method="ffill", inplace=True)
|
||||
position_weight_df.ffill(inplace=True)
|
||||
|
||||
previous_data = {"date": None, "code_list": []}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ Here is an example.
|
||||
fa.plot_all(wspace=0.3, sub_figsize=(12, 3), col_n=5)
|
||||
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from qlib.contrib.report.data.base import FeaAnalyser
|
||||
|
||||
@@ -7,6 +7,7 @@ Assumptions
|
||||
- The analyse each feature individually
|
||||
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.contrib.report.utils import sub_fig_generator
|
||||
|
||||
@@ -14,6 +14,7 @@ from qlib.model.meta.task import MetaTask
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.typehint import Literal
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.utils import replace_task_handler_with_cache
|
||||
|
||||
@@ -298,7 +299,7 @@ class DDGDA(Rolling):
|
||||
# but their task test segment are not aligned! It worked in my previous experiment.
|
||||
# So the misalignment will not affect the effectiveness of the method.
|
||||
with self._internal_data_path.open("rb") as f:
|
||||
internal_data = pickle.load(f)
|
||||
internal_data = restricted_pickle_load(f)
|
||||
|
||||
md = MetaDatasetDS(exp_name=internal_data, **kwargs)
|
||||
|
||||
@@ -360,7 +361,7 @@ class DDGDA(Rolling):
|
||||
)
|
||||
|
||||
with self._internal_data_path.open("rb") as f:
|
||||
internal_data = pickle.load(f)
|
||||
internal_data = restricted_pickle_load(f)
|
||||
mds = MetaDatasetDS(exp_name=internal_data, **kwargs)
|
||||
|
||||
# 3) meta model make inference and get new qlib task
|
||||
|
||||
@@ -16,7 +16,6 @@ from .rule_strategy import (
|
||||
|
||||
from .cost_control import SoftTopkStrategy
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TopkDropoutStrategy",
|
||||
"WeightStrategyBase",
|
||||
|
||||
@@ -1,101 +1,117 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
This strategy is not well maintained
|
||||
"""
|
||||
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
from .signal_strategy import WeightStrategyBase
|
||||
import copy
|
||||
|
||||
|
||||
class SoftTopkStrategy(WeightStrategyBase):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
dataset,
|
||||
topk,
|
||||
model=None,
|
||||
dataset=None,
|
||||
topk=None,
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
max_sold_weight=1.0,
|
||||
trade_impact_limit=None,
|
||||
risk_degree=0.95,
|
||||
buy_method="first_fill",
|
||||
trade_exchange=None,
|
||||
level_infra=None,
|
||||
common_infra=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Refactored SoftTopkStrategy with a budget-constrained rebalancing engine.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
topk : int
|
||||
top-N stocks to buy
|
||||
The number of top-N stocks to be held in the portfolio.
|
||||
trade_impact_limit : float
|
||||
Maximum weight change for each stock in one trade. If None, fallback to max_sold_weight.
|
||||
max_sold_weight : float
|
||||
Backward-compatible alias for trade_impact_limit. Use 1.0 to effectively disable the limit.
|
||||
risk_degree : float
|
||||
position percentage of total value buy_method:
|
||||
|
||||
rank_fill: assign the weight stocks that rank high first(1/topk max)
|
||||
average_fill: assign the weight to the stocks rank high averagely.
|
||||
The target percentage of total value to be invested.
|
||||
"""
|
||||
super(SoftTopkStrategy, self).__init__(
|
||||
model, dataset, order_generator_cls_or_obj, trade_exchange, level_infra, common_infra, **kwargs
|
||||
model=model, dataset=dataset, order_generator_cls_or_obj=order_generator_cls_or_obj, **kwargs
|
||||
)
|
||||
|
||||
self.topk = topk
|
||||
self.max_sold_weight = max_sold_weight
|
||||
self.trade_impact_limit = trade_impact_limit if trade_impact_limit is not None else max_sold_weight
|
||||
self.risk_degree = risk_degree
|
||||
self.buy_method = buy_method
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing
|
||||
"""
|
||||
# It will use 95% amount of your total value by default
|
||||
return self.risk_degree
|
||||
|
||||
def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
|
||||
def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
score:
|
||||
pred score for this trade date, pd.Series, index is stock_id, contain 'score' column
|
||||
current:
|
||||
current position, use Position() class
|
||||
trade_date:
|
||||
trade date
|
||||
|
||||
generate target position from score for this date and the current position
|
||||
|
||||
The cache is not considered in the position
|
||||
Generates target position using Proportional Budget Allocation.
|
||||
Ensures deterministic sells and synchronized buys under impact limits.
|
||||
"""
|
||||
# TODO:
|
||||
# If the current stock list is more than topk(eg. The weights are modified
|
||||
# by risk control), the weight will not be handled correctly.
|
||||
buy_signal_stocks = set(score.sort_values(ascending=False).iloc[: self.topk].index)
|
||||
cur_stock_weight = current.get_stock_weight_dict(only_stock=True)
|
||||
|
||||
if len(cur_stock_weight) == 0:
|
||||
final_stock_weight = {code: 1 / self.topk for code in buy_signal_stocks}
|
||||
else:
|
||||
final_stock_weight = copy.deepcopy(cur_stock_weight)
|
||||
sold_stock_weight = 0.0
|
||||
for stock_id in final_stock_weight:
|
||||
if stock_id not in buy_signal_stocks:
|
||||
sw = min(self.max_sold_weight, final_stock_weight[stock_id])
|
||||
sold_stock_weight += sw
|
||||
final_stock_weight[stock_id] -= sw
|
||||
if self.buy_method == "first_fill":
|
||||
for stock_id in buy_signal_stocks:
|
||||
add_weight = min(
|
||||
max(1 / self.topk - final_stock_weight.get(stock_id, 0), 0.0),
|
||||
sold_stock_weight,
|
||||
)
|
||||
final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + add_weight
|
||||
sold_stock_weight -= add_weight
|
||||
elif self.buy_method == "average_fill":
|
||||
for stock_id in buy_signal_stocks:
|
||||
final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + sold_stock_weight / len(
|
||||
buy_signal_stocks
|
||||
)
|
||||
else:
|
||||
raise ValueError("Buy method not found")
|
||||
return final_stock_weight
|
||||
if self.topk is None or self.topk <= 0:
|
||||
return {}
|
||||
|
||||
def apply_impact_limit(weight):
|
||||
return weight if self.trade_impact_limit is None else min(weight, self.trade_impact_limit)
|
||||
|
||||
ideal_per_stock = self.risk_degree / self.topk
|
||||
ideal_list = score.sort_values(ascending=False).iloc[: self.topk].index.tolist()
|
||||
|
||||
cur_weights = current.get_stock_weight_dict(only_stock=True)
|
||||
initial_total_weight = sum(cur_weights.values())
|
||||
|
||||
# --- Case A: Cold Start ---
|
||||
if not cur_weights:
|
||||
fill = apply_impact_limit(ideal_per_stock)
|
||||
return {code: fill for code in ideal_list}
|
||||
|
||||
# --- Case B: Rebalancing ---
|
||||
all_tickers = set(cur_weights.keys()) | set(ideal_list)
|
||||
next_weights = {t: cur_weights.get(t, 0.0) for t in all_tickers}
|
||||
|
||||
# Phase 1: Deterministic Sell Phase
|
||||
released_cash = 0.0
|
||||
for t in list(next_weights.keys()):
|
||||
cur = next_weights[t]
|
||||
if cur <= 1e-8:
|
||||
continue
|
||||
|
||||
if t not in ideal_list:
|
||||
sell = apply_impact_limit(cur)
|
||||
next_weights[t] -= sell
|
||||
released_cash += sell
|
||||
elif cur > ideal_per_stock + 1e-8:
|
||||
excess = cur - ideal_per_stock
|
||||
sell = apply_impact_limit(excess)
|
||||
next_weights[t] -= sell
|
||||
released_cash += sell
|
||||
|
||||
# Phase 2: Budget Calculation
|
||||
# Budget = Cash from sells + Available space from target risk degree
|
||||
total_budget = released_cash + (self.risk_degree - initial_total_weight)
|
||||
|
||||
# Phase 3: Proportional Buy Allocation
|
||||
if total_budget > 1e-8:
|
||||
shortfalls = {
|
||||
t: (ideal_per_stock - next_weights.get(t, 0.0))
|
||||
for t in ideal_list
|
||||
if next_weights.get(t, 0.0) < ideal_per_stock - 1e-8
|
||||
}
|
||||
|
||||
if shortfalls:
|
||||
total_shortfall = sum(shortfalls.values())
|
||||
# Normalize total_budget to not exceed total_shortfall
|
||||
available_to_spend = min(total_budget, total_shortfall)
|
||||
|
||||
for t, shortfall in shortfalls.items():
|
||||
# Every stock gets its fair share based on its distance to target
|
||||
share_of_budget = (shortfall / total_shortfall) * available_to_spend
|
||||
|
||||
# Capped by impact limit
|
||||
max_buy_cap = apply_impact_limit(shortfall)
|
||||
|
||||
next_weights[t] += min(share_of_budget, max_buy_cap)
|
||||
|
||||
return {k: v for k, v in next_weights.items() if v > 1e-8}
|
||||
|
||||
@@ -5,5 +5,4 @@ from .base import BaseOptimizer
|
||||
from .optimizer import PortfolioOptimizer
|
||||
from .enhanced_indexing import EnhancedIndexingOptimizer
|
||||
|
||||
|
||||
__all__ = ["BaseOptimizer", "PortfolioOptimizer", "EnhancedIndexingOptimizer"]
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Union, Optional, Dict, Any, List
|
||||
from qlib.log import get_module_logger
|
||||
from .base import BaseOptimizer
|
||||
|
||||
|
||||
logger = get_module_logger("EnhancedIndexingOptimizer")
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
This order generator is for strategies based on WeightStrategyBase
|
||||
"""
|
||||
|
||||
from ...backtest.position import Position
|
||||
from ...backtest.exchange import Exchange
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ This module is not a necessary part of Qlib.
|
||||
They are just some tools for convenience
|
||||
It is should not imported into the core part of qlib
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@@ -13,7 +13,6 @@ import yaml
|
||||
|
||||
from .config import TunerConfigManager
|
||||
|
||||
|
||||
args_parser = argparse.ArgumentParser(prog="tuner")
|
||||
args_parser.add_argument(
|
||||
"-c",
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
from hyperopt import hp
|
||||
|
||||
|
||||
TopkAmountStrategySpace = {
|
||||
"topk": hp.choice("topk", [30, 35, 40]),
|
||||
"buffer_margin": hp.choice("buffer_margin", [200, 250, 300]),
|
||||
|
||||
@@ -8,7 +8,6 @@ import os
|
||||
import yaml
|
||||
import json
|
||||
import copy
|
||||
import pickle
|
||||
import logging
|
||||
import importlib
|
||||
import subprocess
|
||||
@@ -18,6 +17,7 @@ import numpy as np
|
||||
from abc import abstractmethod
|
||||
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils.pickle_utils import restricted_pickle_load
|
||||
from hyperopt import fmin, tpe
|
||||
from hyperopt import STATUS_OK, STATUS_FAIL
|
||||
|
||||
@@ -136,7 +136,7 @@ class QLibTuner(Tuner):
|
||||
exp_result_dir = os.path.join(self.ex_dir, QLibTuner.EXP_RESULT_DIR.format(estimator_ex_id))
|
||||
exp_result_path = os.path.join(exp_result_dir, QLibTuner.EXP_RESULT_NAME)
|
||||
with open(exp_result_path, "rb") as fp:
|
||||
analysis_df = pickle.load(fp)
|
||||
analysis_df = restricted_pickle_load(fp)
|
||||
|
||||
# 4. Get the backtest factor which user want to optimize, if user want to maximize the factor, then reverse the result
|
||||
res = analysis_df.loc[self.optim_config.report_type].loc[self.optim_config.report_factor]
|
||||
|
||||
@@ -3,5 +3,4 @@
|
||||
from .record_temp import MultiSegRecord
|
||||
from .record_temp import SignalMseRecord
|
||||
|
||||
|
||||
__all__ = ["MultiSegRecord", "SignalMseRecord"]
|
||||
|
||||
@@ -36,7 +36,6 @@ from .cache import (
|
||||
MemoryCalendarCache,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"D",
|
||||
"CalendarProvider",
|
||||
|
||||
@@ -30,6 +30,7 @@ from ..utils import (
|
||||
normalize_cache_fields,
|
||||
normalize_cache_instruments,
|
||||
)
|
||||
from ..utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
from ..log import get_module_logger
|
||||
from .base import Feature
|
||||
@@ -225,7 +226,7 @@ class CacheUtils:
|
||||
cache_path = Path(cache_path)
|
||||
meta_path = cache_path.with_suffix(".meta")
|
||||
with meta_path.open("rb") as f:
|
||||
d = pickle.load(f)
|
||||
d = restricted_pickle_load(f)
|
||||
with meta_path.open("wb") as f:
|
||||
try:
|
||||
d["meta"]["last_visit"] = str(time.time())
|
||||
@@ -592,7 +593,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri())}:expression-{cache_uri}"):
|
||||
with meta_path.open("rb") as f:
|
||||
d = pickle.load(f)
|
||||
d = restricted_pickle_load(f)
|
||||
instrument = d["info"]["instrument"]
|
||||
field = d["info"]["field"]
|
||||
freq = d["info"]["freq"]
|
||||
@@ -959,7 +960,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
im = DiskDatasetCache.IndexManager(cp_cache_uri)
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri())}:dataset-{cache_uri}"):
|
||||
with meta_path.open("rb") as f:
|
||||
d = pickle.load(f)
|
||||
d = restricted_pickle_load(f)
|
||||
instruments = d["info"]["instruments"]
|
||||
fields = d["info"]["fields"]
|
||||
freq = d["info"]["freq"]
|
||||
|
||||
@@ -2,15 +2,15 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import division, print_function
|
||||
|
||||
import json
|
||||
|
||||
import socketio
|
||||
|
||||
import qlib
|
||||
from ..config import C
|
||||
|
||||
from ..log import get_module_logger
|
||||
import pickle
|
||||
|
||||
|
||||
class Client:
|
||||
@@ -96,7 +96,7 @@ class Client:
|
||||
self.logger.debug("connected")
|
||||
# The pickle is for passing some parameters with special type(such as
|
||||
# pd.Timestamp)
|
||||
request_content = {"head": head_info, "body": pickle.dumps(request_content, protocol=C.dump_protocol_version)}
|
||||
request_content = {"head": head_info, "body": json.dumps(request_content, default=str)}
|
||||
self.sio.on(request_type + "_response", request_callback)
|
||||
self.logger.debug("try sending")
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
|
||||
@@ -19,7 +19,6 @@ from .loader import DataLoader
|
||||
from . import processor as processor_module
|
||||
from . import loader as data_loader_module
|
||||
|
||||
|
||||
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import pandas as pd
|
||||
@@ -11,6 +10,7 @@ from typing import Tuple, Union, List, Dict
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.serial import Serializable
|
||||
|
||||
@@ -283,7 +283,7 @@ class StaticDataLoader(DataLoader, Serializable):
|
||||
self._data = pd.read_parquet(self._config, engine="pyarrow")
|
||||
else:
|
||||
with Path(self._config).open("rb") as f:
|
||||
self._data = pickle.load(f)
|
||||
self._data = restricted_pickle_load(f)
|
||||
elif isinstance(self._config, pd.DataFrame):
|
||||
self._data = self._config
|
||||
|
||||
|
||||
@@ -67,7 +67,6 @@ class NaiveDFStorage(BaseHandlerStorage):
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
|
||||
# Following conflicts may occur
|
||||
# - Does [20200101", "20210101"] mean selecting this slice or these two days?
|
||||
# To solve this issue
|
||||
|
||||
@@ -168,7 +168,7 @@ class SeriesDFilter(BaseDFilter):
|
||||
for _ts, _bool in timestamp_series.items():
|
||||
# there is likely to be NAN when the filter series don't have the
|
||||
# bool value, so we just change the NAN into False
|
||||
if _bool == np.nan:
|
||||
if np.isnan(_bool):
|
||||
_bool = False
|
||||
if _lbool is None:
|
||||
_cur_start = _ts
|
||||
|
||||
@@ -13,6 +13,7 @@ The calculation of both <period_time, feature> and <observe_time, feature> data
|
||||
2) concatenate all th collasped data, we will get data with format <observe_time, feature>.
|
||||
Qlib will use the operator `P` to perform the collapse.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.data.ops import ElemOperator
|
||||
|
||||
@@ -3,5 +3,4 @@
|
||||
|
||||
from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT
|
||||
|
||||
|
||||
__all__ = ["CalendarStorage", "InstrumentStorage", "FeatureStorage", "CalVT", "InstVT", "InstKT"]
|
||||
|
||||
@@ -156,7 +156,7 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
def index(self, value: CalVT) -> int:
|
||||
self.check()
|
||||
calendar = self._read_calendar()
|
||||
return int(np.argwhere(calendar == value)[0])
|
||||
return calendar.index(value)
|
||||
|
||||
def insert(self, index: int, value: CalVT):
|
||||
calendar = self._read_calendar()
|
||||
|
||||
@@ -5,5 +5,4 @@ import warnings
|
||||
|
||||
from .base import Model
|
||||
|
||||
|
||||
__all__ = ["Model", "warnings"]
|
||||
|
||||
@@ -4,5 +4,4 @@
|
||||
from .task import MetaTask
|
||||
from .dataset import MetaTaskDataset
|
||||
|
||||
|
||||
__all__ = ["MetaTask", "MetaTaskDataset"]
|
||||
|
||||
@@ -6,7 +6,6 @@ from .poet import POETCovEstimator
|
||||
from .shrink import ShrinkCovEstimator
|
||||
from .structured import StructuredCovEstimator
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RiskModel",
|
||||
"POETCovEstimator",
|
||||
|
||||
@@ -9,7 +9,6 @@ import tempfile
|
||||
from importlib import import_module
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
|
||||
DELETE_KEY = "_delete_"
|
||||
|
||||
|
||||
|
||||
@@ -2,17 +2,18 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast, List
|
||||
from typing import List, cast
|
||||
|
||||
import cachetools
|
||||
import pandas as pd
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.constant import EPS_T
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
|
||||
|
||||
@@ -162,7 +163,7 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
with open(path, "rb") as fstream:
|
||||
dataset = pickle.load(fstream)
|
||||
dataset = restricted_pickle_load(fstream)
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
if index_only:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
This module covers some utility functions that operate on data or basic object
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import contextlib
|
||||
import importlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import pkgutil
|
||||
import re
|
||||
import sys
|
||||
@@ -20,6 +19,7 @@ from typing import Any, Dict, List, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from qlib.typehint import InstConf
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
|
||||
def get_module_by_module_path(module_path: Union[str, ModuleType]):
|
||||
@@ -161,7 +161,6 @@ def init_instance_by_config(
|
||||
# path like 'file:///<path to pickle file>/obj.pkl'
|
||||
pr = urlparse(config)
|
||||
if pr.scheme == "file":
|
||||
|
||||
# To enable relative path like file://data/a/b/c.pkl. pr.netloc will be data
|
||||
path = pr.path
|
||||
if pr.netloc != "":
|
||||
@@ -169,10 +168,10 @@ def init_instance_by_config(
|
||||
|
||||
pr_path = os.path.join(pr.netloc, path) if bool(pr.path) else pr.netloc
|
||||
with open(os.path.normpath(pr_path), "rb") as f:
|
||||
return pickle.load(f)
|
||||
return restricted_pickle_load(f)
|
||||
else:
|
||||
with config.open("rb") as f:
|
||||
return pickle.load(f)
|
||||
return restricted_pickle_load(f)
|
||||
|
||||
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from qlib.config import C
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
|
||||
class ObjManager:
|
||||
@@ -116,7 +117,7 @@ class FileManager(ObjManager):
|
||||
|
||||
def load_obj(self, name):
|
||||
with (self.path / name).open("rb") as f:
|
||||
return pickle.load(f)
|
||||
return restricted_pickle_load(f)
|
||||
|
||||
def exists(self, name):
|
||||
return (self.path / name).exists()
|
||||
|
||||
171
qlib/utils/pickle_utils.py
Normal file
171
qlib/utils/pickle_utils.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Secure pickle utilities to prevent arbitrary code execution through deserialization.
|
||||
|
||||
This module provides a secure alternative to pickle.load() and pickle.loads()
|
||||
that restricts deserialization to a whitelist of safe classes.
|
||||
"""
|
||||
|
||||
import io
|
||||
import pickle
|
||||
from typing import Any, BinaryIO, Set, Tuple
|
||||
|
||||
# Whitelist of safe classes that are allowed to be unpickled
|
||||
# These are common data types used in qlib that should be safe to deserialize
|
||||
SAFE_PICKLE_CLASSES: Set[Tuple[str, str]] = {
|
||||
# python builtins
|
||||
("builtins", "slice"),
|
||||
("builtins", "range"),
|
||||
("builtins", "dict"),
|
||||
("builtins", "list"),
|
||||
("builtins", "tuple"),
|
||||
("builtins", "set"),
|
||||
("builtins", "frozenset"),
|
||||
("builtins", "bytearray"),
|
||||
("builtins", "bytes"),
|
||||
("builtins", "str"),
|
||||
("builtins", "int"),
|
||||
("builtins", "float"),
|
||||
("builtins", "bool"),
|
||||
("builtins", "complex"),
|
||||
("builtins", "type"),
|
||||
("builtins", "property"),
|
||||
# common utility classes
|
||||
("datetime", "datetime"),
|
||||
("datetime", "date"),
|
||||
("datetime", "time"),
|
||||
("datetime", "timedelta"),
|
||||
("datetime", "timezone"),
|
||||
("decimal", "Decimal"),
|
||||
("collections", "OrderedDict"),
|
||||
("collections", "defaultdict"),
|
||||
("collections", "Counter"),
|
||||
("collections", "namedtuple"),
|
||||
("enum", "Enum"),
|
||||
("pathlib", "Path"),
|
||||
("pathlib", "PosixPath"),
|
||||
("pathlib", "WindowsPath"),
|
||||
("qlib.data.dataset.handler", "DataHandler"),
|
||||
("qlib.data.dataset.handler", "DataHandlerLP"),
|
||||
("qlib.data.dataset.loader", "StaticDataLoader"),
|
||||
}
|
||||
|
||||
|
||||
TRUSTED_MODULE_PREFIXES = (
|
||||
"pandas",
|
||||
"numpy",
|
||||
)
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
"""Custom unpickler that only allows safe classes to be deserialized.
|
||||
|
||||
This prevents arbitrary code execution through malicious pickle files by
|
||||
restricting deserialization to a whitelist of safe classes.
|
||||
|
||||
Example:
|
||||
>>> with open("data.pkl", "rb") as f:
|
||||
... data = RestrictedUnpickler(f).load()
|
||||
"""
|
||||
|
||||
def find_class(self, module: str, name: str):
|
||||
"""Override find_class to restrict allowed classes.
|
||||
|
||||
Args:
|
||||
module: Module name of the class
|
||||
name: Class name
|
||||
|
||||
Returns:
|
||||
The class object if it's in the whitelist
|
||||
|
||||
Raises:
|
||||
pickle.UnpicklingError: If the class is not in the whitelist
|
||||
"""
|
||||
if module.startswith(TRUSTED_MODULE_PREFIXES):
|
||||
return super().find_class(module, name)
|
||||
|
||||
# 2. explicit whitelist (qlib internal)
|
||||
if (module, name) in SAFE_PICKLE_CLASSES:
|
||||
return super().find_class(module, name)
|
||||
|
||||
raise pickle.UnpicklingError(
|
||||
f"Forbidden class: {module}.{name}. "
|
||||
f"Only whitelisted classes are allowed for security reasons. "
|
||||
f"This is to prevent arbitrary code execution through pickle deserialization."
|
||||
)
|
||||
|
||||
|
||||
def restricted_pickle_load(file: BinaryIO) -> Any:
|
||||
"""Safely load a pickle file with restricted classes.
|
||||
|
||||
This is a drop-in replacement for pickle.load() that prevents
|
||||
arbitrary code execution by only allowing whitelisted classes.
|
||||
|
||||
Args:
|
||||
file: An opened file object in binary mode
|
||||
|
||||
Returns:
|
||||
The unpickled Python object
|
||||
|
||||
Raises:
|
||||
pickle.UnpicklingError: If the pickle contains forbidden classes
|
||||
|
||||
Example:
|
||||
>>> with open("data.pkl", "rb") as f:
|
||||
... data = restricted_pickle_load(f)
|
||||
"""
|
||||
return RestrictedUnpickler(file).load()
|
||||
|
||||
|
||||
def restricted_pickle_loads(data: bytes) -> Any:
|
||||
"""Safely load a pickle from bytes with restricted classes.
|
||||
|
||||
This is a drop-in replacement for pickle.loads() that prevents
|
||||
arbitrary code execution by only allowing whitelisted classes.
|
||||
|
||||
Args:
|
||||
data: Bytes object containing pickled data
|
||||
|
||||
Returns:
|
||||
The unpickled Python object
|
||||
|
||||
Raises:
|
||||
pickle.UnpicklingError: If the pickle contains forbidden classes
|
||||
|
||||
Example:
|
||||
>>> data = b'\\x80\\x04\\x95...'
|
||||
>>> obj = restricted_pickle_loads(data)
|
||||
"""
|
||||
file_like = io.BytesIO(data)
|
||||
return RestrictedUnpickler(file_like).load()
|
||||
|
||||
|
||||
def add_safe_class(module: str, name: str) -> None:
|
||||
"""Add a class to the whitelist of safe classes for unpickling.
|
||||
|
||||
Use this function to extend the whitelist if your code needs to deserialize
|
||||
additional classes. However, be very careful when adding classes, as this
|
||||
could potentially introduce security vulnerabilities.
|
||||
|
||||
Args:
|
||||
module: Module name of the class (e.g., 'my_package.my_module')
|
||||
name: Class name (e.g., 'MyClass')
|
||||
|
||||
Warning:
|
||||
Only add classes that you fully control and trust. Adding arbitrary
|
||||
classes from external packages could introduce security risks.
|
||||
|
||||
Example:
|
||||
>>> add_safe_class('my_package.models', 'CustomModel')
|
||||
"""
|
||||
SAFE_PICKLE_CLASSES.add((module, name))
|
||||
|
||||
|
||||
def get_safe_classes() -> Set[Tuple[str, str]]:
|
||||
"""Get a copy of the current whitelist of safe classes.
|
||||
|
||||
Returns:
|
||||
A set of (module, name) tuples representing allowed classes
|
||||
"""
|
||||
return SAFE_PICKLE_CLASSES.copy()
|
||||
@@ -109,7 +109,7 @@ def resam_ts_data(
|
||||
"""
|
||||
Resample value from time-series data
|
||||
|
||||
- If `feature` has MultiIndex[instrument, datetime], apply the `method` to each instruemnt data with datetime in [start_time, end_time]
|
||||
- If `feature` has MultiIndex[instrument, datetime], apply the `method` to each instrument data with datetime in [start_time, end_time]
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
@@ -222,7 +222,7 @@ def get_valid_value(series, last=True):
|
||||
Nan | float
|
||||
the first/last valid value
|
||||
"""
|
||||
return series.fillna(method="ffill").iloc[-1] if last else series.fillna(method="bfill").iloc[0]
|
||||
return series.ffill().iloc[-1] if last else series.bfill().iloc[0]
|
||||
|
||||
|
||||
def _ts_data_valid(ts_feature, last=False):
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
Time related utils are compiled in this script
|
||||
"""
|
||||
|
||||
import bisect
|
||||
from datetime import datetime, time, date, timedelta
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@@ -14,7 +15,6 @@ import pandas as pd
|
||||
from qlib.config import C
|
||||
from qlib.constant import REG_CN, REG_TW, REG_US
|
||||
|
||||
|
||||
CN_TIME = [
|
||||
datetime.strptime("9:30", "%H:%M"),
|
||||
datetime.strptime("11:30", "%H:%M"),
|
||||
|
||||
@@ -16,7 +16,6 @@ from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
from ..utils.exceptions import ExpAlreadyExistError
|
||||
|
||||
|
||||
logger = get_module_logger("workflow")
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from ..utils.data import deepcopy_basic_type
|
||||
from ..utils.exceptions import QlibException
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
@@ -476,7 +475,13 @@ class PortAnaRecord(ACRecordTemp):
|
||||
if self.backtest_config["start_time"] is None:
|
||||
self.backtest_config["start_time"] = dt_values.min()
|
||||
if self.backtest_config["end_time"] is None:
|
||||
self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1)
|
||||
self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), -1)
|
||||
warnings.warn(
|
||||
"No explicit backtest end_time provided. "
|
||||
"Qlib requires one extra calendar step to determine the right boundary of a bar. "
|
||||
"Therefore the end_time is shifted backward by one trading day from "
|
||||
f"{dt_values.max()} -> {self.backtest_config['end_time']}."
|
||||
)
|
||||
|
||||
artifact_objects = {}
|
||||
# custom strategy and get backtest
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
TaskGenerator module can generate many tasks based on TaskGen and some task templates.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import pandas as pd
|
||||
@@ -106,15 +107,13 @@ def handler_mod(task: dict, rolling_gen):
|
||||
rg (RollingGen): an instance of RollingGen
|
||||
"""
|
||||
try:
|
||||
interval = rolling_gen.ta.cal_interval(
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
|
||||
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1],
|
||||
)
|
||||
# if end_time < the end of test_segments, then change end_time to allow load more data
|
||||
if interval < 0:
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(
|
||||
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1]
|
||||
)
|
||||
handler_kwargs = task["dataset"]["kwargs"]["handler"]["kwargs"]
|
||||
handler_end_time = handler_kwargs.get("end_time")
|
||||
test_seg_end_time = task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1]
|
||||
# if the end of test_segments is None (open-ended segment, i.e., "until now") or end_time < the end of test_segments,
|
||||
# then change end_time to allow load more data
|
||||
if test_seg_end_time is None or rolling_gen.ta.cal_interval(handler_end_time, test_seg_end_time) < 0:
|
||||
handler_kwargs["end_time"] = copy.deepcopy(test_seg_end_time)
|
||||
except KeyError:
|
||||
# Maybe dataset do not have handler, then do nothing.
|
||||
pass
|
||||
|
||||
@@ -12,6 +12,7 @@ A task in TaskManager consists of 3 parts
|
||||
- tasks status: the status of the task
|
||||
- tasks result: A user can get the task with the task description and task result.
|
||||
"""
|
||||
|
||||
import concurrent
|
||||
import pickle
|
||||
import time
|
||||
@@ -28,6 +29,7 @@ from tqdm.cli import tqdm
|
||||
|
||||
from .utils import get_mongodb
|
||||
from ...config import C
|
||||
from ...utils.pickle_utils import restricted_pickle_loads
|
||||
|
||||
|
||||
class TaskManager:
|
||||
@@ -131,7 +133,7 @@ class TaskManager:
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
for k in list(task.keys()):
|
||||
if k.startswith(prefix):
|
||||
task[k] = pickle.loads(task[k])
|
||||
task[k] = restricted_pickle_loads(task[k])
|
||||
return task
|
||||
|
||||
def _dict_to_str(self, flt):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from loguru import logger
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
import qlib
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ class DataHealthChecker:
|
||||
self.large_step_threshold_price = large_step_threshold_price
|
||||
self.large_step_threshold_volume = large_step_threshold_volume
|
||||
self.missing_data_num = missing_data_num
|
||||
self.qlib_dir = os.path.abspath(os.path.expanduser(qlib_dir))
|
||||
|
||||
if csv_path:
|
||||
assert os.path.isdir(csv_path), f"{csv_path} should be a directory."
|
||||
@@ -68,6 +69,43 @@ class DataHealthChecker:
|
||||
self.data[instrument] = df
|
||||
print(df)
|
||||
|
||||
# NOTE:
|
||||
# This check is added due to a known issue in Qlib where feature paths
|
||||
# are constructed using lowercased instrument names. On case-sensitive
|
||||
# file systems (e.g. Linux), uppercase directory names under `features/`
|
||||
# will cause data loading failures.
|
||||
#
|
||||
# See: https://github.com/microsoft/qlib/issues/2053
|
||||
def check_features_dir_lowercase(self) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Check whether all subdirectories under `<qlib_dir>/features` are named in lowercase.
|
||||
|
||||
This validation helps prevent data loading issues on case-sensitive
|
||||
file systems caused by uppercase instrument directory names.
|
||||
"""
|
||||
if not self.qlib_dir:
|
||||
return None
|
||||
|
||||
features_dir = os.path.join(self.qlib_dir, "features")
|
||||
if not os.path.isdir(features_dir):
|
||||
logger.warning(f"`features` directory not found under {self.qlib_dir}")
|
||||
return None
|
||||
|
||||
bad_dirs = []
|
||||
for name in os.listdir(features_dir):
|
||||
full_path = os.path.join(features_dir, name)
|
||||
if os.path.isdir(full_path) and name != name.lower():
|
||||
bad_dirs.append(name)
|
||||
|
||||
if bad_dirs:
|
||||
result_df = pd.DataFrame({"non_lowercase_dir": bad_dirs})
|
||||
return result_df
|
||||
else:
|
||||
logger.info(
|
||||
f"✅ All subdirectories under `{os.path.join(self.qlib_dir, 'features')}` are named in lowercase."
|
||||
)
|
||||
return None
|
||||
|
||||
def check_missing_data(self) -> Optional[pd.DataFrame]:
|
||||
"""Check if any data is missing in the DataFrame."""
|
||||
result_dict = {
|
||||
@@ -177,11 +215,13 @@ class DataHealthChecker:
|
||||
check_large_step_changes_result = self.check_large_step_changes()
|
||||
check_required_columns_result = self.check_required_columns()
|
||||
check_missing_factor_result = self.check_missing_factor()
|
||||
check_features_dir_case_result = self.check_features_dir_lowercase()
|
||||
if (
|
||||
check_large_step_changes_result is not None
|
||||
or check_large_step_changes_result is not None
|
||||
or check_required_columns_result is not None
|
||||
or check_missing_factor_result is not None
|
||||
or check_features_dir_case_result is not None
|
||||
):
|
||||
print(f"\nSummary of data health check ({len(self.data)} files checked):")
|
||||
print("-------------------------------------------------")
|
||||
@@ -197,6 +237,11 @@ class DataHealthChecker:
|
||||
if isinstance(check_missing_factor_result, pd.DataFrame):
|
||||
logger.warning(f"The factor column does not exist or is empty")
|
||||
print(check_missing_factor_result)
|
||||
if isinstance(check_features_dir_case_result, pd.DataFrame):
|
||||
logger.warning(
|
||||
f"Some subdirectories under `{os.path.join(self.qlib_dir, 'features')}` contain uppercase letters, please rename them to lowercase manually."
|
||||
)
|
||||
print(check_features_dir_case_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -28,32 +28,32 @@ class InfoCollector:
|
||||
"""collect qlib related info"""
|
||||
print("Qlib version: {}".format(qlib.__version__))
|
||||
REQUIRED = [
|
||||
"setuptools",
|
||||
"wheel",
|
||||
"cython",
|
||||
"pyyaml",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"scipy",
|
||||
"requests",
|
||||
"sacred",
|
||||
"python-socketio",
|
||||
"redis",
|
||||
"python-redis-lock",
|
||||
"schedule",
|
||||
"cvxpy",
|
||||
"hyperopt",
|
||||
"fire",
|
||||
"statsmodels",
|
||||
"xlrd",
|
||||
"plotly",
|
||||
"matplotlib",
|
||||
"tables",
|
||||
"pyyaml",
|
||||
"mlflow",
|
||||
"tqdm",
|
||||
"loguru",
|
||||
"lightgbm",
|
||||
"tornado",
|
||||
"joblib",
|
||||
"filelock",
|
||||
"redis",
|
||||
"dill",
|
||||
"fire",
|
||||
"ruamel.yaml",
|
||||
"python-redis-lock",
|
||||
"tqdm",
|
||||
"pymongo",
|
||||
"loguru",
|
||||
"lightgbm",
|
||||
"gym",
|
||||
"cvxpy",
|
||||
"joblib",
|
||||
"matplotlib",
|
||||
"jupyter",
|
||||
"nbconvert",
|
||||
"pyarrow",
|
||||
"pydantic-settings",
|
||||
"setuptools-scm",
|
||||
]
|
||||
|
||||
for package in REQUIRED:
|
||||
|
||||
@@ -172,7 +172,7 @@ class BaostockNormalizeHS3005min(BaseNormalize):
|
||||
@staticmethod
|
||||
def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series:
|
||||
df = df.copy()
|
||||
_tmp_series = df["close"].fillna(method="ffill")
|
||||
_tmp_series = df["close"].ffill()
|
||||
_tmp_shift_series = _tmp_series.shift(1)
|
||||
if last_close is not None:
|
||||
_tmp_shift_series.iloc[0] = float(last_close)
|
||||
|
||||
@@ -280,11 +280,20 @@ class Normalize:
|
||||
self._symbol_field_name = symbol_field_name
|
||||
self._end_date = kwargs.get("end_date", None)
|
||||
self._max_workers = max_workers
|
||||
self.interval = kwargs.get("interval", "1d")
|
||||
|
||||
self._normalize_obj = normalize_class(
|
||||
date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs
|
||||
)
|
||||
|
||||
def format_data(self, df: pd.DataFrame):
|
||||
if self.interval == "1d":
|
||||
try:
|
||||
pd.to_datetime(df.iloc[-1]["date"], format="%Y-%m-%d", errors="raise")
|
||||
except Exception:
|
||||
df = df.iloc[:-1]
|
||||
return df
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
|
||||
@@ -300,14 +309,18 @@ class Normalize:
|
||||
keep_default_na=False,
|
||||
na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns},
|
||||
)
|
||||
df = self.format_data(df=df)
|
||||
|
||||
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if df is not None and not df.empty:
|
||||
if self._end_date is not None:
|
||||
_mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date)
|
||||
df = df[_mask]
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
if not df.empty:
|
||||
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if df is not None and not df.empty:
|
||||
if self._end_date is not None:
|
||||
_mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date)
|
||||
df = df[_mask]
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
else:
|
||||
logger.warning(f"{file_path.stem} source data is empty and will not undergo normalization processing.")
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
@@ -22,7 +22,6 @@ from data_collector.index import IndexBase
|
||||
from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry
|
||||
from data_collector.utils import get_instruments
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = (
|
||||
"https://oss-ch.csindex.com.cn/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
)
|
||||
|
||||
@@ -19,7 +19,6 @@ from time import mktime
|
||||
from datetime import datetime as dt
|
||||
import time
|
||||
|
||||
|
||||
_CG_CRYPTO_SYMBOLS = None
|
||||
|
||||
|
||||
|
||||
@@ -7,13 +7,14 @@ import sys
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
from io import StringIO
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
from fake_useragent import UserAgent
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
@@ -22,7 +23,6 @@ from data_collector.index import IndexBase
|
||||
from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift
|
||||
from data_collector.utils import get_instruments
|
||||
|
||||
|
||||
WIKI_URL = "https://en.wikipedia.org/wiki"
|
||||
|
||||
WIKI_INDEX_NAME_MAP = {
|
||||
@@ -51,6 +51,7 @@ class WIKIIndex(IndexBase):
|
||||
)
|
||||
|
||||
self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}"
|
||||
self._ua = UserAgent()
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@@ -112,7 +113,8 @@ class WIKIIndex(IndexBase):
|
||||
return _calendar_list
|
||||
|
||||
def _request_new_companies(self) -> requests.Response:
|
||||
resp = requests.get(self._target_url, timeout=None)
|
||||
headers = {"User-Agent": self._ua.random}
|
||||
resp = requests.get(self._target_url, timeout=None, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {self._target_url}")
|
||||
|
||||
@@ -128,7 +130,7 @@ class WIKIIndex(IndexBase):
|
||||
def get_new_companies(self):
|
||||
logger.info(f"get new companies {self.index_name} ......")
|
||||
_data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)()
|
||||
df_list = pd.read_html(_data.text)
|
||||
df_list = pd.read_html(StringIO(_data.text))
|
||||
for _df in df_list:
|
||||
_df = self.filter_df(_df)
|
||||
if (_df is not None) and (not _df.empty):
|
||||
@@ -226,7 +228,11 @@ class SP500Index(WIKIIndex):
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
logger.info(f"get sp500 history changes......")
|
||||
# NOTE: may update the index of the table
|
||||
changes_df = pd.read_html(self.WIKISP500_CHANGES_URL)[-1]
|
||||
# Add headers to avoid 403 Forbidden error from Wikipedia
|
||||
headers = {"User-Agent": self._ua.random}
|
||||
response = requests.get(self.WIKISP500_CHANGES_URL, headers=headers, timeout=None)
|
||||
response.raise_for_status()
|
||||
changes_df = pd.read_html(StringIO(response.text))[-1]
|
||||
changes_df = changes_df.iloc[:, [0, 1, 3]]
|
||||
changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE]
|
||||
changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME])
|
||||
|
||||
@@ -3,3 +3,4 @@ requests
|
||||
pandas
|
||||
lxml
|
||||
loguru
|
||||
fake-useragent
|
||||
|
||||
@@ -7,7 +7,6 @@ import importlib
|
||||
import time
|
||||
import bisect
|
||||
import pickle
|
||||
import random
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
@@ -21,6 +20,9 @@ from tqdm import tqdm
|
||||
from functools import partial
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from bs4 import BeautifulSoup
|
||||
import baostock as bs
|
||||
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
|
||||
@@ -67,9 +69,16 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
|
||||
logger.info(f"get calendar list: {bench_code}......")
|
||||
|
||||
def _get_calendar(url):
|
||||
_value_list = requests.get(url, timeout=None).json()["data"]["klines"]
|
||||
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
|
||||
def _get_calendar(end_date):
|
||||
bs.login()
|
||||
rs = bs.query_trade_dates(start_date="2005-01-01", end_date=end_date)
|
||||
data_list = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
data_list.append(rs.get_row_data())
|
||||
bs.logout()
|
||||
df = pd.DataFrame(data_list, columns=rs.fields)
|
||||
trade_days = df[df["is_trading_day"] == "1"]["calendar_date"]
|
||||
return sorted(map(pd.Timestamp, trade_days.to_list()))
|
||||
|
||||
calendar = _CALENDAR_MAP.get(bench_code, None)
|
||||
if calendar is None:
|
||||
@@ -80,30 +89,17 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
|
||||
else:
|
||||
if bench_code.upper() == "ALL":
|
||||
import akshare as ak # pylint: disable=C0415
|
||||
|
||||
@deco_retry
|
||||
def _get_calendar(month):
|
||||
_cal = []
|
||||
try:
|
||||
resp = requests.get(
|
||||
SZSE_CALENDAR_URL.format(month=month, random=random.random), timeout=None
|
||||
).json()
|
||||
for _r in resp["data"]:
|
||||
if int(_r["jybz"]):
|
||||
_cal.append(pd.Timestamp(_r["jyrq"]))
|
||||
except Exception as e:
|
||||
raise ValueError(f"{month}-->{e}") from e
|
||||
return _cal
|
||||
|
||||
month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M")
|
||||
calendar = []
|
||||
for _m in month_range:
|
||||
cal = _get_calendar(_m.strftime("%Y-%m"))
|
||||
if cal:
|
||||
calendar += cal
|
||||
calendar = list(filter(lambda x: x <= pd.Timestamp.now(), calendar))
|
||||
trade_date_df = ak.tool_trade_date_hist_sina()
|
||||
trade_date_list = trade_date_df["trade_date"].tolist()
|
||||
trade_date_list = [pd.Timestamp(d) for d in trade_date_list]
|
||||
dates = pd.DatetimeIndex(trade_date_list)
|
||||
filtered_dates = dates[(dates >= "2000-01-04") & (dates <= pd.Timestamp.today().normalize())]
|
||||
calendar = filtered_dates.tolist()
|
||||
else:
|
||||
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
|
||||
end_date = time.strftime("%Y-%m-%d", time.localtime())
|
||||
calendar = _get_calendar(end_date=end_date)
|
||||
_CALENDAR_MAP[bench_code] = calendar
|
||||
logger.info(f"end of get calendar list: {bench_code}.")
|
||||
return calendar
|
||||
@@ -280,7 +276,7 @@ def get_hs_stock_symbols() -> list:
|
||||
symbol_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if symbol_cache_path.exists():
|
||||
with symbol_cache_path.open("rb") as fp:
|
||||
cache_symbols = pickle.load(fp)
|
||||
cache_symbols = restricted_pickle_load(fp)
|
||||
symbols |= cache_symbols
|
||||
with symbol_cache_path.open("wb") as fp:
|
||||
pickle.dump(symbols, fp)
|
||||
@@ -297,20 +293,14 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
import akshare as ak # pylint: disable=C0415
|
||||
|
||||
global _US_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12"
|
||||
resp = requests.get(url, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
try:
|
||||
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
df = ak.get_us_stock_name()
|
||||
_symbols = df["symbol"].to_list()
|
||||
|
||||
if len(_symbols) < 8000:
|
||||
raise ValueError("request error")
|
||||
|
||||
@@ -371,7 +371,7 @@ class YahooNormalize(BaseNormalize):
|
||||
@staticmethod
|
||||
def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series:
|
||||
df = df.copy()
|
||||
_tmp_series = df["close"].fillna(method="ffill")
|
||||
_tmp_series = df["close"].ffill()
|
||||
_tmp_shift_series = _tmp_series.shift(1)
|
||||
if last_close is not None:
|
||||
_tmp_shift_series.iloc[0] = float(last_close)
|
||||
@@ -459,7 +459,7 @@ class YahooNormalize1d(YahooNormalize, ABC):
|
||||
df.set_index(self._date_field_name, inplace=True)
|
||||
if "adjclose" in df:
|
||||
df["factor"] = df["adjclose"] / df["close"]
|
||||
df["factor"] = df["factor"].fillna(method="ffill")
|
||||
df["factor"] = df["factor"].ffill()
|
||||
else:
|
||||
df["factor"] = 1
|
||||
for _col in self.COLUMNS:
|
||||
@@ -613,10 +613,6 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
def symbol_to_yahoo(self, symbol):
|
||||
raise NotImplementedError("rewrite symbol_to_yahoo")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
raise NotImplementedError("rewrite _get_1d_calendar_list")
|
||||
|
||||
|
||||
class YahooNormalizeUS:
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
|
||||
@@ -9,4 +9,5 @@ yahooquery
|
||||
joblib
|
||||
beautifulsoup4
|
||||
bs4
|
||||
soupsieve
|
||||
soupsieve
|
||||
akshare
|
||||
@@ -4,6 +4,5 @@
|
||||
import fire
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(GetData)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user