mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
30 Commits
update_rea
...
v0.9.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da920b7f95 | ||
|
|
d89fa0184c | ||
|
|
1b426503fc | ||
|
|
78b77e302b | ||
|
|
38f02d25dc | ||
|
|
de86e46ed0 | ||
|
|
ba8b6cc30f | ||
|
|
3525514704 | ||
|
|
3e72593b8c | ||
|
|
c38e799ce7 | ||
|
|
14d54aa2a1 | ||
|
|
89ae312109 | ||
|
|
3ea30c0290 | ||
|
|
4b8d70df1b | ||
|
|
a2996f7046 | ||
|
|
fbba768006 | ||
|
|
df557d29d5 | ||
|
|
be9cd9fe23 | ||
|
|
85cc74846b | ||
|
|
950408ef46 | ||
|
|
320bd65e19 | ||
|
|
e7a1b5ea1f | ||
|
|
67feeaeb00 | ||
|
|
4d621bff99 | ||
|
|
82f1ef2def | ||
|
|
186512f272 | ||
|
|
bda374180a | ||
|
|
014ff7d3fe | ||
|
|
23d9d5a0a9 | ||
|
|
7ce97c9da5 |
28
.github/workflows/test_qlib_from_pip.yml
vendored
28
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
|
||||
os: [windows-latest, ubuntu-24.04, ubuntu-22.04, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
@@ -31,22 +31,30 @@ jobs:
|
||||
- name: Update pip to the latest version
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
# Will cancel this step when the next qlib version is released. The current qlib version is: 0.9.6
|
||||
- name: Installing pywinpt for windows
|
||||
if: ${{ matrix.os == 'windows-latest' }}
|
||||
run: |
|
||||
python -m pip install pywinpty --only-binary=:all:
|
||||
|
||||
# # joblib was released on 2025-05-04 with version 1.5.0, in which _backend_args was removed and replaced by _backend_kwargs.
|
||||
# This change caused the application to fail, so the version of joblib is restricted here.
|
||||
# This restriction will be removed in the next release. The current qlib version is: 0.9.6
|
||||
- name: Qlib installation test
|
||||
run: |
|
||||
python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pyqlib==0.9.5.80
|
||||
python -m pip install pyqlib
|
||||
python -m pip install "joblib<=1.4.2"
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
brew update
|
||||
brew install libomp || brew reinstall libomp
|
||||
python -m pip install --no-binary=:all: lightgbm
|
||||
|
||||
# When the new version is released it should be changed to:
|
||||
# python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
- name: Downloads dependencies data
|
||||
run: |
|
||||
cd ..
|
||||
|
||||
39
.github/workflows/test_qlib_from_source.yml
vendored
39
.github/workflows/test_qlib_from_source.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
|
||||
os: [windows-latest, ubuntu-24.04, ubuntu-22.04, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
@@ -34,12 +34,12 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Installing pytorch for ubuntu
|
||||
if: ${{ matrix.os == 'ubuntu-20.04' || matrix.os == 'ubuntu-22.04' }}
|
||||
if: ${{ matrix.os == 'ubuntu-24.04' || matrix.os == 'ubuntu-22.04' }}
|
||||
run: |
|
||||
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
@@ -84,15 +84,11 @@ jobs:
|
||||
python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
brew update
|
||||
brew install libomp || brew reinstall libomp
|
||||
python -m pip install --no-binary=:all: lightgbm
|
||||
|
||||
- name: Check Qlib ipynb with nbconvert
|
||||
run: |
|
||||
@@ -101,9 +97,26 @@ jobs:
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python -m pip install numba
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
- name: Unit tests with Pytest (MacOS)
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
uses: nick-fields/retry@v2
|
||||
with:
|
||||
timeout_minutes: 60
|
||||
max_attempts: 3
|
||||
command: |
|
||||
# Limit the number of threads in various libraries to prevent Segmentation faults caused by OpenMP multithreading conflicts under macOS.
|
||||
export OMP_NUM_THREADS=1 # Limit the number of OpenMP threads
|
||||
export MKL_NUM_THREADS=1 # Limit the number of Intel MKL threads
|
||||
export NUMEXPR_NUM_THREADS=1 # Limit the number of NumExpr threads
|
||||
export OPENBLAS_NUM_THREADS=1 # Limit the number of OpenBLAS threads
|
||||
export VECLIB_MAXIMUM_THREADS=1 # Limit the number of macOS Accelerate/vecLib threads
|
||||
cd tests
|
||||
python -m pytest . -m "not slow" --durations=0
|
||||
|
||||
- name: Unit tests with Pytest (Ubuntu and Windows)
|
||||
if: ${{ matrix.os != 'macos-13' && matrix.os != 'macos-14' && matrix.os != 'macos-15' }}
|
||||
uses: nick-fields/retry@v2
|
||||
with:
|
||||
timeout_minutes: 60
|
||||
|
||||
16
.github/workflows/test_qlib_from_source_slow.yml
vendored
16
.github/workflows/test_qlib_from_source_slow.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-15]
|
||||
os: [windows-latest, ubuntu-24.04, ubuntu-22.04, macos-14, macos-15]
|
||||
# In github action, using python 3.7, pip install will not match the latest version of the package.
|
||||
# Also, python 3.7 is no longer supported from macos-14, and will be phased out from macos-13 in the near future.
|
||||
# All things considered, we have removed python 3.7.
|
||||
@@ -37,16 +37,14 @@ jobs:
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
# install.sh file contents from: https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh
|
||||
# brew_install.sh file contents from: https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
if: ${{ matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
brew update
|
||||
brew install libomp || brew reinstall libomp
|
||||
python -m pip install --no-binary=:all: lightgbm
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
uses: nick-fields/retry@v2
|
||||
|
||||
16
Makefile
16
Makefile
@@ -12,6 +12,12 @@ PUBLIC_DIR := $(shell [ "$$READTHEDOCS" = "True" ] && echo "$$READTHEDOCS_OUTPUT
|
||||
SO_DIR := qlib/data/_libs
|
||||
SO_FILES := $(wildcard $(SO_DIR)/*.so)
|
||||
|
||||
ifeq ($(OS),Windows_NT)
|
||||
IS_WINDOWS = true
|
||||
else
|
||||
IS_WINDOWS = false
|
||||
endif
|
||||
|
||||
########################################################################################
|
||||
# Development Environment Management
|
||||
########################################################################################
|
||||
@@ -48,6 +54,10 @@ deepclean: clean
|
||||
# What this code does is compile two Cython modules, rolling and expanding, using setuptools and Cython,
|
||||
# and builds them as binary expansion modules that can be imported directly into Python.
|
||||
# Since pyproject.toml can't do that, we compile it here.
|
||||
|
||||
# pywinpty as a dependency of jupyter on windows, if you use pip install pywinpty installation,
|
||||
# will first download the tar.gz file, and then locally compiled and installed,
|
||||
# this will lead to some unnecessary trouble, so we choose to install the compiled whl file, to avoid trouble.
|
||||
prerequisite:
|
||||
@if [ -n "$(SO_FILES)" ]; then \
|
||||
echo "Shared library files exist, skipping build."; \
|
||||
@@ -58,6 +68,10 @@ prerequisite:
|
||||
python -c "from setuptools import setup, Extension; from Cython.Build import cythonize; import numpy; extensions = [Extension('qlib.data._libs.rolling', ['qlib/data/_libs/rolling.pyx'], language='c++', include_dirs=[numpy.get_include()]), Extension('qlib.data._libs.expanding', ['qlib/data/_libs/expanding.pyx'], language='c++', include_dirs=[numpy.get_include()])]; setup(ext_modules=cythonize(extensions, language_level='3'), script_args=['build_ext', '--inplace'])"; \
|
||||
fi
|
||||
|
||||
@if [ "$(IS_WINDOWS)" = "true" ]; then \
|
||||
python -m pip install pywinpty --only-binary=:all:; \
|
||||
fi
|
||||
|
||||
# Install the package in editable mode.
|
||||
dependencies:
|
||||
python -m pip install -e .
|
||||
@@ -87,7 +101,7 @@ analysis:
|
||||
python -m pip install -e .[analysis]
|
||||
|
||||
all:
|
||||
python -m pip install -e .[dev,lint,docs,package,test,analysis,rl]
|
||||
python -m pip install -e .[pywinpty,dev,lint,docs,package,test,analysis,rl]
|
||||
|
||||
install: prerequisite dependencies
|
||||
|
||||
|
||||
64
README.md
64
README.md
@@ -26,10 +26,25 @@ We have prepared several demo videos for you:
|
||||
| Quant Factor Mining from reports | [Link](https://rdagent.azurewebsites.net/report_factor?lang=en) | [Link](https://rdagent.azurewebsites.net/report_factor?lang=zh) |
|
||||
| Quant Model Optimization | [Link](https://rdagent.azurewebsites.net/model_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/model_loop?lang=zh) |
|
||||
|
||||
- 📃**Paper**: [R&D-Agent-Quant: A Multi-Agent Framework for Data-Centric Factors and Model Joint Optimization](https://arxiv.org/abs/2505.15155)
|
||||
- 👾**Code**: https://github.com/microsoft/RD-Agent/
|
||||
```BibTeX
|
||||
@misc{li2025rdagentquant,
|
||||
title={R\&D-Agent-Quant: A Multi-Agent Framework for Data-Centric Factors and Model Joint Optimization},
|
||||
author={Yuante Li and Xu Yang and Xiao Yang and Minrui Xu and Xisen Wang and Weiqing Liu and Jiang Bian},
|
||||
year={2025},
|
||||
eprint={2505.15155},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.AI}
|
||||
}
|
||||
```
|
||||

|
||||
|
||||
***
|
||||
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| [R&D-Agent-Quant](https://arxiv.org/abs/2505.15155) Published | Apply R&D-Agent to Qlib for quant trading |
|
||||
| BPQP for End-to-end learning | 📈Coming soon!([Under review](https://github.com/microsoft/qlib/pull/1863)) |
|
||||
| 🔥LLM-driven Auto Quant Factory🔥 | 🚀 Released in [♾️RD-Agent](https://github.com/microsoft/RD-Agent) on Aug 8, 2024 |
|
||||
| KRNN and Sandwich models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1414/) on May 26, 2023 |
|
||||
@@ -155,15 +170,15 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
|
||||
This table demonstrates the supported Python version of `Qlib`:
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:------------------:|
|
||||
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
| Python 3.9 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.10 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.11 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.12 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
|
||||
**Note**:
|
||||
1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
1. `Qlib`Requires `tables` package, `hdf5` in tables does not support python3.9.
|
||||
2. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.8 or higher, or use `conda`'s Python to install ``Qlib`` from source.
|
||||
|
||||
### Install with pip
|
||||
Users can easily install ``Qlib`` by pip according to the following command.
|
||||
@@ -181,7 +196,7 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
|
||||
```bash
|
||||
pip install numpy
|
||||
pip install --upgrade cython
|
||||
pip install --upgrade cython
|
||||
```
|
||||
|
||||
* Clone the repository and install ``Qlib`` as follows.
|
||||
@@ -189,17 +204,16 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
pip install . # `pip install -e .[dev]` is recommended for development. check details in docs/developer/code_standard_and_dev_guide.rst
|
||||
```
|
||||
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommended approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
|
||||
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test_qlib_from_source.yml) may help you find the problem.
|
||||
|
||||
**Tips for Mac**: If you are using Mac with M1, you might encounter issues in building the wheel for LightGBM, which is due to missing dependencies from OpenMP. To solve the problem, install openmp first with ``brew install libomp`` and then run ``pip install .`` to build it successfully.
|
||||
|
||||
## Data Preparation
|
||||
❗ Due to more restrict data security policy. The offical dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.
|
||||
Here is an example to download the data updated on 20240809.
|
||||
❗ Due to more restrict data security policy. The official dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.
|
||||
Here is an example to download the latest data.
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/2024-08-09/qlib_bin.tar.gz
|
||||
wget https://github.com/chenditc/investment_data/releases/latest/download/qlib_bin.tar.gz
|
||||
mkdir -p ~/.qlib/qlib_data/cn_data
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=1
|
||||
rm -f qlib_bin.tar.gz
|
||||
@@ -215,10 +229,10 @@ Load and prepare data by running the following code:
|
||||
### Get with module
|
||||
```bash
|
||||
# get 1d data
|
||||
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# get 1min data
|
||||
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
|
||||
```
|
||||
|
||||
@@ -265,6 +279,16 @@ We recommend users to prepare their own data if they have a high-quality dataset
|
||||
* *trading_date*: start of trading day
|
||||
* *end_date*: end of trading day(not included)
|
||||
|
||||
### Checking the health of the data
|
||||
* We provide a script to check the health of the data, you can run the following commands to check whether the data is healthy or not.
|
||||
```
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
```
|
||||
* Of course, you can also add some parameters to adjust the test results, such as this.
|
||||
```
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data --missing_data_num 30055 --large_step_threshold_volume 94485 --large_step_threshold_price 20
|
||||
```
|
||||
* If you want more information about `check_data_health`, please refer to the [documentation](https://qlib.readthedocs.io/en/latest/component/data.html#checking-the-health-of-the-data).
|
||||
|
||||
<!--
|
||||
- Run the initialization code and get stock data:
|
||||
@@ -305,7 +329,7 @@ We recommend users to prepare their own data if they have a high-quality dataset
|
||||
3. At this point you are in the docker environment and can run the qlib scripts. An example:
|
||||
```bash
|
||||
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
>>> python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
>>> python qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
4. Exit the container
|
||||
```bash
|
||||
@@ -335,9 +359,9 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
```
|
||||
If users want to use `qrun` under debug mode, please use the following command:
|
||||
```bash
|
||||
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pdb qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
The result of `qrun` is as follows, please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
|
||||
The result of `qrun` is as follows, please refer to [docs](https://qlib.readthedocs.io/en/latest/component/strategy.html#result) for more explanations about the result.
|
||||
|
||||
```bash
|
||||
|
||||
@@ -453,6 +477,14 @@ python run_all_model.py run 10
|
||||
|
||||
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
### Break change
|
||||
In `pandas`, `group_key` is one of the parameters of the `groupby` method. From version 1.5 to 2.0 of `pandas`, the default value of `group_key` has been changed from `no default` to `True`, which will cause qlib to report an error during operation. So we set `group_key=False`, but it doesn't guarantee that some programmes will run correctly, including:
|
||||
* qlib\examples\rl_order_execution\scripts\gen_training_orders.py
|
||||
* qlib\examples\benchmarks\TRA\src\dataset.MTSDatasetH.py
|
||||
* qlib\examples\benchmarks\TFT\tft.py
|
||||
|
||||
|
||||
|
||||
## [Adapting to Market Dynamics](examples/benchmarks_dynamic)
|
||||
|
||||
Due to the non-stationary nature of the environment of the financial market, the data distribution may change in different periods, which makes the performance of models build on training data decays in the future test data.
|
||||
@@ -588,7 +620,7 @@ You can find some impefect implementation in Qlib by `rg 'TODO|FIXME' qlib`
|
||||
|
||||
If you would like to become one of Qlib's maintainers to contribute more (e.g. help merge PR, triage issues), please contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)). We are glad to help to upgrade your permission.
|
||||
|
||||
## Licence
|
||||
## License
|
||||
Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the right to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
@@ -108,10 +108,10 @@ Automatic update of daily frequency data
|
||||
|
||||
|
||||
|
||||
Converting CSV Format into Qlib Format
|
||||
--------------------------------------
|
||||
Converting CSV and Parquet Format into Qlib Format
|
||||
--------------------------------------------------
|
||||
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV or Parquet format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
|
||||
|
||||
Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format.
|
||||
Here are some example:
|
||||
@@ -126,17 +126,17 @@ for 1min data:
|
||||
|
||||
python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10
|
||||
|
||||
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
|
||||
Users can also provide their own data in CSV or Parquet format. However, the data **must satisfies** following criterions:
|
||||
|
||||
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
|
||||
- CSV or Parquet file is named after a specific stock *or* the CSV or Parquet file includes a column of the stock name
|
||||
|
||||
- Name the CSV file after a stock: `SH600000.csv`, `AAPL.csv` (not case sensitive).
|
||||
- Name the CSV or Parquet file after a stock: `SH600000.csv`, `AAPL.csv` or `SH600000.parquet`, `AAPL.parquet` (not case sensitive).
|
||||
|
||||
- CSV file includes a column of the stock name. User **must** specify the column name when dumping the data. Here is an example:
|
||||
- CSV or Parquet file includes a column of the stock name. User **must** specify the column name when dumping the data. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all ... --symbol_field_name symbol
|
||||
python scripts/dump_bin.py dump_all ... --symbol_field_name symbol --file_suffix <.csv or .parquet>
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
@@ -146,11 +146,11 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
|
||||
| SH600000 | 120 |
|
||||
+-----------+-------+
|
||||
|
||||
- CSV file **must** include a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
|
||||
- CSV or Parquet file **must** include a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all ... --date_field_name date
|
||||
python scripts/dump_bin.py dump_all ... --date_field_name date --file_suffix <.csv or .parquet>
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
@@ -163,23 +163,23 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
|
||||
+---------+------------+-------+------+----------+
|
||||
|
||||
|
||||
Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv_data/my_data``, they can run the following command to start the conversion.
|
||||
Supposed that users prepare their CSV or Parquet format data in the directory ``~/.qlib/my_data``, they can run the following command to start the conversion.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
|
||||
python scripts/dump_bin.py dump_all --data_path ~/.qlib/my_data --qlib_dir ~/.qlib/qlib_data/ --include_fields open,close,high,low,volume,factor --file_suffix <.csv or .parquet>
|
||||
|
||||
For other supported parameters when dumping the data into `.bin` file, users can refer to the information by running the following commands:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python dump_bin.py dump_all --help
|
||||
python scripts/dump_bin.py dump_all --help
|
||||
|
||||
After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/my_data`.
|
||||
After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/`.
|
||||
|
||||
.. note::
|
||||
|
||||
The arguments of `--include_fields` should correspond with the column names of CSV files. The columns names of dataset provided by ``Qlib`` should include open, close, high, low, volume and factor at least.
|
||||
The arguments of `--include_fields` should correspond with the column names of CSV or Parquet files. The columns names of dataset provided by ``Qlib`` should include open, close, high, low, volume and factor at least.
|
||||
|
||||
- `open`
|
||||
The adjusted opening price
|
||||
@@ -195,7 +195,58 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
The Restoration factor. Normally, ``factor = adjusted_price / original_price``, `adjusted price` reference: `split adjusted <https://www.investopedia.com/terms/s/splitadjusted.asp>`_
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
If you want to use your own alpha-factor which can't be calculate by OCHLV, like PE, EPS and so on, you could add it to the CSV files with OHCLV together and then dump it to the Qlib format data.
|
||||
If you want to use your own alpha-factor which can't be calculate by OCHLV, like PE, EPS and so on, you could add it to the CSV or Parquet files with OHCLV together and then dump it to the Qlib format data.
|
||||
|
||||
Checking the health of the data
|
||||
-------------------------------
|
||||
|
||||
``Qlib`` provides a script to check the health of the data.
|
||||
|
||||
- The main points to check are as follows
|
||||
|
||||
- Check if any data is missing in the DataFrame.
|
||||
|
||||
- Check if there are any large step changes above the threshold in the OHLCV columns.
|
||||
|
||||
- Check if any of the required columns (OLHCV) are missing in the DataFrame.
|
||||
|
||||
- Check if the 'factor' column is missing in the DataFrame.
|
||||
|
||||
- You can run the following commands to check whether the data is healthy or not.
|
||||
|
||||
for daily data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
|
||||
for 1min data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data_1min --freq 1min
|
||||
|
||||
- Of course, you can also add some parameters to adjust the test results.
|
||||
|
||||
- The available parameters are these.
|
||||
|
||||
- freq: Frequency of data.
|
||||
|
||||
- large_step_threshold_price: Maximum permitted price change
|
||||
|
||||
- large_step_threshold_volume: Maximum permitted volume change.
|
||||
|
||||
- missing_data_num: Maximum value for which data is allowed to be null.
|
||||
|
||||
- You can run the following commands to check whether the data is healthy or not.
|
||||
|
||||
for daily data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data --missing_data_num 30055 --large_step_threshold_volume 94485 --large_step_threshold_price 20
|
||||
|
||||
for 1min data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/check_data_health.py check_data --qlib_dir ~/.qlib/qlib_data/cn_data --freq 1min --missing_data_num 35806 --large_step_threshold_volume 3205452000000 --large_step_threshold_price 0.91
|
||||
|
||||
Stock Pool (Market)
|
||||
-------------------
|
||||
|
||||
@@ -25,7 +25,7 @@ The design of the framework is shown in the yellow part in the middle of the fig
|
||||
|
||||
The frequency of the trading algorithm, decision content and execution environment can be customized by users (e.g. intraday trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The flexibility of the nested decision execution framework makes it easy for users to explore the effects of combining different levels of trading strategies and break down the optimization barriers between different levels of the trading algorithm.
|
||||
|
||||
The optimization for the nested decision execution framework can be implemented with the support of `QlibRL <https://qlib.readthedocs.io/en/latest/component/rl.html>`_. To know more about how to use the QlibRL, go to API Reference: `RL API <../reference/api.html#rl>`_.
|
||||
The optimization for the nested decision execution framework can be implemented with the support of `QlibRL <./rl/overall.html>`_. To know more about how to use the QlibRL, go to API Reference: `RL API <../reference/api.html#rl>`_.
|
||||
|
||||
Example
|
||||
=======
|
||||
|
||||
@@ -55,13 +55,16 @@ Below is a typical config file of ``qrun``.
|
||||
n_drop: 5
|
||||
signal: <PRED>
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
@@ -107,7 +110,7 @@ If users want to use ``qrun`` under debug mode, please use the following command
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pdb qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ How to use qlib images
|
||||
.. code-block:: bash
|
||||
|
||||
>>> python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
>>> python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
>>> python qlib/cli/run.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
3. Exit the container
|
||||
|
||||
|
||||
@@ -599,7 +599,7 @@ class TemporalFusionTransformer:
|
||||
print("Getting valid sampling locations.")
|
||||
valid_sampling_locations = []
|
||||
split_data_map = {}
|
||||
for identifier, df in data.groupby(id_col):
|
||||
for identifier, df in data.groupby(id_col, group_key=False):
|
||||
print("Getting locations for {}".format(identifier))
|
||||
num_entries = len(df)
|
||||
if num_entries >= self.time_steps:
|
||||
@@ -678,7 +678,7 @@ class TemporalFusionTransformer:
|
||||
input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
|
||||
|
||||
data_map = {}
|
||||
for _, sliced in data.groupby(id_col):
|
||||
for _, sliced in data.groupby(id_col, group_keys=False):
|
||||
col_mappings = {"identifier": [id_col], "time": [time_col], "outputs": [target_col], "inputs": input_cols}
|
||||
|
||||
for k in col_mappings:
|
||||
|
||||
@@ -78,13 +78,15 @@ DATASET_SETTING = {
|
||||
|
||||
|
||||
def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
|
||||
return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
|
||||
return data_df[[col_shift]].groupby("instrument", group_keys=False).apply(lambda df: df.shift(shifts))
|
||||
|
||||
|
||||
def fill_test_na(test_df):
|
||||
test_df_res = test_df.copy()
|
||||
feature_cols = ~test_df_res.columns.str.contains("label", case=False)
|
||||
test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
|
||||
test_feature_fna = (
|
||||
test_df_res.loc[:, feature_cols].groupby("datetime", group_keys=False).apply(lambda df: df.fillna(df.mean()))
|
||||
)
|
||||
test_df_res.loc[:, feature_cols] = test_feature_fna
|
||||
return test_df_res
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def _create_ts_slices(index, seq_len):
|
||||
assert index.is_lexsorted(), "index should be sorted"
|
||||
|
||||
# number of dates for each code
|
||||
sample_count_by_codes = pd.Series(0, index=index).groupby(level=0).size().values
|
||||
sample_count_by_codes = pd.Series(0, index=index).groupby(level=0, group_keys=False).size().values
|
||||
|
||||
# start_index for each code
|
||||
start_index_of_codes = np.roll(np.cumsum(sample_count_by_codes), 1)
|
||||
|
||||
@@ -110,7 +110,6 @@ task:
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
|
||||
@@ -104,7 +104,6 @@ task:
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
|
||||
@@ -104,7 +104,6 @@ task:
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size: 6
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
|
||||
@@ -7,7 +7,7 @@ The table below shows the performances of different solutions on different forec
|
||||
## Alpha158 Dataset
|
||||
Here is the [crowd sourced version of qlib data](data_collector/crowd_source/README.md): https://github.com/chenditc/investment_data/releases
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
wget https://github.com/chenditc/investment_data/releases/latest/download/qlib_bin.tar.gz
|
||||
mkdir -p ~/.qlib/qlib_data/cn_data
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
rm -f qlib_bin.tar.gz
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -25,7 +25,7 @@ class DayLast(ElemOperator):
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("last")
|
||||
return series.groupby(_calendar[series.index], group_keys=False).transform("last")
|
||||
|
||||
|
||||
class FFillNan(ElemOperator):
|
||||
@@ -44,7 +44,7 @@ class FFillNan(ElemOperator):
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="ffill")
|
||||
return series.ffill()
|
||||
|
||||
|
||||
class BFillNan(ElemOperator):
|
||||
@@ -63,7 +63,7 @@ class BFillNan(ElemOperator):
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="bfill")
|
||||
return series.bfill()
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
NOTE:
|
||||
- This scripts is a demo to import example data import Qlib
|
||||
- !!!!!!!!!!!!!!!TODO!!!!!!!!!!!!!!!!!!!:
|
||||
- Its structure is not well designed and very ugly, your contribution is welcome to make importing dataset easier
|
||||
NOTE:
|
||||
- This scripts is a demo to import example data import Qlib
|
||||
- !!!!!!!!!!!!!!!TODO!!!!!!!!!!!!!!!!!!!:
|
||||
- Its structure is not well designed and very ugly, your contribution is welcome to make importing dataset easier
|
||||
"""
|
||||
from datetime import date, datetime as dt
|
||||
import os
|
||||
|
||||
@@ -7,7 +7,7 @@ This folder comprises an example of Reinforcement Learning (RL) workflows for or
|
||||
### Get Data
|
||||
|
||||
```
|
||||
python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region hs300 --interval 5min
|
||||
python -m qlib.cli.data qlib_data --target_dir ./data/bin --region hs300 --interval 5min
|
||||
```
|
||||
|
||||
### Generate Pickle-Style Data
|
||||
|
||||
@@ -19,9 +19,9 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
|
||||
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
df = df.set_index(["instrument", "datetime", "date"])
|
||||
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
|
||||
df = df.groupby("date", group_keys=False).take(range(start_idx, end_idx)).droplevel(level=0)
|
||||
|
||||
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
|
||||
order_all = pd.DataFrame(df.groupby(level=(2, 0), group_keys=False).mean().dropna())
|
||||
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
|
||||
order_all = order_all[order_all["amount"] > 0.0]
|
||||
order_all["order_type"] = 0
|
||||
|
||||
@@ -171,7 +171,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import plotly.graph_objects as go\n",
|
||||
"import plotly.io as pio\n",
|
||||
"\n",
|
||||
"pio.renderers.default = \"notebook\"\n",
|
||||
"fig = go.Figure(\n",
|
||||
" data=[\n",
|
||||
" go.Candlestick(\n",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Qlib provides two kinds of interfaces.
|
||||
Qlib provides two kinds of interfaces.
|
||||
(1) Users could define the Quant research workflow by a simple configuration.
|
||||
(2) Qlib is designed in a modularized way and supports creating research workflow by code just like building blocks.
|
||||
|
||||
|
||||
@@ -26,7 +26,10 @@ readme = {file = "README.md", content-type = "text/markdown"}
|
||||
dependencies = [
|
||||
"pyyaml",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"pandas>=0.24",
|
||||
# I encoutered an Error that the set_uri does not work when downloading artifacts in mlflow 3.1.1;
|
||||
# But earlier versions of mlflow does not have this problem.
|
||||
# But when I switch to 2.*.* version, another error occurs, which is even more strange...
|
||||
"mlflow",
|
||||
"filelock>=3.16.0",
|
||||
"redis",
|
||||
@@ -44,6 +47,8 @@ dependencies = [
|
||||
"matplotlib",
|
||||
"jupyter",
|
||||
"nbconvert",
|
||||
"pyarrow",
|
||||
"pydantic-settings",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -66,10 +71,17 @@ lint = [
|
||||
"flake8",
|
||||
"nbqa",
|
||||
]
|
||||
# snowballstemmer, a dependency of sphinx, was released on 2025-05-08 with version 3.0.0,
|
||||
# which causes errors in the build process. So we've limited the version for now.
|
||||
docs = [
|
||||
# After upgrading scipy to version 1.16.0,
|
||||
# we encountered ImportError: cannot import name '_lazywhere', in the build documentation,
|
||||
# so we restricted the version of scipy to: 1.15.3
|
||||
"scipy<=1.15.3",
|
||||
"sphinx",
|
||||
"sphinx_rtd_theme",
|
||||
"readthedocs_sphinx_ext",
|
||||
"snowballstemmer<3.0",
|
||||
]
|
||||
package = [
|
||||
"twine",
|
||||
@@ -82,12 +94,17 @@ test = [
|
||||
]
|
||||
analysis = [
|
||||
"plotly",
|
||||
"statsmodels",
|
||||
]
|
||||
|
||||
# In the process of releasing a new version, when checking the manylinux package with twine, an error is reported:
|
||||
# InvalidDistribution: Invalid distribution metadata: unrecognized or malformed field 'license-file'
|
||||
# To solve this problem, we added license-files here. Refs: https://github.com/pypa/twine/issues/1216
|
||||
[tool.setuptools]
|
||||
packages = [
|
||||
"qlib",
|
||||
]
|
||||
license-files = []
|
||||
|
||||
[project.scripts]
|
||||
qrun = "qlib.workflow.cli:run"
|
||||
qrun = "qlib.cli.run:run"
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.9.6"
|
||||
__version__ = "0.9.7"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
import re
|
||||
from typing import Union
|
||||
from ruamel.yaml import YAML
|
||||
import logging
|
||||
@@ -80,34 +81,41 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
LOG = get_module_logger("mount nfs", level=logging.INFO)
|
||||
if mount_path is None:
|
||||
raise ValueError(f"Invalid mount path: {mount_path}!")
|
||||
if not re.match(r"^[a-zA-Z0-9.:/\-_]+$", provider_uri):
|
||||
raise ValueError(f"Invalid provider_uri format: {provider_uri}")
|
||||
# FIXME: the C["provider_uri"] is modified in this function
|
||||
# If it is not modified, we can pass only provider_uri or mount_path instead of C
|
||||
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
|
||||
mount_command = ["sudo", "mount.nfs", provider_uri, mount_path]
|
||||
# If the provider uri looks like this 172.23.233.89//data/csdesign'
|
||||
# It will be a nfs path. The client provider will be used
|
||||
if not auto_mount: # pylint: disable=R1702
|
||||
if not Path(mount_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
f"Invalid mount path: {mount_path}! Please mount manually: {' '.join(mount_command)} or Set init parameter `auto_mount=True`"
|
||||
)
|
||||
else:
|
||||
# Judging system type
|
||||
sys_type = platform.system()
|
||||
if "windows" in sys_type.lower():
|
||||
# system: window
|
||||
exec_result = os.popen(f"mount -o anon {provider_uri} {mount_path}")
|
||||
result = exec_result.read()
|
||||
if "85" in result:
|
||||
LOG.warning(f"{provider_uri} on Windows:{mount_path} is already mounted")
|
||||
elif "53" in result:
|
||||
raise OSError("not find network path")
|
||||
elif "error" in result or "错误" in result:
|
||||
raise OSError("Invalid mount path")
|
||||
elif provider_uri in result:
|
||||
LOG.info("window success mount..")
|
||||
else:
|
||||
raise OSError(f"unknown error: {result}")
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["mount", "-o", "anon", provider_uri, mount_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
LOG.info("Mount finished.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_output = (e.stdout or "") + (e.stderr or "")
|
||||
if e.returncode == 85:
|
||||
LOG.warning(f"{provider_uri} already mounted at {mount_path}")
|
||||
elif e.returncode == 53:
|
||||
raise OSError("Network path not found") from e
|
||||
elif "error" in error_output.lower() or "错误" in error_output:
|
||||
raise OSError("Invalid mount path") from e
|
||||
else:
|
||||
raise OSError(f"Unknown mount error: {error_output.strip()}") from e
|
||||
else:
|
||||
# system: linux/Unix/Mac
|
||||
# check mount
|
||||
@@ -119,12 +127,13 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
_is_mount = False
|
||||
while _check_level_num:
|
||||
with subprocess.Popen(
|
||||
'mount | grep "{}"'.format(_remote_uri),
|
||||
shell=True,
|
||||
["mount"],
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as shell_r:
|
||||
_command_log = shell_r.stdout.readlines()
|
||||
_command_log = [line for line in _command_log if _remote_uri in line]
|
||||
if len(_command_log) > 0:
|
||||
for _c in _command_log:
|
||||
_temp_mount = _c.decode("utf-8").split(" ")[2]
|
||||
@@ -152,16 +161,16 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
if not command_res:
|
||||
raise OSError("nfs-common is not found, please install it by execute: sudo apt install nfs-common")
|
||||
# manually mount
|
||||
command_status = os.system(mount_command)
|
||||
if command_status == 256:
|
||||
raise OSError(
|
||||
f"mount {provider_uri} on {mount_path} error! Needs SUDO! Please mount manually: {mount_command}"
|
||||
)
|
||||
elif command_status == 32512:
|
||||
# LOG.error("Command error")
|
||||
raise OSError(f"mount {provider_uri} on {mount_path} error! Command error")
|
||||
elif command_status == 0:
|
||||
LOG.info("Mount finished")
|
||||
try:
|
||||
subprocess.run(mount_command, check=True, capture_output=True, text=True)
|
||||
LOG.info("Mount finished.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
if e.returncode == 256:
|
||||
raise OSError("Mount failed: requires sudo or permission denied") from e
|
||||
elif e.returncode == 32512:
|
||||
raise OSError(f"mount {provider_uri} on {mount_path} error! Command error") from e
|
||||
else:
|
||||
raise OSError(f"Mount failed: {e.stderr}") from e
|
||||
else:
|
||||
LOG.warning(f"{_remote_uri} on {_mount_path} is already mounted")
|
||||
|
||||
|
||||
@@ -897,6 +897,7 @@ class Exchange:
|
||||
# if we don't know current position, we choose to sell all
|
||||
# Otherwise, we clip the amount based on current position
|
||||
if position is not None:
|
||||
# TODO: make the trading shortable
|
||||
current_amount = (
|
||||
position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
|
||||
)
|
||||
|
||||
@@ -104,7 +104,7 @@ class PandasQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
|
||||
super().__init__(quote_df=quote_df, freq=freq)
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument", group_keys=False):
|
||||
quote_dict[stock_id] = stock_val.droplevel(level="instrument")
|
||||
self.data = quote_dict
|
||||
|
||||
@@ -137,7 +137,7 @@ class NumpyQuote(BaseQuote):
|
||||
"""
|
||||
super().__init__(quote_df=quote_df, freq=freq)
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument", group_keys=False):
|
||||
quote_dict[stock_id] = idd.MultiData(stock_val.droplevel(level="instrument"))
|
||||
quote_dict[stock_id].sort_index() # To support more flexible slicing, we must sort data first
|
||||
self.data = quote_dict
|
||||
|
||||
@@ -311,7 +311,7 @@ class Position(BasePosition):
|
||||
freq=freq,
|
||||
disk_cache=True,
|
||||
).dropna()
|
||||
price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict()
|
||||
price_dict = price_df.groupby(["instrument"], group_keys=False).tail(1)["$close"].to_dict()
|
||||
|
||||
if len(price_dict) < len(stock_list):
|
||||
lack_stock = set(stock_list) - set(price_dict)
|
||||
|
||||
@@ -114,7 +114,11 @@ class PortfolioMetrics:
|
||||
_temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq)
|
||||
if len(_temp_result) == 0:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
return (
|
||||
_temp_result.groupby(level="datetime", group_keys=False)[_temp_result.columns.tolist()[0]]
|
||||
.mean()
|
||||
.fillna(0)
|
||||
)
|
||||
|
||||
def _sample_benchmark(
|
||||
self,
|
||||
@@ -427,6 +431,10 @@ class Indicator:
|
||||
# NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
|
||||
# ~(np.nan < 1e-8) -> ~(False) -> True
|
||||
|
||||
# if price_s is empty
|
||||
if price_s.empty:
|
||||
return None, None
|
||||
|
||||
assert isinstance(price_s, idd.SingleData)
|
||||
if agg == "vwap":
|
||||
volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
|
||||
|
||||
@@ -87,7 +87,7 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
"""
|
||||
This is a Qlib CLI entrance.
|
||||
User can run the whole Quant research workflow defined by a configure file
|
||||
- the code is located here ``qlib/workflow/cli.py`
|
||||
- the code is located here ``qlib/cli/run.py`
|
||||
|
||||
User can specify a base_config file in your workflow.yml file by adding "BASE_CONFIG_PATH".
|
||||
Qlib will load the configuration in BASE_CONFIG_PATH first, and the user only needs to update the custom fields
|
||||
@@ -27,6 +27,38 @@ from qlib.constant import REG_CN, REG_US, REG_TW
|
||||
if TYPE_CHECKING:
|
||||
from qlib.utils.time import Freq
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class MLflowSettings(BaseSettings):
|
||||
uri: str = "file:" + str(Path(os.getcwd()).resolve() / "mlruns")
|
||||
default_exp_name: str = "Experiment"
|
||||
|
||||
|
||||
class QSettings(BaseSettings):
|
||||
"""
|
||||
Qlib's settings.
|
||||
It tries to provide a default settings for most of Qlib's components.
|
||||
But it would be a long journey to provide a comprehensive settings for all of Qlib's components.
|
||||
|
||||
Here is some design guidelines:
|
||||
- The priority of settings is
|
||||
- Actively passed-in settings, like `qlib.init(provider_uri=...)`
|
||||
- The default settings
|
||||
- QSettings tries to provide default settings for most of Qlib's components.
|
||||
"""
|
||||
|
||||
mlflow: MLflowSettings = MLflowSettings()
|
||||
provider_uri: str = "~/.qlib/qlib_data/cn_data"
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="QLIB_",
|
||||
env_nested_delimiter="_",
|
||||
)
|
||||
|
||||
|
||||
QSETTINGS = QSettings()
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, default_conf):
|
||||
@@ -187,8 +219,8 @@ _default_config = {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "file:" + str(Path(os.getcwd()).resolve() / "mlruns"),
|
||||
"default_exp_name": "Experiment",
|
||||
"uri": QSETTINGS.mlflow.uri,
|
||||
"default_exp_name": QSETTINGS.mlflow.default_exp_name,
|
||||
},
|
||||
},
|
||||
"pit_record_type": {
|
||||
@@ -230,7 +262,7 @@ MODE_CONF = {
|
||||
},
|
||||
"client": {
|
||||
# config it in user's own code
|
||||
"provider_uri": "~/.qlib/qlib_data/cn_data",
|
||||
"provider_uri": QSETTINGS.provider_uri,
|
||||
# cache
|
||||
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
|
||||
# Disable cache by default. Avoid introduce advanced features for beginners
|
||||
|
||||
@@ -6,6 +6,8 @@ import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.utils.data import guess_horizon
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
@@ -32,7 +34,7 @@ def _create_ts_slices(index, seq_len):
|
||||
assert index.is_monotonic_increasing, "index should be sorted"
|
||||
|
||||
# number of dates for each instrument
|
||||
sample_count_by_insts = index.to_series().groupby(level=0).size().values
|
||||
sample_count_by_insts = index.to_series().groupby(level=0, group_keys=False).size().values
|
||||
|
||||
# start index for each instrument
|
||||
start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1)
|
||||
@@ -130,6 +132,14 @@ class MTSDatasetH(DatasetH):
|
||||
input_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
if horizon == 0:
|
||||
# Try to guess horizon
|
||||
if isinstance(handler, (dict, str)):
|
||||
handler = init_instance_by_config(handler)
|
||||
assert "label" in getattr(handler.data_loader, "fields", None)
|
||||
label = handler.data_loader.fields["label"][0][0]
|
||||
horizon = guess_horizon([label])
|
||||
|
||||
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
assert memory_mode in ["sample", "daily"], "unsupported memory mode"
|
||||
assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)"
|
||||
|
||||
@@ -55,14 +55,18 @@ class ConfigSectionProcessor(Processor):
|
||||
|
||||
# Label
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^LABEL")]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_label_norm)
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime", group_keys=False).apply(_label_norm)
|
||||
|
||||
# Features
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x**0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = (
|
||||
df_focus[cols].apply(lambda x: x**0.25).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x**0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = (
|
||||
df_focus[cols].apply(lambda x: x**0.5).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
)
|
||||
|
||||
_cols = [
|
||||
"KMID",
|
||||
@@ -88,25 +92,35 @@ class ConfigSectionProcessor(Processor):
|
||||
]
|
||||
pat = "|".join(["^" + x for x in _cols])
|
||||
cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin(["HIGH0", "LOW0"]))]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^RSQR")]
|
||||
df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MAX|^HIGH0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = (
|
||||
df_focus[cols]
|
||||
.apply(lambda x: (x - 1) ** 0.5)
|
||||
.groupby(level="datetime", group_keys=False)
|
||||
.apply(_feature_norm)
|
||||
)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MIN|^LOW0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = (
|
||||
df_focus[cols]
|
||||
.apply(lambda x: (1 - x) ** 0.5)
|
||||
.groupby(level="datetime", group_keys=False)
|
||||
.apply(_feature_norm)
|
||||
)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^CORR|^CORD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^WVMA")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime", group_keys=False).apply(_feature_norm)
|
||||
|
||||
df[selected_cols] = df_focus.values
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ def calc_long_short_prec(
|
||||
long precision and short precision in time level
|
||||
"""
|
||||
if is_alpha:
|
||||
label = label - label.mean(level=date_col)
|
||||
label = label - label.groupby(level=date_col, group_keys=False).mean()
|
||||
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
|
||||
raise ValueError("Need more instruments to calculate precision")
|
||||
|
||||
@@ -47,23 +47,25 @@ def calc_long_short_prec(
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
|
||||
group = df.groupby(level=date_col)
|
||||
group = df.groupby(level=date_col, group_keys=False)
|
||||
|
||||
def N(x):
|
||||
return int(len(x) * quantile)
|
||||
|
||||
# find the top/low quantile of prediction and treat them as long and short target
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label)
|
||||
|
||||
groupll = long.groupby(date_col)
|
||||
groupll = long.groupby(date_col, group_keys=False)
|
||||
l_dom = groupll.apply(lambda x: x > 0)
|
||||
l_c = groupll.count()
|
||||
|
||||
groups = short.groupby(date_col)
|
||||
groups = short.groupby(date_col, group_keys=False)
|
||||
s_dom = groups.apply(lambda x: x < 0)
|
||||
s_c = groups.count()
|
||||
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
|
||||
return (l_dom.groupby(date_col, group_keys=False).sum() / l_c), (
|
||||
s_dom.groupby(date_col, group_keys=False).sum() / s_c
|
||||
)
|
||||
|
||||
|
||||
def calc_long_short_return(
|
||||
@@ -100,7 +102,7 @@ def calc_long_short_return(
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
group = df.groupby(level=date_col)
|
||||
group = df.groupby(level=date_col, group_keys=False)
|
||||
|
||||
def N(x):
|
||||
return int(len(x) * quantile)
|
||||
@@ -173,8 +175,8 @@ def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
ic = df.groupby(date_col, group_keys=False).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col, group_keys=False).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
|
||||
@@ -7,7 +7,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import warnings
|
||||
from typing import Union
|
||||
from typing import Union, Literal
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..utils import get_date_range
|
||||
@@ -24,16 +24,14 @@ from ..data.dataset.utils import get_level_index
|
||||
logger = get_module_logger("Evaluate")
|
||||
|
||||
|
||||
def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
def risk_analysis(r, N: int = None, freq: str = "day", mode: Literal["sum", "product"] = "sum"):
|
||||
"""Risk Analysis
|
||||
NOTE:
|
||||
The calculation of annulaized return is different from the definition of annualized return.
|
||||
The calculation of annualized return is different from the definition of annualized return.
|
||||
It is implemented by design.
|
||||
Qlib tries to cumulated returns by summation instead of production to avoid the cumulated curve being skewed exponentially.
|
||||
Qlib tries to cumulate returns by summation instead of production to avoid the cumulated curve being skewed exponentially.
|
||||
All the calculation of annualized returns follows this principle in Qlib.
|
||||
|
||||
TODO: add a parameter to enable calculating metrics with production accumulation of return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
r : pandas.Series
|
||||
@@ -42,11 +40,14 @@ def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
scaler for annualizing information_ratio (day: 252, week: 50, month: 12), at least one of `N` and `freq` should exist
|
||||
freq: str
|
||||
analysis frequency used for calculating the scaler, at least one of `N` and `freq` should exist
|
||||
mode: Literal["sum", "product"]
|
||||
the method by which returns are accumulated:
|
||||
- "sum": Arithmetic accumulation (linear returns).
|
||||
- "product": Geometric accumulation (compounded returns).
|
||||
"""
|
||||
|
||||
def cal_risk_analysis_scaler(freq):
|
||||
_count, _freq = Freq.parse(freq)
|
||||
# len(D.calendar(start_time='2010-01-01', end_time='2019-12-31', freq='day')) = 2384
|
||||
_freq_scaler = {
|
||||
Freq.NORM_FREQ_MINUTE: 240 * 238,
|
||||
Freq.NORM_FREQ_DAY: 238,
|
||||
@@ -62,11 +63,26 @@ def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
if N is None:
|
||||
N = cal_risk_analysis_scaler(freq)
|
||||
|
||||
mean = r.mean()
|
||||
std = r.std(ddof=1)
|
||||
annualized_return = mean * N
|
||||
if mode == "sum":
|
||||
mean = r.mean()
|
||||
std = r.std(ddof=1)
|
||||
annualized_return = mean * N
|
||||
max_drawdown = (r.cumsum() - r.cumsum().cummax()).min()
|
||||
elif mode == "product":
|
||||
cumulative_curve = (1 + r).cumprod()
|
||||
# geometric mean (compound annual growth rate)
|
||||
mean = cumulative_curve.iloc[-1] ** (1 / len(r)) - 1
|
||||
# volatility of log returns
|
||||
std = np.log(1 + r).std(ddof=1)
|
||||
|
||||
cumulative_return = cumulative_curve.iloc[-1] - 1
|
||||
annualized_return = (1 + cumulative_return) ** (N / len(r)) - 1
|
||||
# max percentage drawdown from peak cumulative product
|
||||
max_drawdown = (cumulative_curve / cumulative_curve.cummax() - 1).min()
|
||||
else:
|
||||
raise ValueError(f"risk_analysis accumulation mode {mode} is not supported. Expected `sum` or `product`.")
|
||||
|
||||
information_ratio = mean / std * np.sqrt(N)
|
||||
max_drawdown = (r.cumsum() - r.cumsum().cummax()).min()
|
||||
data = {
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
|
||||
@@ -106,7 +106,7 @@ class InternalData:
|
||||
|
||||
def _calc_perf(self, pred, label):
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
df = df.groupby("datetime").corr(method="spearman")
|
||||
df = df.groupby("datetime", group_keys=False).corr(method="spearman")
|
||||
corr = df.loc(axis=0)[:, "pred"]["label"].droplevel(axis=0, level=-1)
|
||||
return corr
|
||||
|
||||
@@ -161,7 +161,7 @@ class MetaTaskDS(MetaTask):
|
||||
raise ValueError(f"Most of samples are dropped. Please check this task: {task}")
|
||||
|
||||
assert (
|
||||
d_test.groupby("datetime").size().shape[0] >= 5
|
||||
d_test.groupby("datetime", group_keys=False).size().shape[0] >= 5
|
||||
), "In this segment, this trading dates is less than 5, you'd better check the data."
|
||||
|
||||
sample_time_belong = np.zeros((d_train.shape[0], time_perf.shape[1]))
|
||||
|
||||
@@ -125,7 +125,11 @@ class MetaModelDS(MetaTaskModel):
|
||||
loss_l.setdefault(phase, []).append(running_loss)
|
||||
|
||||
pred_y_all = pd.concat(pred_y_all)
|
||||
ic = pred_y_all.groupby("datetime").apply(lambda df: df["pred"].corr(df["label"], method="spearman")).mean()
|
||||
ic = (
|
||||
pred_y_all.groupby("datetime", group_keys=False)
|
||||
.apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
.mean()
|
||||
)
|
||||
|
||||
R.log_metrics(**{f"loss/{phase}": running_loss, "step": epoch})
|
||||
R.log_metrics(**{f"ic/{phase}": ic, "step": epoch})
|
||||
|
||||
@@ -166,7 +166,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
|
||||
# calculate weights
|
||||
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
|
||||
h_avg = h.groupby("bins")["h_value"].mean()
|
||||
h_avg = h.groupby("bins", group_keys=False, observed=False)["h_value"].mean()
|
||||
weights = pd.Series(np.zeros(N, dtype=float))
|
||||
for b in h_avg.index:
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay**k_th * h_avg[b] + 0.1)
|
||||
|
||||
@@ -90,8 +90,14 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
l_name = df_train["label"].columns[0]
|
||||
# Convert label into alpha
|
||||
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
|
||||
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
|
||||
df_train.loc[:, ("label", l_name)] = (
|
||||
df_train.loc[:, ("label", l_name)]
|
||||
- df_train.loc[:, ("label", l_name)].groupby(level=0, group_keys=False).mean()
|
||||
)
|
||||
df_valid.loc[:, ("label", l_name)] = (
|
||||
df_valid.loc[:, ("label", l_name)]
|
||||
- df_valid.loc[:, ("label", l_name)].groupby(level=0, group_keys=False).mean()
|
||||
)
|
||||
|
||||
def mapping_fn(x):
|
||||
return 0 if x < 0 else 1
|
||||
|
||||
@@ -214,8 +214,10 @@ class ADARNN(Model):
|
||||
def calc_all_metrics(pred):
|
||||
"""pred is a pandas dataframe that has two attributes: score (pred) and label (real)"""
|
||||
res = {}
|
||||
ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score))
|
||||
rank_ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score, method="spearman"))
|
||||
ic = pred.groupby(level="datetime", group_keys=False).apply(lambda x: x.label.corr(x.score))
|
||||
rank_ic = pred.groupby(level="datetime", group_keys=False).apply(
|
||||
lambda x: x.label.corr(x.score, method="spearman")
|
||||
)
|
||||
res["ic"] = ic.mean()
|
||||
res["icir"] = ic.mean() / ic.std()
|
||||
res["ric"] = rank_ic.mean()
|
||||
|
||||
@@ -226,7 +226,7 @@ class ADD(Model):
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_count = df.groupby(level=0, group_keys=False).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
@@ -349,7 +349,7 @@ class ADD(Model):
|
||||
return best_score
|
||||
|
||||
def gen_market_label(self, df, raw_label):
|
||||
market_label = raw_label.groupby("datetime").mean().squeeze()
|
||||
market_label = raw_label.groupby("datetime", group_keys=False).mean().squeeze()
|
||||
bins = [-np.inf, self.lo, self.hi, np.inf]
|
||||
market_label = pd.cut(market_label, bins, labels=False)
|
||||
market_label.name = ("market_return", "market_return")
|
||||
@@ -357,7 +357,7 @@ class ADD(Model):
|
||||
return df
|
||||
|
||||
def fit_thresh(self, train_label):
|
||||
market_label = train_label.groupby("datetime").mean().squeeze()
|
||||
market_label = train_label.groupby("datetime", group_keys=False).mean().squeeze()
|
||||
self.lo, self.hi = market_label.quantile([1 / 3, 2 / 3])
|
||||
|
||||
def fit(
|
||||
|
||||
@@ -163,7 +163,7 @@ class GATs(Model):
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_count = df.groupby(level=0, group_keys=False).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
|
||||
@@ -27,7 +27,9 @@ class DailyBatchSampler(Sampler):
|
||||
def __init__(self, data_source):
|
||||
self.data_source = data_source
|
||||
# calculate number of samples in each batch
|
||||
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values
|
||||
self.daily_count = (
|
||||
pd.Series(index=self.data_source.get_index()).groupby("datetime", group_keys=False).size().values
|
||||
)
|
||||
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
|
||||
self.daily_index[0] = 0
|
||||
|
||||
@@ -181,7 +183,7 @@ class GATs(Model):
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_count = df.groupby(level=0, group_keys=False).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
|
||||
@@ -13,6 +13,7 @@ import copy
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
@@ -136,6 +137,10 @@ class GeneralPTNN(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
# === ReduceLROnPlateau learning rate scheduler ===
|
||||
self.lr_scheduler = ReduceLROnPlateau(
|
||||
self.train_optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6, threshold=1e-5
|
||||
)
|
||||
self.fitted = False
|
||||
self.dnn_model.to(self.device)
|
||||
|
||||
@@ -154,7 +159,7 @@ class GeneralPTNN(Model):
|
||||
weight = torch.ones_like(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask], weight[mask])
|
||||
return self.mse(pred[mask], label[mask].view(-1, 1), weight[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
@@ -162,7 +167,7 @@ class GeneralPTNN(Model):
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
return self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -238,6 +243,8 @@ class GeneralPTNN(Model):
|
||||
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
self.logger.info(f"Train samples: {len(dl_train)}")
|
||||
self.logger.info(f"Valid samples: {len(dl_valid)}")
|
||||
if dl_train.empty or dl_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
@@ -279,7 +286,7 @@ class GeneralPTNN(Model):
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_score = np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
@@ -295,13 +302,18 @@ class GeneralPTNN(Model):
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(train_loader)
|
||||
val_loss, val_score = self.test_epoch(valid_loader)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
self.logger.info("Epoch%d: train %.6f, valid %.6f" % (step, train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
# current_lr = self.train_optimizer.param_groups[0]["lr"]
|
||||
# self.logger.info("Current learning rate: %.6e" % current_lr)
|
||||
|
||||
self.lr_scheduler.step(val_score)
|
||||
|
||||
if step == 0:
|
||||
best_param = copy.deepcopy(self.dnn_model.state_dict())
|
||||
if val_score > best_score:
|
||||
if val_score < best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
@@ -312,7 +324,7 @@ class GeneralPTNN(Model):
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.logger.info("best score: %.6lf @ %d epoch" % (best_score, best_epoch))
|
||||
self.dnn_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
@@ -329,6 +341,7 @@ class GeneralPTNN(Model):
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
self.logger.info(f"Test samples: {len(dl_test)}")
|
||||
|
||||
if isinstance(dataset, TSDatasetH):
|
||||
dl_test.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
@@ -177,7 +177,7 @@ class HIST(Model):
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_count = df.groupby(level=0, group_keys=False).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
|
||||
@@ -170,7 +170,7 @@ class IGMTF(Model):
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_count = df.groupby(level=0, group_keys=False).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
|
||||
@@ -368,7 +368,7 @@ class KRNN(Model):
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_count = df.groupby(level=0, group_keys=False).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
|
||||
@@ -146,19 +146,34 @@ class DNNModelPytorch(Model):
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
if scheduler == "default":
|
||||
# Reduce learning rate when loss has stopped decrease
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.train_optimizer,
|
||||
mode="min",
|
||||
factor=0.5,
|
||||
patience=10,
|
||||
verbose=True,
|
||||
threshold=0.0001,
|
||||
threshold_mode="rel",
|
||||
cooldown=0,
|
||||
min_lr=0.00001,
|
||||
eps=1e-08,
|
||||
)
|
||||
# In torch version 2.7.0, the verbose parameter has been removed. Reference Link:
|
||||
# https://github.com/pytorch/pytorch/pull/147301/files#diff-036a7470d5307f13c9a6a51c3a65dd014f00ca02f476c545488cd856bea9bcf2L1313
|
||||
if str(torch.__version__).split("+", maxsplit=1)[0] <= "2.6.0":
|
||||
# Reduce learning rate when loss has stopped decrease
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # pylint: disable=E1123
|
||||
self.train_optimizer,
|
||||
mode="min",
|
||||
factor=0.5,
|
||||
patience=10,
|
||||
verbose=True,
|
||||
threshold=0.0001,
|
||||
threshold_mode="rel",
|
||||
cooldown=0,
|
||||
min_lr=0.00001,
|
||||
eps=1e-08,
|
||||
)
|
||||
else:
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.train_optimizer,
|
||||
mode="min",
|
||||
factor=0.5,
|
||||
patience=10,
|
||||
threshold=0.0001,
|
||||
threshold_mode="rel",
|
||||
cooldown=0,
|
||||
min_lr=0.00001,
|
||||
eps=1e-08,
|
||||
)
|
||||
elif scheduler is None:
|
||||
self.scheduler = None
|
||||
else:
|
||||
|
||||
@@ -96,7 +96,7 @@ class DayCumsum(ElemOperator):
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform(self.period_cusum)
|
||||
return series.groupby(_calendar[series.index], group_keys=False).transform(self.period_cusum)
|
||||
|
||||
|
||||
class DayLast(ElemOperator):
|
||||
@@ -116,7 +116,7 @@ class DayLast(ElemOperator):
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("last")
|
||||
return series.groupby(_calendar[series.index], group_keys=False).transform("last")
|
||||
|
||||
|
||||
class FFillNan(ElemOperator):
|
||||
|
||||
@@ -38,7 +38,7 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int
|
||||
t_df = pd.DataFrame(
|
||||
{
|
||||
"Group%d"
|
||||
% (i + 1): pred_label_drop.groupby(level="datetime")["label"].apply(
|
||||
% (i + 1): pred_label_drop.groupby(level="datetime", group_keys=False)["label"].apply(
|
||||
lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean() # pylint: disable=W0640
|
||||
)
|
||||
for i in range(N)
|
||||
@@ -50,7 +50,7 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int
|
||||
t_df["long-short"] = t_df["Group1"] - t_df["Group%d" % N]
|
||||
|
||||
# Long-Average
|
||||
t_df["long-average"] = t_df["Group1"] - pred_label.groupby(level="datetime")["label"].mean()
|
||||
t_df["long-average"] = t_df["Group1"] - pred_label.groupby(level="datetime", group_keys=False)["label"].mean()
|
||||
|
||||
t_df = t_df.dropna(how="all") # for days which does not contain label
|
||||
# Cumulative Return By Group
|
||||
@@ -137,7 +137,9 @@ def _pred_ic(
|
||||
|
||||
ic_df = pd.concat(
|
||||
[
|
||||
pred_label.groupby(level="datetime").apply(partial(_corr_series, method=_methods_mapping[m])).rename(m)
|
||||
pred_label.groupby(level="datetime", group_keys=False)
|
||||
.apply(partial(_corr_series, method=_methods_mapping[m]))
|
||||
.rename(m)
|
||||
for m in methods
|
||||
],
|
||||
axis=1,
|
||||
@@ -145,7 +147,7 @@ def _pred_ic(
|
||||
_ic = ic_df.iloc(axis=1)[0]
|
||||
|
||||
_index = _ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6)
|
||||
_monthly_ic = _ic.groupby(_index).mean()
|
||||
_monthly_ic = _ic.groupby(_index, group_keys=False).mean()
|
||||
_monthly_ic.index = pd.MultiIndex.from_arrays(
|
||||
[_monthly_ic.index.str.slice(0, 4), _monthly_ic.index.str.slice(4, 6)],
|
||||
names=["year", "month"],
|
||||
@@ -220,8 +222,10 @@ def _pred_ic(
|
||||
|
||||
def _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple:
|
||||
pred = pred_label.copy()
|
||||
pred["score_last"] = pred.groupby(level="instrument")["score"].shift(lag)
|
||||
ac = pred.groupby(level="datetime").apply(lambda x: x["score"].rank(pct=True).corr(x["score_last"].rank(pct=True)))
|
||||
pred["score_last"] = pred.groupby(level="instrument", group_keys=False)["score"].shift(lag)
|
||||
ac = pred.groupby(level="datetime", group_keys=False).apply(
|
||||
lambda x: x["score"].rank(pct=True).corr(x["score_last"].rank(pct=True))
|
||||
)
|
||||
_df = ac.to_frame("value")
|
||||
ac_figure = ScatterGraph(
|
||||
_df,
|
||||
@@ -235,13 +239,13 @@ def _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple:
|
||||
|
||||
def _pred_turnover(pred_label: pd.DataFrame, N=5, lag=1, **kwargs) -> tuple:
|
||||
pred = pred_label.copy()
|
||||
pred["score_last"] = pred.groupby(level="instrument")["score"].shift(lag)
|
||||
top = pred.groupby(level="datetime").apply(
|
||||
pred["score_last"] = pred.groupby(level="instrument", group_keys=False)["score"].shift(lag)
|
||||
top = pred.groupby(level="datetime", group_keys=False).apply(
|
||||
lambda x: 1
|
||||
- x.nlargest(len(x) // N, columns="score").index.isin(x.nlargest(len(x) // N, columns="score_last").index).sum()
|
||||
/ (len(x) // N)
|
||||
)
|
||||
bottom = pred.groupby(level="datetime").apply(
|
||||
bottom = pred.groupby(level="datetime", group_keys=False).apply(
|
||||
lambda x: 1
|
||||
- x.nsmallest(len(x) // N, columns="score")
|
||||
.index.isin(x.nsmallest(len(x) // N, columns="score_last").index)
|
||||
@@ -313,7 +317,7 @@ def model_performance_graph(
|
||||
2017-12-15 -0.102778 -0.102778
|
||||
|
||||
|
||||
:param lag: `pred.groupby(level='instrument')['score'].shift(lag)`. It will be only used in the auto-correlation computing.
|
||||
:param lag: `pred.groupby(level='instrument', group_keys=False)['score'].shift(lag)`. It will be only used in the auto-correlation computing.
|
||||
:param N: group number, default 5.
|
||||
:param reverse: if `True`, `pred['score'] *= -1`.
|
||||
:param rank: if **True**, calculate rank ic.
|
||||
|
||||
@@ -38,7 +38,7 @@ def _get_cum_return_data_with_position(
|
||||
|
||||
_cumulative_return_df["label"] = _cumulative_return_df["label"] - _cumulative_return_df["bench"]
|
||||
_cumulative_return_df = _cumulative_return_df.dropna()
|
||||
df_gp = _cumulative_return_df.groupby(level="datetime")
|
||||
df_gp = _cumulative_return_df.groupby(level="datetime", group_keys=False)
|
||||
result_list = []
|
||||
for gp in df_gp:
|
||||
date = gp[0]
|
||||
|
||||
@@ -132,7 +132,7 @@ def _calculate_label_rank(df: pd.DataFrame) -> pd.DataFrame:
|
||||
g_df["excess_return"] = g_df[_label_name] - g_df[_label_name].mean()
|
||||
return g_df
|
||||
|
||||
return df.groupby(level="datetime").apply(_calculate_day_value)
|
||||
return df.groupby(level="datetime", group_keys=False).apply(_calculate_day_value)
|
||||
|
||||
|
||||
def get_position_data(
|
||||
|
||||
@@ -31,7 +31,7 @@ def _get_figure_with_position(
|
||||
)
|
||||
|
||||
res_dict = dict()
|
||||
_pos_gp = _position_df.groupby(level=1)
|
||||
_pos_gp = _position_df.groupby(level=1, group_keys=False)
|
||||
for _item in _pos_gp:
|
||||
_date = _item[0]
|
||||
_day_df = _item[1]
|
||||
|
||||
@@ -63,9 +63,11 @@ def _get_monthly_risk_analysis_with_report(report_normal_df: pd.DataFrame) -> pd
|
||||
"""
|
||||
|
||||
# Group by month
|
||||
report_normal_gp = report_normal_df.groupby([report_normal_df.index.year, report_normal_df.index.month])
|
||||
report_normal_gp = report_normal_df.groupby(
|
||||
[report_normal_df.index.year, report_normal_df.index.month], group_keys=False
|
||||
)
|
||||
# report_long_short_gp = report_long_short_df.groupby(
|
||||
# [report_long_short_df.index.year, report_long_short_df.index.month]
|
||||
# [report_long_short_df.index.year, report_long_short_df.index.month], group_keys=False
|
||||
# )
|
||||
|
||||
gp_month = sorted(set(report_normal_gp.size().index))
|
||||
@@ -97,7 +99,7 @@ def _get_monthly_analysis_with_feature(monthly_df: pd.DataFrame, feature: str =
|
||||
:param feature:
|
||||
:return:
|
||||
"""
|
||||
_monthly_df_gp = monthly_df.reset_index().groupby(["level_1"])
|
||||
_monthly_df_gp = monthly_df.reset_index().groupby(["level_1"], group_keys=False)
|
||||
|
||||
_name_df = _monthly_df_gp.get_group(feature).set_index(["level_0", "level_1"])
|
||||
_temp_df = _name_df.pivot_table(index="date", values=["risk"], columns=_name_df.index)
|
||||
|
||||
@@ -15,8 +15,10 @@ def _get_score_ic(pred_label: pd.DataFrame):
|
||||
"""
|
||||
concat_data = pred_label.copy()
|
||||
concat_data.dropna(axis=0, how="any", inplace=True)
|
||||
_ic = concat_data.groupby(level="datetime").apply(lambda x: x["label"].corr(x["score"]))
|
||||
_rank_ic = concat_data.groupby(level="datetime").apply(lambda x: x["label"].corr(x["score"], method="spearman"))
|
||||
_ic = concat_data.groupby(level="datetime", group_keys=False).apply(lambda x: x["label"].corr(x["score"]))
|
||||
_rank_ic = concat_data.groupby(level="datetime", group_keys=False).apply(
|
||||
lambda x: x["label"].corr(x["score"], method="spearman")
|
||||
)
|
||||
return pd.DataFrame({"ic": _ic, "rank_ic": _rank_ic})
|
||||
|
||||
|
||||
|
||||
@@ -72,10 +72,10 @@ class ValueCNT(FeaAnalyser):
|
||||
self._val_cnt = {}
|
||||
for col, item in self._dataset.items():
|
||||
if not super().skip(col):
|
||||
self._val_cnt[col] = item.groupby(DT_COL_NAME).apply(lambda s: len(s.unique()))
|
||||
self._val_cnt[col] = item.groupby(DT_COL_NAME, group_keys=False).apply(lambda s: len(s.unique()))
|
||||
self._val_cnt = pd.DataFrame(self._val_cnt)
|
||||
if self.ratio:
|
||||
self._val_cnt = self._val_cnt.div(self._dataset.groupby(DT_COL_NAME).size(), axis=0)
|
||||
self._val_cnt = self._val_cnt.div(self._dataset.groupby(DT_COL_NAME, group_keys=False).size(), axis=0)
|
||||
|
||||
# TODO: transfer this feature to other analysers
|
||||
ymin, ymax = self._val_cnt.min().min(), self._val_cnt.max().max()
|
||||
@@ -98,7 +98,7 @@ class FeaInfAna(NumFeaAnalyser):
|
||||
self._inf_cnt = {}
|
||||
for col, item in self._dataset.items():
|
||||
if not super().skip(col):
|
||||
self._inf_cnt[col] = item.apply(np.isinf).astype(np.int).groupby(DT_COL_NAME).sum()
|
||||
self._inf_cnt[col] = item.apply(np.isinf).astype(np.int).groupby(DT_COL_NAME, group_keys=False).sum()
|
||||
self._inf_cnt = pd.DataFrame(self._inf_cnt)
|
||||
|
||||
def skip(self, col):
|
||||
@@ -111,7 +111,7 @@ class FeaInfAna(NumFeaAnalyser):
|
||||
|
||||
class FeaNanAna(FeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME).sum()
|
||||
self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME, group_keys=False).sum()
|
||||
|
||||
def skip(self, col):
|
||||
return (col not in self._nan_cnt) or (self._nan_cnt[col].sum() == 0)
|
||||
@@ -123,8 +123,8 @@ class FeaNanAna(FeaAnalyser):
|
||||
|
||||
class FeaNanAnaRatio(FeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME).sum()
|
||||
self._total_cnt = self._dataset.groupby(DT_COL_NAME).size()
|
||||
self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME, group_keys=False).sum()
|
||||
self._total_cnt = self._dataset.groupby(DT_COL_NAME, group_keys=False).size()
|
||||
|
||||
def skip(self, col):
|
||||
return (col not in self._nan_cnt) or (self._nan_cnt[col].sum() == 0)
|
||||
@@ -176,8 +176,8 @@ class FeaSkewTurt(NumFeaAnalyser):
|
||||
|
||||
class FeaMeanStd(NumFeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._std = self._dataset.groupby(DT_COL_NAME).std()
|
||||
self._mean = self._dataset.groupby(DT_COL_NAME).mean()
|
||||
self._std = self._dataset.groupby(DT_COL_NAME, group_keys=False).std()
|
||||
self._mean = self._dataset.groupby(DT_COL_NAME, group_keys=False).mean()
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._mean[col].plot(ax=ax, label="mean")
|
||||
|
||||
@@ -326,8 +326,10 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
self.instruments = "all"
|
||||
if isinstance(instruments, str):
|
||||
elif isinstance(instruments, str):
|
||||
self.instruments = D.instruments(instruments)
|
||||
elif isinstance(instruments, List):
|
||||
self.instruments = instruments
|
||||
self.freq = freq
|
||||
super(SBBStrategyEMA, self).__init__(
|
||||
outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
@@ -345,7 +347,7 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
self.signal = {}
|
||||
|
||||
if not signal_df.empty:
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument", group_keys=False):
|
||||
self.signal[stock_id] = stock_val["signal"].droplevel(level="instrument")
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
@@ -432,7 +434,7 @@ class ACStrategy(BaseStrategy):
|
||||
self.signal = {}
|
||||
|
||||
if not signal_df.empty:
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument", group_keys=False):
|
||||
self.signal[stock_id] = stock_val["volatility"].droplevel(level="instrument")
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
This module is not a necessary part of Qlib.
|
||||
They are just some tools for convenience
|
||||
It is should not imported into the core part of qlib
|
||||
This module is not a necessary part of Qlib.
|
||||
They are just some tools for convenience
|
||||
It is should not imported into the core part of qlib
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
@@ -842,7 +842,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
def build_index_from_data(data, start_index=0):
|
||||
if data.empty:
|
||||
return pd.DataFrame()
|
||||
line_data = data.groupby("datetime").size()
|
||||
line_data = data.groupby("datetime", group_keys=False).size()
|
||||
line_data.sort_index(inplace=True)
|
||||
index_end = line_data.cumsum()
|
||||
index_start = index_end.shift(1, fill_value=0)
|
||||
|
||||
@@ -226,13 +226,8 @@ class DatasetH(Dataset):
|
||||
------
|
||||
NotImplementedError:
|
||||
"""
|
||||
logger = get_module_logger("DatasetH")
|
||||
seg_kwargs = {"col_set": col_set}
|
||||
seg_kwargs = {"col_set": col_set, "data_key": data_key}
|
||||
seg_kwargs.update(kwargs)
|
||||
if "data_key" in getfullargspec(self.handler.fetch).args:
|
||||
seg_kwargs["data_key"] = data_key
|
||||
else:
|
||||
logger.info(f"data_key[{data_key}] is ignored.")
|
||||
|
||||
# Conflictions may happen here
|
||||
# - The fetched data and the segment key may both be string
|
||||
@@ -240,9 +235,11 @@ class DatasetH(Dataset):
|
||||
# - The segment name will have higher priorities
|
||||
|
||||
# 1) Use it as segment name first
|
||||
# 1.1) directly fetch split like "train" "valid" "test"
|
||||
if isinstance(segments, str) and segments in self.segments:
|
||||
return self._prepare_seg(self.segments[segments], **seg_kwargs)
|
||||
|
||||
# 1.2) fetch multiple splits like ["train", "valid"] ["train", "valid", "test"]
|
||||
if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments):
|
||||
return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments]
|
||||
|
||||
@@ -262,7 +259,7 @@ class DatasetH(Dataset):
|
||||
def _get_extrema(segments, idx: int, cmp: Callable, key_func=pd.Timestamp):
|
||||
"""it will act like sort and return the max value or None"""
|
||||
candidate = None
|
||||
for k, seg in segments.items():
|
||||
for _, seg in segments.items():
|
||||
point = seg[idx]
|
||||
if point is None:
|
||||
# None indicates unbounded, return directly
|
||||
@@ -376,6 +373,8 @@ class TSDataSampler:
|
||||
ffill with previous samples first and fill with later samples second
|
||||
flt_data : pd.Series
|
||||
a column of data(True or False) to filter data. Its index order is <"datetime", "instrument">
|
||||
This feature is essential because:
|
||||
- We want some sample not included due to label-based filtering, but we can't filter them at the beginning due to the features is still important in the feature.
|
||||
None:
|
||||
kepp all data
|
||||
|
||||
@@ -661,8 +660,9 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
DEFAULT_STEP_LEN = 30
|
||||
|
||||
def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
|
||||
def __init__(self, step_len=DEFAULT_STEP_LEN, flt_col: Optional[str] = None, **kwargs):
|
||||
self.step_len = step_len
|
||||
self.flt_col = flt_col
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, **kwargs):
|
||||
@@ -693,10 +693,10 @@ class TSDatasetH(DatasetH):
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
if not isinstance(slc, slice):
|
||||
slc = slice(*slc)
|
||||
start, end = slc.start, slc.stop
|
||||
flt_col = kwargs.pop("flt_col", None)
|
||||
# TSDatasetH will retrieve more data for complete time-series
|
||||
if (flt_col := kwargs.pop("flt_col", None)) is None:
|
||||
flt_col = self.flt_col
|
||||
|
||||
# TSDatasetH will retrieve more data for complete time-series
|
||||
ext_slice = self._extend_slice(slc, self.cal, self.step_len)
|
||||
data = super()._prepare_seg(ext_slice, **kwargs)
|
||||
|
||||
@@ -710,8 +710,8 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
tsds = TSDataSampler(
|
||||
data=data,
|
||||
start=start,
|
||||
end=end,
|
||||
start=slc.start,
|
||||
end=slc.stop,
|
||||
step_len=self.step_len,
|
||||
dtype=dtype,
|
||||
flt_data=flt_data,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# coding=utf-8
|
||||
from abc import abstractmethod
|
||||
import warnings
|
||||
from typing import Callable, Union, Tuple, List, Iterator, Optional
|
||||
|
||||
@@ -19,9 +20,59 @@ from . import processor as processor_module
|
||||
from . import loader as data_loader_module
|
||||
|
||||
|
||||
# TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed.
|
||||
class DataHandler(Serializable):
|
||||
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
|
||||
|
||||
|
||||
class DataHandlerABC(Serializable):
|
||||
"""
|
||||
Interface for data handler.
|
||||
|
||||
This class does not assume the internal data structure of the data handler.
|
||||
It only defines the interface for external users (uses DataFrame as the internal data structure).
|
||||
|
||||
In the future, the data handler's more detailed implementation should be refactored. Here are some guidelines:
|
||||
|
||||
It covers several components:
|
||||
|
||||
- [data loader] -> internal representation of the data -> data preprocessing -> interface adaptor for the fetch interface
|
||||
- The workflow to combine them all:
|
||||
The workflow may be very complicated. DataHandlerLP is one of the practices, but it can't satisfy all the requirements.
|
||||
So leaving the flexibility to the user to implement the workflow is a more reasonable choice.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs): # pylint: disable=W0246
|
||||
"""
|
||||
We should define how to get ready for the fetching.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
CS_ALL = "__all" # return all columns with single-level index column
|
||||
CS_RAW = "__raw" # return raw data with multi-level index column
|
||||
|
||||
# data key
|
||||
DK_R: DATA_KEY_TYPE = "raw"
|
||||
DK_I: DATA_KEY_TYPE = "infer"
|
||||
DK_L: DATA_KEY_TYPE = "learn"
|
||||
|
||||
@abstractmethod
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
data_key: DATA_KEY_TYPE = DK_I,
|
||||
) -> pd.DataFrame:
|
||||
pass
|
||||
|
||||
|
||||
class DataHandler(DataHandlerABC):
|
||||
"""
|
||||
The motivation of DataHandler:
|
||||
|
||||
- It provides an implementation of BaseDataHandler that we implement with:
|
||||
- Handling responses with an internal loaded DataFrame
|
||||
- The DataFrame is loaded by a data loader.
|
||||
|
||||
The steps to using a handler
|
||||
1. initialized data handler (call by `init`).
|
||||
2. use the data.
|
||||
@@ -144,16 +195,14 @@ class DataHandler(Serializable):
|
||||
self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time))
|
||||
# TODO: cache
|
||||
|
||||
CS_ALL = "__all" # return all columns with single-level index column
|
||||
CS_RAW = "__raw" # return raw data with multi-level index column
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,
|
||||
data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
proc_func: Optional[Callable] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
@@ -216,6 +265,8 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame.
|
||||
"""
|
||||
# DataHandler is an example with only one dataframe, so data_key is not used.
|
||||
_ = data_key # avoid linting errors (e.g., unused-argument)
|
||||
return self._fetch_data(
|
||||
data_storage=self._data,
|
||||
selector=selector,
|
||||
@@ -230,7 +281,7 @@ class DataHandler(Serializable):
|
||||
data_storage,
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
):
|
||||
@@ -261,16 +312,9 @@ class DataHandler(Serializable):
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
|
||||
elif isinstance(data_storage, BaseHandlerStorage):
|
||||
if not data_storage.is_proc_func_supported():
|
||||
if proc_func is not None:
|
||||
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
|
||||
)
|
||||
else:
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
|
||||
)
|
||||
if proc_func is not None:
|
||||
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
|
||||
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
|
||||
else:
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HashingStockStorage, not {type(data_storage)}")
|
||||
|
||||
@@ -282,7 +326,7 @@ class DataHandler(Serializable):
|
||||
data_df = data_df.reset_index(level=level, drop=True)
|
||||
return data_df
|
||||
|
||||
def get_cols(self, col_set=CS_ALL) -> list:
|
||||
def get_cols(self, col_set=DataHandlerABC.CS_ALL) -> list:
|
||||
"""
|
||||
get the column names
|
||||
|
||||
@@ -336,11 +380,12 @@ class DataHandler(Serializable):
|
||||
yield cur_date, self.fetch(selector, **kwargs)
|
||||
|
||||
|
||||
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
|
||||
|
||||
|
||||
class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
Motivation:
|
||||
- For the case that we hope using different processor workflows for learning and inference;
|
||||
|
||||
|
||||
DataHandler with **(L)earnable (P)rocessor**
|
||||
|
||||
This handler will produce three pieces of data in pd.DataFrame format.
|
||||
@@ -374,12 +419,8 @@ class DataHandlerLP(DataHandler):
|
||||
_infer: pd.DataFrame # data for inference
|
||||
_learn: pd.DataFrame # data for learning models
|
||||
|
||||
# data key
|
||||
DK_R: DATA_KEY_TYPE = "raw"
|
||||
DK_I: DATA_KEY_TYPE = "infer"
|
||||
DK_L: DATA_KEY_TYPE = "learn"
|
||||
# map data_key to attribute name
|
||||
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
|
||||
ATTR_MAP = {DataHandler.DK_R: "_data", DataHandler.DK_I: "_infer", DataHandler.DK_L: "_learn"}
|
||||
|
||||
# process type
|
||||
PTYPE_I = "independent"
|
||||
@@ -622,7 +663,7 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
# TODO: Be able to cache handler data. Save the memory for data processing
|
||||
|
||||
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DK_I) -> pd.DataFrame:
|
||||
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> pd.DataFrame:
|
||||
if data_key == self.DK_R and self.drop_raw:
|
||||
raise AttributeError(
|
||||
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
|
||||
@@ -635,7 +676,7 @@ class DataHandlerLP(DataHandler):
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: DATA_KEY_TYPE = DK_I,
|
||||
data_key: DATA_KEY_TYPE = DataHandler.DK_I,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
@@ -669,7 +710,7 @@ class DataHandlerLP(DataHandler):
|
||||
proc_func=proc_func,
|
||||
)
|
||||
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DK_I) -> list:
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> list:
|
||||
"""
|
||||
get the column names
|
||||
|
||||
|
||||
@@ -279,8 +279,11 @@ class StaticDataLoader(DataLoader, Serializable):
|
||||
)
|
||||
self._data.sort_index(inplace=True)
|
||||
elif isinstance(self._config, (str, Path)):
|
||||
with Path(self._config).open("rb") as f:
|
||||
self._data = pickle.load(f)
|
||||
if str(self._config).strip().endswith(".parquet"):
|
||||
self._data = pd.read_parquet(self._config, engine="pyarrow")
|
||||
else:
|
||||
with Path(self._config).open("rb") as f:
|
||||
self._data = pickle.load(f)
|
||||
elif isinstance(self._config, pd.DataFrame):
|
||||
self._data = self._config
|
||||
|
||||
@@ -336,6 +339,10 @@ class NestedDataLoader(DataLoader):
|
||||
if df_full is None:
|
||||
df_full = df_current
|
||||
else:
|
||||
current_columns = df_current.columns.tolist()
|
||||
full_columns = df_full.columns.tolist()
|
||||
columns_to_drop = [col for col in current_columns if col in full_columns]
|
||||
df_full.drop(columns=columns_to_drop, inplace=True)
|
||||
df_full = pd.merge(df_full, df_current, left_index=True, right_index=True, how=self.join)
|
||||
return df_full.sort_index(axis=1)
|
||||
|
||||
|
||||
@@ -187,14 +187,9 @@ class Fillna(Processor):
|
||||
if self.fields_group is None:
|
||||
df.fillna(self.fill_value, inplace=True)
|
||||
else:
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
# this implementation is extremely slow
|
||||
# df.fillna({col: self.fill_value for col in cols}, inplace=True)
|
||||
|
||||
# So we use numpy to accelerate filling values
|
||||
nan_select = np.isnan(df.values)
|
||||
nan_select[:, ~df.columns.isin(cols)] = False
|
||||
df.values[nan_select] = self.fill_value
|
||||
df[self.fields_group] = df[self.fields_group].fillna(self.fill_value)
|
||||
return df
|
||||
|
||||
|
||||
@@ -357,7 +352,7 @@ class CSRankNorm(Processor):
|
||||
def __call__(self, df):
|
||||
# try not modify original dataframe
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
t = df[cols].groupby("datetime").rank(pct=True)
|
||||
t = df[cols].groupby("datetime", group_keys=False).rank(pct=True)
|
||||
t -= 0.5
|
||||
t *= 3.46 # NOTE: towards unit std
|
||||
df[cols] = t
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from abc import abstractmethod
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from .handler import DataHandler
|
||||
from typing import Union, List, Callable
|
||||
from typing import Union, List
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
|
||||
|
||||
@@ -14,14 +16,13 @@ class BaseHandlerStorage:
|
||||
- If users want to use custom data storage, they should define subclass inherited BaseHandlerStorage, and implement the following method
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
proc_func: Callable = None,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""fetch data from the data storage
|
||||
|
||||
@@ -41,8 +42,6 @@ class BaseHandlerStorage:
|
||||
select several sets of meaningful columns, the returned data has multiple level
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
proc_func: Callable
|
||||
please refer to the doc of DataHandler.fetch
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -51,13 +50,40 @@ class BaseHandlerStorage:
|
||||
"""
|
||||
raise NotImplementedError("fetch is method not implemented!")
|
||||
|
||||
@staticmethod
|
||||
def from_df(df: pd.DataFrame):
|
||||
raise NotImplementedError("from_df method is not implemented!")
|
||||
|
||||
def is_proc_func_supported(self):
|
||||
"""whether the arg `proc_func` in `fetch` method is supported."""
|
||||
raise NotImplementedError("is_proc_func_supported method is not implemented!")
|
||||
class NaiveDFStorage(BaseHandlerStorage):
|
||||
"""Naive data storage for datahandler
|
||||
- NaiveDFStorage is a naive data storage for datahandler
|
||||
- NaiveDFStorage will input a pandas.DataFrame as and provide interface support for fetching data
|
||||
"""
|
||||
|
||||
def __init__(self, df: pd.DataFrame):
|
||||
self.df = df
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
|
||||
# Following conflicts may occur
|
||||
# - Does [20200101", "20210101"] mean selecting this slice or these two days?
|
||||
# To solve this issue
|
||||
# - slice have higher priorities (except when level is none)
|
||||
if isinstance(selector, (tuple, list)) and level is not None:
|
||||
# when level is None, the argument will be passed in directly
|
||||
# we don't have to convert it into slice
|
||||
try:
|
||||
selector = slice(*selector)
|
||||
except ValueError:
|
||||
get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly")
|
||||
|
||||
data_df = self.df
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=fetch_orig)
|
||||
return data_df
|
||||
|
||||
|
||||
class HashingStockStorage(BaseHandlerStorage):
|
||||
@@ -77,7 +103,7 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
def __init__(self, df):
|
||||
self.hash_df = dict()
|
||||
self.stock_level = get_level_index(df, "instrument")
|
||||
for k, v in df.groupby(level="instrument"):
|
||||
for k, v in df.groupby(level="instrument", group_keys=False):
|
||||
self.hash_df[k] = v
|
||||
self.columns = df.columns
|
||||
|
||||
@@ -142,7 +168,7 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
@@ -164,7 +190,3 @@ class HashingStockStorage(BaseHandlerStorage):
|
||||
return fetch_stock_df_list[0]
|
||||
else:
|
||||
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)
|
||||
|
||||
def is_proc_func_supported(self):
|
||||
"""the arg `proc_func` in `fetch` method is not supported in HashingStockStorage"""
|
||||
return False
|
||||
|
||||
@@ -126,7 +126,7 @@ class AverageEnsemble(Ensemble):
|
||||
# NOTE: this may change the style underlying data!!!!
|
||||
# from pd.DataFrame to pd.Series
|
||||
results = pd.concat(values, axis=1)
|
||||
results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())
|
||||
results = results.groupby("datetime", group_keys=False).apply(lambda df: (df - df.mean()) / df.std())
|
||||
results = results.mean(axis=1)
|
||||
results = results.sort_index()
|
||||
return results
|
||||
|
||||
@@ -240,7 +240,9 @@ class TrainerR(Trainer):
|
||||
self.train_func = train_func
|
||||
self._call_in_subproc = call_in_subproc
|
||||
|
||||
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
def train(
|
||||
self, tasks: list, train_func: Optional[Callable] = None, experiment_name: Optional[str] = None, **kwargs
|
||||
) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of `tasks` and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
|
||||
@@ -200,7 +200,7 @@ class Trainer:
|
||||
|
||||
if ckpt_path is not None:
|
||||
_logger.info("Resuming states from %s", str(ckpt_path))
|
||||
self.load_state_dict(torch.load(ckpt_path))
|
||||
self.load_state_dict(torch.load(ckpt_path, weights_only=False))
|
||||
else:
|
||||
self.initialize()
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class MockInstrumentStorage(MockStorageBase, InstrumentStorage):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
instruments = {}
|
||||
for symbol, group in self.df.groupby(by="symbol"):
|
||||
for symbol, group in self.df.groupby(by="symbol", group_keys=False):
|
||||
start = group["datetime"].iloc[0]
|
||||
end = group["datetime"].iloc[-1]
|
||||
instruments[symbol] = [(start, end)]
|
||||
|
||||
@@ -5,8 +5,11 @@ This module covers some utility functions that operate on data or basic object
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from typing import List, Union
|
||||
import pandas as pd
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.data.data import DatasetProvider
|
||||
|
||||
|
||||
def robust_zscore(x: pd.Series, zscore=False):
|
||||
@@ -103,3 +106,12 @@ def update_config(base_config: dict, ext_config: Union[dict, List[dict]]):
|
||||
# one of then are not dict. Then replace
|
||||
base_config[key] = ec[key]
|
||||
return base_config
|
||||
|
||||
|
||||
def guess_horizon(label: List):
|
||||
"""
|
||||
Try to guess the horizon by parsing label
|
||||
"""
|
||||
expr = DatasetProvider.parse_fields(label)[0]
|
||||
lft_etd, rght_etd = expr.get_extended_window_size()
|
||||
return rght_etd
|
||||
|
||||
@@ -6,6 +6,7 @@ from functools import partial
|
||||
from threading import Thread
|
||||
from typing import Callable, Text, Union
|
||||
|
||||
import joblib
|
||||
from joblib import Parallel, delayed
|
||||
from joblib._parallel_backends import MultiprocessingBackend
|
||||
import pandas as pd
|
||||
@@ -21,11 +22,16 @@ class ParallelExt(Parallel):
|
||||
maxtasksperchild = kwargs.pop("maxtasksperchild", None)
|
||||
super(ParallelExt, self).__init__(*args, **kwargs)
|
||||
if isinstance(self._backend, MultiprocessingBackend):
|
||||
self._backend_args["maxtasksperchild"] = maxtasksperchild
|
||||
# 2025-05-04 joblib released version 1.5.0, in which _backend_args was removed and replaced by _backend_kwargs.
|
||||
# Ref: https://github.com/joblib/joblib/pull/1525/files#diff-e4dff8042ce45b443faf49605b75a58df35b8c195978d4a57f4afa695b406bdc
|
||||
if joblib.__version__ < "1.5.0":
|
||||
self._backend_args["maxtasksperchild"] = maxtasksperchild # pylint: disable=E1101
|
||||
else:
|
||||
self._backend_kwargs["maxtasksperchild"] = maxtasksperchild # pylint: disable=E1101
|
||||
|
||||
|
||||
def datetime_groupby_apply(
|
||||
df, apply_func: Union[Callable, Text], axis=0, level="datetime", resample_rule="M", n_jobs=-1
|
||||
df, apply_func: Union[Callable, Text], axis=0, level="datetime", resample_rule="ME", n_jobs=-1
|
||||
):
|
||||
"""datetime_groupby_apply
|
||||
This function will apply the `apply_func` on the datetime level index.
|
||||
@@ -51,12 +57,12 @@ def datetime_groupby_apply(
|
||||
|
||||
def _naive_group_apply(df):
|
||||
if isinstance(apply_func, str):
|
||||
return getattr(df.groupby(axis=axis, level=level), apply_func)()
|
||||
return df.groupby(axis=axis, level=level).apply(apply_func)
|
||||
return getattr(df.groupby(axis=axis, level=level, group_keys=False), apply_func)()
|
||||
return df.groupby(level=level, group_keys=False).apply(apply_func)
|
||||
|
||||
if n_jobs != 1:
|
||||
dfs = ParallelExt(n_jobs=n_jobs)(
|
||||
delayed(_naive_group_apply)(sub_df) for idx, sub_df in df.resample(resample_rule, axis=axis, level=level)
|
||||
delayed(_naive_group_apply)(sub_df) for idx, sub_df in df.resample(resample_rule, level=level)
|
||||
)
|
||||
return pd.concat(dfs, axis=axis).sort_index()
|
||||
else:
|
||||
|
||||
@@ -194,9 +194,9 @@ def resam_ts_data(
|
||||
if isinstance(feature.index, pd.MultiIndex):
|
||||
if callable(method):
|
||||
method_func = method
|
||||
return feature.groupby(level="instrument").apply(method_func, **method_kwargs)
|
||||
return feature.groupby(level="instrument", group_keys=False).apply(method_func, **method_kwargs)
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
|
||||
return getattr(feature.groupby(level="instrument", group_keys=False), method)(**method_kwargs)
|
||||
else:
|
||||
if callable(method):
|
||||
method_func = method
|
||||
|
||||
@@ -1,5 +1,18 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Motivation of this design (instead of using mlflow directly):
|
||||
- Better design than mlflow native design
|
||||
- we have record object with a lot of methods(more intuitive), instead of use run_id everytime in mlflow
|
||||
- So the recorder's interfaces like log, start, will be more intuitive.
|
||||
- Provide richer and tailerd features than mlflow native
|
||||
- Logging code diff at the start of run.
|
||||
- log_object and load_object to for Python object directly instead log_artifact and download_artifact
|
||||
- (weak) Allow diverse backend support
|
||||
|
||||
To be honest, design always add burdens. For example,
|
||||
- You need to create an experiment before you can get a recorder. (In MLflow, experiments are more like tags, and you often just use a run_id in many interfaces without first defining an experiment.)
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Text, Optional, Any, Dict
|
||||
|
||||
@@ -652,7 +652,7 @@ class MultiPassPortAnaRecord(PortAnaRecord):
|
||||
combined_df = pd.concat(risk_analysis_df_map[_analysis_freq])
|
||||
|
||||
# Calculate return and information ratio's mean, std and mean/std
|
||||
multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1]).apply(
|
||||
multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1], group_keys=False).apply(
|
||||
lambda x: pd.Series(
|
||||
{"mean": x["risk"].mean(), "std": x["risk"].std(), "mean_std": x["risk"].mean() / x["risk"].std()}
|
||||
)
|
||||
|
||||
@@ -71,6 +71,6 @@ qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
## Use Crowd Sourced Data
|
||||
The is also a [crowd sourced version of qlib data](data_collector/crowd_source/README.md): https://github.com/chenditc/investment_data/releases
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
wget https://github.com/chenditc/investment_data/releases/latest/download/qlib_bin.tar.gz
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
```
|
||||
|
||||
203
scripts/check_data_health.py
Normal file
203
scripts/check_data_health.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from loguru import logger
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
import qlib
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.data import D
|
||||
|
||||
|
||||
class DataHealthChecker:
|
||||
"""Checks a dataset for data completeness and correctness. The data will be converted to a pd.DataFrame and checked for the following problems:
|
||||
- any of the columns ["open", "high", "low", "close", "volume"] are missing
|
||||
- any data is missing
|
||||
- any step change in the OHLCV columns is above a threshold (default: 0.5 for price, 3 for volume)
|
||||
- any factor is missing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv_path=None,
|
||||
qlib_dir=None,
|
||||
freq="day",
|
||||
large_step_threshold_price=0.5,
|
||||
large_step_threshold_volume=3,
|
||||
missing_data_num=0,
|
||||
):
|
||||
assert csv_path or qlib_dir, "One of csv_path or qlib_dir should be provided."
|
||||
assert not (csv_path and qlib_dir), "Only one of csv_path or qlib_dir should be provided."
|
||||
|
||||
self.data = {}
|
||||
self.problems = {}
|
||||
self.freq = freq
|
||||
self.large_step_threshold_price = large_step_threshold_price
|
||||
self.large_step_threshold_volume = large_step_threshold_volume
|
||||
self.missing_data_num = missing_data_num
|
||||
|
||||
if csv_path:
|
||||
assert os.path.isdir(csv_path), f"{csv_path} should be a directory."
|
||||
files = [f for f in os.listdir(csv_path) if f.endswith(".csv")]
|
||||
for filename in tqdm(files, desc="Loading data"):
|
||||
df = pd.read_csv(os.path.join(csv_path, filename))
|
||||
self.data[filename] = df
|
||||
|
||||
elif qlib_dir:
|
||||
qlib.init(provider_uri=qlib_dir)
|
||||
self.load_qlib_data()
|
||||
|
||||
def load_qlib_data(self):
|
||||
instruments = D.instruments(market="all")
|
||||
instrument_list = D.list_instruments(instruments=instruments, as_list=True, freq=self.freq)
|
||||
required_fields = ["$open", "$close", "$low", "$high", "$volume", "$factor"]
|
||||
for instrument in instrument_list:
|
||||
df = D.features([instrument], required_fields, freq=self.freq)
|
||||
df.rename(
|
||||
columns={
|
||||
"$open": "open",
|
||||
"$close": "close",
|
||||
"$low": "low",
|
||||
"$high": "high",
|
||||
"$volume": "volume",
|
||||
"$factor": "factor",
|
||||
},
|
||||
inplace=True,
|
||||
)
|
||||
self.data[instrument] = df
|
||||
print(df)
|
||||
|
||||
def check_missing_data(self) -> Optional[pd.DataFrame]:
|
||||
"""Check if any data is missing in the DataFrame."""
|
||||
result_dict = {
|
||||
"instruments": [],
|
||||
"open": [],
|
||||
"high": [],
|
||||
"low": [],
|
||||
"close": [],
|
||||
"volume": [],
|
||||
}
|
||||
for filename, df in self.data.items():
|
||||
missing_data_columns = df.isnull().sum()[df.isnull().sum() > self.missing_data_num].index.tolist()
|
||||
if len(missing_data_columns) > 0:
|
||||
result_dict["instruments"].append(filename)
|
||||
result_dict["open"].append(df.isnull().sum()["open"])
|
||||
result_dict["high"].append(df.isnull().sum()["high"])
|
||||
result_dict["low"].append(df.isnull().sum()["low"])
|
||||
result_dict["close"].append(df.isnull().sum()["close"])
|
||||
result_dict["volume"].append(df.isnull().sum()["volume"])
|
||||
|
||||
result_df = pd.DataFrame(result_dict).set_index("instruments")
|
||||
if not result_df.empty:
|
||||
return result_df
|
||||
else:
|
||||
logger.info(f"✅ There are no missing data.")
|
||||
return None
|
||||
|
||||
def check_large_step_changes(self) -> Optional[pd.DataFrame]:
|
||||
"""Check if there are any large step changes above the threshold in the OHLCV columns."""
|
||||
result_dict = {
|
||||
"instruments": [],
|
||||
"col_name": [],
|
||||
"date": [],
|
||||
"pct_change": [],
|
||||
}
|
||||
for filename, df in self.data.items():
|
||||
affected_columns = []
|
||||
for col in ["open", "high", "low", "close", "volume"]:
|
||||
if col in df.columns:
|
||||
pct_change = df[col].pct_change(fill_method=None).abs()
|
||||
threshold = self.large_step_threshold_volume if col == "volume" else self.large_step_threshold_price
|
||||
if pct_change.max() > threshold:
|
||||
large_steps = pct_change[pct_change > threshold]
|
||||
result_dict["instruments"].append(filename)
|
||||
result_dict["col_name"].append(col)
|
||||
result_dict["date"].append(large_steps.index.to_list()[0][1].strftime("%Y-%m-%d"))
|
||||
result_dict["pct_change"].append(pct_change.max())
|
||||
affected_columns.append(col)
|
||||
|
||||
result_df = pd.DataFrame(result_dict).set_index("instruments")
|
||||
if not result_df.empty:
|
||||
return result_df
|
||||
else:
|
||||
logger.info(f"✅ There are no large step changes in the OHLCV column above the threshold.")
|
||||
return None
|
||||
|
||||
def check_required_columns(self) -> Optional[pd.DataFrame]:
|
||||
"""Check if any of the required columns (OLHCV) are missing in the DataFrame."""
|
||||
required_columns = ["open", "high", "low", "close", "volume"]
|
||||
result_dict = {
|
||||
"instruments": [],
|
||||
"missing_col": [],
|
||||
}
|
||||
for filename, df in self.data.items():
|
||||
if not all(column in df.columns for column in required_columns):
|
||||
missing_required_columns = [column for column in required_columns if column not in df.columns]
|
||||
result_dict["instruments"].append(filename)
|
||||
result_dict["missing_col"] += missing_required_columns
|
||||
|
||||
result_df = pd.DataFrame(result_dict).set_index("instruments")
|
||||
if not result_df.empty:
|
||||
return result_df
|
||||
else:
|
||||
logger.info(f"✅ The columns (OLHCV) are complete and not missing.")
|
||||
return None
|
||||
|
||||
def check_missing_factor(self) -> Optional[pd.DataFrame]:
|
||||
"""Check if the 'factor' column is missing in the DataFrame."""
|
||||
result_dict = {
|
||||
"instruments": [],
|
||||
"missing_factor_col": [],
|
||||
"missing_factor_data": [],
|
||||
}
|
||||
for filename, df in self.data.items():
|
||||
if "000300" in filename or "000903" in filename or "000905" in filename:
|
||||
continue
|
||||
if "factor" not in df.columns:
|
||||
result_dict["instruments"].append(filename)
|
||||
result_dict["missing_factor_col"].append(True)
|
||||
if df["factor"].isnull().all():
|
||||
if filename in result_dict["instruments"]:
|
||||
result_dict["missing_factor_data"].append(True)
|
||||
else:
|
||||
result_dict["instruments"].append(filename)
|
||||
result_dict["missing_factor_col"].append(False)
|
||||
result_dict["missing_factor_data"].append(True)
|
||||
|
||||
result_df = pd.DataFrame(result_dict).set_index("instruments")
|
||||
if not result_df.empty:
|
||||
return result_df
|
||||
else:
|
||||
logger.info(f"✅ The `factor` column already exists and is not empty.")
|
||||
return None
|
||||
|
||||
def check_data(self):
|
||||
check_missing_data_result = self.check_missing_data()
|
||||
check_large_step_changes_result = self.check_large_step_changes()
|
||||
check_required_columns_result = self.check_required_columns()
|
||||
check_missing_factor_result = self.check_missing_factor()
|
||||
if (
|
||||
check_large_step_changes_result is not None
|
||||
or check_large_step_changes_result is not None
|
||||
or check_required_columns_result is not None
|
||||
or check_missing_factor_result is not None
|
||||
):
|
||||
print(f"\nSummary of data health check ({len(self.data)} files checked):")
|
||||
print("-------------------------------------------------")
|
||||
if isinstance(check_missing_data_result, pd.DataFrame):
|
||||
logger.warning(f"There is missing data.")
|
||||
print(check_missing_data_result)
|
||||
if isinstance(check_large_step_changes_result, pd.DataFrame):
|
||||
logger.warning(f"The OHLCV column has large step changes.")
|
||||
print(check_large_step_changes_result)
|
||||
if isinstance(check_required_columns_result, pd.DataFrame):
|
||||
logger.warning(f"Columns (OLHCV) are missing.")
|
||||
print(check_required_columns_result)
|
||||
if isinstance(check_missing_factor_result, pd.DataFrame):
|
||||
logger.warning(f"The factor column does not exist or is empty")
|
||||
print(check_missing_factor_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(DataHealthChecker)
|
||||
@@ -64,7 +64,7 @@
|
||||
This will convert the normalized csv in `feature` directory as numpy array and store the normalized data one file per column and one symbol per directory.
|
||||
|
||||
- parameters:
|
||||
- `csv_path`: stock data path or directory, **normalize result(normalize_dir)**
|
||||
- `data_path`: stock data path or directory, **normalize result(normalize_dir)**
|
||||
- `qlib_dir`: qlib(dump) data director
|
||||
- `freq`: transaction frequency, by default `day`
|
||||
> `freq_map = {1d:day, 5mih: 5min}`
|
||||
@@ -74,8 +74,9 @@
|
||||
> dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- `file_suffix`: stock data file format, by default ".csv"
|
||||
- examples:
|
||||
```bash
|
||||
# dump 5min cn
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/hs300_5min_nor --qlib_dir ~/.qlib/qlib_data/hs300_5min_bin --freq 5min --exclude_fields date,symbol
|
||||
python dump_bin.py dump_all --data_path ~/.qlib/stock_data/source/hs300_5min_nor --qlib_dir ~/.qlib/qlib_data/hs300_5min_bin --freq 5min --exclude_fields date,symbol
|
||||
```
|
||||
@@ -23,7 +23,9 @@ from data_collector.utils import get_calendar_list, get_trading_date_by_shift, d
|
||||
from data_collector.utils import get_instruments
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = "https://csi-web-dev.oss-cn-shanghai-finance-1-pub.aliyuncs.com/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
NEW_COMPANIES_URL = (
|
||||
"https://oss-ch.csindex.com.cn/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
)
|
||||
|
||||
|
||||
INDEX_CHANGES_URL = "https://www.csindex.com.cn/csindex-home/search/search-content?lang=cn&searchInput=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC&pageNum={page_num}&pageSize={page_size}&sortField=date&dateRange=all&contentType=announcement"
|
||||
|
||||
@@ -16,9 +16,9 @@ The packaged docker runtime is hosted on dockerhub: https://hub.docker.com/repos
|
||||
|
||||
## How to use it in qlib
|
||||
### Option 1: Download release bin data
|
||||
User can download data in qlib bin format and use it directly: https://github.com/chenditc/investment_data/releases/tag/20220720
|
||||
User can download data in qlib bin format and use it directly: https://github.com/chenditc/investment_data/releases/latest
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
wget https://github.com/chenditc/investment_data/releases/latest/download/qlib_bin.tar.gz
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
```
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ python collector.py normalize_data --source_dir ~/.qlib/crypto_data/source/1d --
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/crypto_data/source/1d_nor --qlib_dir ~/.qlib/qlib_data/crypto_data --freq day --date_field_name date --include_fields prices,total_volumes,market_caps
|
||||
python dump_bin.py dump_all --data_path ~/.qlib/crypto_data/source/1d_nor --qlib_dir ~/.qlib/qlib_data/crypto_data --freq day --date_field_name date --include_fields prices,total_volumes,market_caps
|
||||
|
||||
```
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_data
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
|
||||
python dump_bin.py dump_all --data_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
|
||||
|
||||
```
|
||||
|
||||
|
||||
@@ -36,5 +36,5 @@ python collector.py normalize_data --interval quarterly --source_dir ~/.qlib/sto
|
||||
|
||||
```bash
|
||||
cd qlib/scripts
|
||||
python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly
|
||||
python dump_pit.py dump --data_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly
|
||||
```
|
||||
|
||||
@@ -202,18 +202,59 @@ def get_hs_stock_symbols() -> list:
|
||||
-------
|
||||
{600000.ss, 600001.ss, 600002.ss, 600003.ss, ...}
|
||||
"""
|
||||
url = "http://99.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&po=1&np=1&fs=m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048&fields=f12"
|
||||
try:
|
||||
resp = requests.get(url, timeout=None)
|
||||
resp.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise requests.exceptions.HTTPError(f"Request to {url} failed with status code {resp.status_code}") from e
|
||||
# url = "http://99.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&po=1&np=1&fs=m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048&fields=f12"
|
||||
|
||||
try:
|
||||
_symbols = [_v["f12"] for _v in resp.json()["data"]["diff"]]
|
||||
except Exception as e:
|
||||
logger.warning("An error occurred while extracting data from the response.")
|
||||
raise
|
||||
base_url = "http://99.push2.eastmoney.com/api/qt/clist/get"
|
||||
params = {
|
||||
"pn": 1, # page number
|
||||
"pz": 100, # page size, default to 100
|
||||
"po": 1,
|
||||
"np": 1,
|
||||
"fs": "m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048",
|
||||
"fields": "f12",
|
||||
}
|
||||
|
||||
_symbols = []
|
||||
page = 1
|
||||
|
||||
while True:
|
||||
params["pn"] = page
|
||||
try:
|
||||
resp = requests.get(base_url, params=params, timeout=None)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Check if response contains valid data
|
||||
if not data or "data" not in data or not data["data"] or "diff" not in data["data"]:
|
||||
logger.warning(f"Invalid response structure on page {page}")
|
||||
break
|
||||
|
||||
# fetch the current page data
|
||||
current_symbols = [_v["f12"] for _v in data["data"]["diff"]]
|
||||
|
||||
if not current_symbols: # It's the last page if there is no data in current page
|
||||
logger.info(f"Last page reached: {page - 1}")
|
||||
break
|
||||
|
||||
_symbols.extend(current_symbols)
|
||||
|
||||
# show progress
|
||||
logger.info(
|
||||
f"Page {page}: fetch {len(current_symbols)} stocks:[{current_symbols[0]} ... {current_symbols[-1]}]"
|
||||
)
|
||||
|
||||
page += 1
|
||||
|
||||
# sleep time to avoid overloading the server
|
||||
time.sleep(0.5)
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise requests.exceptions.HTTPError(
|
||||
f"Request to {base_url} failed with status code {resp.status_code}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.warning("An error occurred while extracting data from the response.")
|
||||
raise
|
||||
|
||||
if len(_symbols) < 3900:
|
||||
raise ValueError("The complete list of stocks is not available.")
|
||||
@@ -767,7 +808,7 @@ def calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name):
|
||||
all_nan_nums = 0
|
||||
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
|
||||
not_nan_nums = 0
|
||||
for _date, _df in df.groupby("_tmp_date"):
|
||||
for _date, _df in df.groupby("_tmp_date", group_keys=False):
|
||||
_df["paused"] = 0
|
||||
if not _df.loc[_df["volume"] < 0].empty:
|
||||
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
|
||||
|
||||
@@ -50,12 +50,6 @@ pip install -r requirements.txt
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
# us 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us --interval 1d
|
||||
# us 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data_1min --region us --interval 1min
|
||||
# in 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/in_data --region in --interval 1d
|
||||
# in 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/in_data_1min --region in --interval 1min
|
||||
```
|
||||
|
||||
### Collector *YahooFinance* data to qlib
|
||||
@@ -145,7 +139,7 @@ pip install -r requirements.txt
|
||||
This will convert the normalized csv in `feature` directory as numpy array and store the normalized data one file per column and one symbol per directory.
|
||||
|
||||
- parameters:
|
||||
- `csv_path`: stock data path or directory, **normalize result(normalize_dir)**
|
||||
- `data_path`: stock data path or directory, **normalize result(normalize_dir)**
|
||||
- `qlib_dir`: qlib(dump) data director
|
||||
- `freq`: transaction frequency, by default `day`
|
||||
> `freq_map = {1d:day, 1mih: 1min}`
|
||||
@@ -155,12 +149,13 @@ pip install -r requirements.txt
|
||||
> dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- `file_suffix`: stock data file format, by default ".csv"
|
||||
- examples:
|
||||
```bash
|
||||
# dump 1d cn
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_data --freq day --exclude_fields date,symbol
|
||||
python dump_bin.py dump_all --data_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_data --freq day --exclude_fields date,symbol --file_suffix .csv
|
||||
# dump 1min cn
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/cn_data_1min --freq 1min --exclude_fields date,symbol
|
||||
python dump_bin.py dump_all --data_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/cn_data_1min --freq 1min --exclude_fields date,symbol --file_suffix .csv
|
||||
```
|
||||
|
||||
### Automatic update of daily frequency data(from yahoo finance)
|
||||
|
||||
@@ -856,7 +856,7 @@ class Run(BaseRun):
|
||||
|
||||
3. normalize new source data(from step 2): python scripts/data_collector/yahoo/collector.py normalize_data_1d_extend --old_qlib_dir <dir1> --source_dir <dir2> --normalize_dir <dir3> --region CN --interval 1d
|
||||
|
||||
4. dump data: python scripts/dump_bin.py dump_update --csv_path <dir3> --qlib_dir <dir1> --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date
|
||||
4. dump data: python scripts/dump_bin.py dump_update --data_path <dir3> --qlib_dir <dir1> --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date
|
||||
|
||||
5. update instrument(eg. csi300): python python scripts/data_collector/cn_index/collector.py --index_name CSI300 --qlib_dir <dir1> --method parse_instruments
|
||||
|
||||
@@ -997,7 +997,7 @@ class Run(BaseRun):
|
||||
|
||||
# dump bin
|
||||
_dump = DumpDataUpdate(
|
||||
csv_path=self.normalize_dir,
|
||||
data_path=self.normalize_dir,
|
||||
qlib_dir=qlib_data_1d_dir,
|
||||
exclude_fields="symbol,date",
|
||||
max_workers=self.max_workers,
|
||||
|
||||
@@ -17,6 +17,39 @@ from loguru import logger
|
||||
from qlib.utils import fname_to_code, code_to_fname
|
||||
|
||||
|
||||
def read_as_df(file_path: Union[str, Path], **kwargs) -> pd.DataFrame:
|
||||
"""
|
||||
Read a csv or parquet file into a pandas DataFrame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_path : Union[str, Path]
|
||||
Path to the data file.
|
||||
**kwargs :
|
||||
Additional keyword arguments passed to the underlying pandas
|
||||
reader.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
file_path = Path(file_path).expanduser()
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
keep_keys = {".csv": ("low_memory",)}
|
||||
kept_kwargs = {}
|
||||
for k in keep_keys.get(suffix, []):
|
||||
if k in kwargs:
|
||||
kept_kwargs[k] = kwargs[k]
|
||||
|
||||
if suffix == ".csv":
|
||||
return pd.read_csv(file_path, **kept_kwargs)
|
||||
elif suffix == ".parquet":
|
||||
return pd.read_parquet(file_path, **kept_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {suffix}")
|
||||
|
||||
|
||||
class DumpDataBase:
|
||||
INSTRUMENTS_START_FIELD = "start_datetime"
|
||||
INSTRUMENTS_END_FIELD = "end_datetime"
|
||||
@@ -34,7 +67,7 @@ class DumpDataBase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str,
|
||||
data_path: str,
|
||||
qlib_dir: str,
|
||||
backup_dir: str = None,
|
||||
freq: str = "day",
|
||||
@@ -50,7 +83,7 @@ class DumpDataBase:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
csv_path: str
|
||||
data_path: str
|
||||
stock data path or directory
|
||||
qlib_dir: str
|
||||
qlib(dump) data director
|
||||
@@ -73,7 +106,7 @@ class DumpDataBase:
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
"""
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
data_path = Path(data_path).expanduser()
|
||||
if isinstance(exclude_fields, str):
|
||||
exclude_fields = exclude_fields.split(",")
|
||||
if isinstance(include_fields, str):
|
||||
@@ -82,9 +115,9 @@ class DumpDataBase:
|
||||
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
|
||||
self.file_suffix = file_suffix
|
||||
self.symbol_field_name = symbol_field_name
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
self.df_files = sorted(data_path.glob(f"*{self.file_suffix}") if data_path.is_dir() else [data_path])
|
||||
if limit_nums is not None:
|
||||
self.csv_files = self.csv_files[: int(limit_nums)]
|
||||
self.df_files = self.df_files[: int(limit_nums)]
|
||||
self.qlib_dir = Path(qlib_dir).expanduser()
|
||||
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
|
||||
if backup_dir is not None:
|
||||
@@ -134,13 +167,14 @@ class DumpDataBase:
|
||||
return _calendars.tolist()
|
||||
|
||||
def _get_source_data(self, file_path: Path) -> pd.DataFrame:
|
||||
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
|
||||
df[self.date_field_name] = df[self.date_field_name].astype(str).astype("datetime64[ns]")
|
||||
df = read_as_df(file_path, low_memory=False)
|
||||
if self.date_field_name in df.columns:
|
||||
df[self.date_field_name] = pd.to_datetime(df[self.date_field_name])
|
||||
# df.drop_duplicates([self.date_field_name], inplace=True)
|
||||
return df
|
||||
|
||||
def get_symbol_from_file(self, file_path: Path) -> str:
|
||||
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
|
||||
return fname_to_code(file_path.stem.strip().lower())
|
||||
|
||||
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
|
||||
return (
|
||||
@@ -274,10 +308,10 @@ class DumpDataAll(DumpDataBase):
|
||||
all_datetime = set()
|
||||
date_range_list = []
|
||||
_fun = partial(self._get_date, as_set=True, is_begin_end=True)
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with tqdm(total=len(self.df_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
for file_path, ((_begin_time, _end_time), _set_calendars) in zip(
|
||||
self.csv_files, executor.map(_fun, self.csv_files)
|
||||
self.df_files, executor.map(_fun, self.df_files)
|
||||
):
|
||||
all_datetime = all_datetime | _set_calendars
|
||||
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
|
||||
@@ -305,9 +339,9 @@ class DumpDataAll(DumpDataBase):
|
||||
def _dump_features(self):
|
||||
logger.info("start dump features......")
|
||||
_dump_func = partial(self._dump_bin, calendar_list=self._calendars_list)
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with tqdm(total=len(self.df_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
for _ in executor.map(_dump_func, self.csv_files):
|
||||
for _ in executor.map(_dump_func, self.df_files):
|
||||
p_bar.update()
|
||||
|
||||
logger.info("end of features dump.\n")
|
||||
@@ -325,16 +359,15 @@ class DumpDataFix(DumpDataAll):
|
||||
_fun = partial(self._get_date, is_begin_end=True)
|
||||
new_stock_files = sorted(
|
||||
filter(
|
||||
lambda x: fname_to_code(x.name[: -len(self.file_suffix)].strip().lower()).upper()
|
||||
not in self._old_instruments,
|
||||
self.csv_files,
|
||||
lambda x: self.get_symbol_from_file(x).upper() not in self._old_instruments,
|
||||
self.df_files,
|
||||
)
|
||||
)
|
||||
with tqdm(total=len(new_stock_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as execute:
|
||||
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
|
||||
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
|
||||
symbol = fname_to_code(self.get_symbol_from_file(file_path).lower()).upper()
|
||||
symbol = self.get_symbol_from_file(file_path).upper()
|
||||
_dt_map = self._old_instruments.setdefault(symbol, dict())
|
||||
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
|
||||
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
|
||||
@@ -359,7 +392,7 @@ class DumpDataFix(DumpDataAll):
|
||||
class DumpDataUpdate(DumpDataBase):
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str,
|
||||
data_path: str,
|
||||
qlib_dir: str,
|
||||
backup_dir: str = None,
|
||||
freq: str = "day",
|
||||
@@ -375,7 +408,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
csv_path: str
|
||||
data_path: str
|
||||
stock data path or directory
|
||||
qlib_dir: str
|
||||
qlib(dump) data director
|
||||
@@ -399,7 +432,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
Use when debugging, default None
|
||||
"""
|
||||
super().__init__(
|
||||
csv_path,
|
||||
data_path,
|
||||
qlib_dir,
|
||||
backup_dir,
|
||||
freq,
|
||||
@@ -431,15 +464,19 @@ class DumpDataUpdate(DumpDataBase):
|
||||
logger.info("start load all source data....")
|
||||
all_df = []
|
||||
|
||||
def _read_csv(file_path: Path):
|
||||
_df = pd.read_csv(file_path, parse_dates=[self.date_field_name])
|
||||
def _read_df(file_path: Path):
|
||||
_df = read_as_df(file_path)
|
||||
if self.date_field_name in _df.columns and not np.issubdtype(
|
||||
_df[self.date_field_name].dtype, np.datetime64
|
||||
):
|
||||
_df[self.date_field_name] = pd.to_datetime(_df[self.date_field_name])
|
||||
if self.symbol_field_name not in _df.columns:
|
||||
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
|
||||
return _df
|
||||
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with tqdm(total=len(self.df_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for df in executor.map(_read_csv, self.csv_files):
|
||||
for df in executor.map(_read_df, self.df_files):
|
||||
if not df.empty:
|
||||
all_df.append(df)
|
||||
p_bar.update()
|
||||
@@ -458,7 +495,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
error_code = {}
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
futures = {}
|
||||
for _code, _df in self._all_data.groupby(self.symbol_field_name):
|
||||
for _code, _df in self._all_data.groupby(self.symbol_field_name, group_keys=False):
|
||||
_code = fname_to_code(str(_code).lower()).upper()
|
||||
_start, _end = self._get_date(_df, is_begin_end=True)
|
||||
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
|
||||
|
||||
@@ -10,6 +10,7 @@ sys.path.append(str(Path(__file__).resolve().parent))
|
||||
from qlib.data.dataset.loader import NestedDataLoader, QlibDataLoader
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.contrib.data.loader import Alpha158DL, Alpha360DL
|
||||
from qlib.data.dataset.processor import Fillna
|
||||
from qlib.data import D
|
||||
|
||||
|
||||
@@ -30,7 +31,7 @@ class TestDataLoader(unittest.TestCase):
|
||||
)
|
||||
# Of course you can use StaticDataLoader
|
||||
|
||||
dataset = nd.load(start_time="2020-01-01", end_time="2020-01-31")
|
||||
dataset = nd.load(instruments="csi300", start_time="2020-01-01", end_time="2020-01-31")
|
||||
|
||||
assert dataset is not None
|
||||
|
||||
@@ -45,6 +46,13 @@ class TestDataLoader(unittest.TestCase):
|
||||
|
||||
assert "LABEL0" in columns_list
|
||||
|
||||
assert dataset.isna().any().any()
|
||||
|
||||
fn = Fillna(fields_group="feature", fill_value=0)
|
||||
fn_dataset = fn.__call__(dataset)
|
||||
|
||||
assert not fn_dataset.isna().any().any()
|
||||
|
||||
# Then you can use it wth DataHandler;
|
||||
# NOTE: please note that the data processors are missing!!! You should add based on your requirements
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ from qlib.tests import TestAutoData
|
||||
class TestDataset(TestAutoData):
|
||||
def testCSI300(self):
|
||||
close_p = D.features(D.instruments("csi300"), ["$close"])
|
||||
size = close_p.groupby("datetime").size()
|
||||
cnt = close_p.groupby("datetime").count()["$close"]
|
||||
size = close_p.groupby("datetime", group_keys=False).size()
|
||||
cnt = close_p.groupby("datetime", group_keys=False).count()["$close"]
|
||||
size_desc = size.describe(percentiles=np.arange(0.1, 1.0, 0.1))
|
||||
cnt_desc = cnt.describe(percentiles=np.arange(0.1, 1.0, 0.1))
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from qlib.config import C
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.constant import REG_CN, REG_US, REG_TW
|
||||
from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal, CN_TIME, US_TIME, TW_TIME
|
||||
from qlib.utils.data import guess_horizon
|
||||
|
||||
REG_MAP = {REG_CN: CN_TIME, REG_US: US_TIME, REG_TW: TW_TIME}
|
||||
|
||||
@@ -112,5 +113,24 @@ class TimeUtils(TestCase):
|
||||
cal_sam_minute_new(*args, region=region)
|
||||
|
||||
|
||||
class DataUtils(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
init()
|
||||
|
||||
def test_guess_horizon(self):
|
||||
label = ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
result = guess_horizon(label)
|
||||
assert result == 2
|
||||
|
||||
label = ["Ref($close, -5) / Ref($close, -1) - 1"]
|
||||
result = guess_horizon(label)
|
||||
assert result == 5
|
||||
|
||||
label = ["Ref($close, -1) / Ref($close, -1) - 1"]
|
||||
result = guess_horizon(label)
|
||||
assert result == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -194,7 +194,7 @@ def test_trainer_checkpoint():
|
||||
assert (output_dir / "002.pth").exists()
|
||||
assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth")
|
||||
|
||||
trainer.load_state_dict(torch.load(output_dir / "001.pth"))
|
||||
trainer.load_state_dict(torch.load(output_dir / "001.pth", weights_only=False))
|
||||
assert trainer.current_iter == 1
|
||||
assert trainer.current_episode == 100
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class TestDumpData(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
GetData().download_data(file_name="csv_data_cn.zip", target_dir=SOURCE_DIR)
|
||||
TestDumpData.DUMP_DATA = DumpDataAll(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS)
|
||||
TestDumpData.DUMP_DATA = DumpDataAll(data_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS)
|
||||
TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
provider_uri = str(QLIB_DIR.resolve())
|
||||
qlib.init(
|
||||
@@ -76,7 +76,7 @@ class TestDumpData(unittest.TestCase):
|
||||
def test_4_dump_features_simple(self):
|
||||
stock = self.STOCK_NAMES[0]
|
||||
dump_data = DumpDataFix(
|
||||
csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS
|
||||
data_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS
|
||||
)
|
||||
dump_data.dump()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user