mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
55 Commits
v0.9.3
...
extra_df_d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2366fe1345 | ||
|
|
6c2fa0fc71 | ||
|
|
3b6c227562 | ||
|
|
d2c68e0cc0 | ||
|
|
b4879fc9da | ||
|
|
3f86171051 | ||
|
|
ce596f9dfa | ||
|
|
13768d1dac | ||
|
|
5190332c7e | ||
|
|
cde80206e4 | ||
|
|
a339fc11d1 | ||
|
|
33482047dc | ||
|
|
47bd13295b | ||
|
|
ebc0ca893e | ||
|
|
3a348aec9f | ||
|
|
37b908792b | ||
|
|
73ec0f4003 | ||
|
|
155c17f8ff | ||
|
|
41b94059aa | ||
|
|
7db83d84b7 | ||
|
|
35e0fdd1c0 | ||
|
|
598017f634 | ||
|
|
907c888c23 | ||
|
|
02fe6b6974 | ||
|
|
b892b21045 | ||
|
|
155f80323c | ||
|
|
63021018d6 | ||
|
|
f79a0eeaff | ||
|
|
8a087d0db9 | ||
|
|
2ae4be426a | ||
|
|
6ed83f7c04 | ||
|
|
917e3a725e | ||
|
|
b1e0e77c97 | ||
|
|
ea245f5435 | ||
|
|
3779b5186a | ||
|
|
194284b1ac | ||
|
|
1bb8f2fa23 | ||
|
|
39f88daaa7 | ||
|
|
98f569eed2 | ||
|
|
ceff886f49 | ||
|
|
15b64768e2 | ||
|
|
8bf2678676 | ||
|
|
fb80e318e2 | ||
|
|
ecbeeafdc1 | ||
|
|
69e28ceab8 | ||
|
|
4c30e5827b | ||
|
|
5387ea5c1f | ||
|
|
05d67b3828 | ||
|
|
38edac5069 | ||
|
|
b4b7a2fdd4 | ||
|
|
480f233e3f | ||
|
|
953621ac7e | ||
|
|
87a026fef3 | ||
|
|
8676303077 | ||
|
|
1a32ba1806 |
10
.github/workflows/python-publish.yml
vendored
10
.github/workflows/python-publish.yml
vendored
@@ -51,8 +51,8 @@ jobs:
|
||||
python setup.py bdist_wheel
|
||||
- name: Build and publish
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine upload dist/*
|
||||
|
||||
@@ -72,10 +72,10 @@ jobs:
|
||||
python-version: 3.7
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install twine
|
||||
pip install twine
|
||||
- name: Build and publish
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine upload dist/pyqlib-*-manylinux*.whl
|
||||
|
||||
6
.github/workflows/release-drafter.yml
vendored
6
.github/workflows/release-drafter.yml
vendored
@@ -6,8 +6,14 @@ on:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
update_release_draft:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# Drafts your next Release notes as Pull Requests are merged into "master"
|
||||
|
||||
15
.github/workflows/test_qlib_from_pip.yml
vendored
15
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -8,13 +8,15 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
if: ${{ false }} # FIXME: temporarily disable... Due to we are rushing a feature
|
||||
timeout-minutes: 120
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -43,10 +45,10 @@ jobs:
|
||||
|
||||
- name: Qlib installation test
|
||||
run: |
|
||||
# 2024-05-30 scs has released a new version: 3.2.4.post2,
|
||||
# This will cause the CI to fail, so we have limited the version of scs for now.
|
||||
python -m pip install "scs<=3.2.4"
|
||||
python -m pip install pyqlib
|
||||
# Specify the numpy version because the numpy upgrade caused the CI test to fail,
|
||||
# and this line of code will be removed when the next version of qlib is released.
|
||||
python -m pip install "numpy<1.23"
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
@@ -66,5 +68,8 @@ jobs:
|
||||
cd qlib
|
||||
|
||||
- name: Test workflow by config
|
||||
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
|
||||
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
19
.github/workflows/test_qlib_from_source.yml
vendored
19
.github/workflows/test_qlib_from_source.yml
vendored
@@ -14,7 +14,10 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -38,10 +41,8 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Update pip to the latest version
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0
|
||||
run: |
|
||||
python -m pip install pip==23.0
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
@@ -71,8 +72,10 @@ jobs:
|
||||
black . -l 120 --check --diff
|
||||
|
||||
- name: Make html with sphinx
|
||||
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
|
||||
if: ${{ matrix.os == 'ubuntu-22.04' }}
|
||||
run: |
|
||||
cd docs
|
||||
cd docs
|
||||
sphinx-build -W --keep-going -b html . _build
|
||||
cd ..
|
||||
|
||||
@@ -104,6 +107,7 @@ jobs:
|
||||
- name: Check Qlib with pylint
|
||||
run: |
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
@@ -157,11 +161,16 @@ jobs:
|
||||
|
||||
# Run after data downloads
|
||||
- name: Check Qlib ipynb with nbconvert
|
||||
# Running the nbconvert check on a macos-11 system results in a "Kernel died" error, so we've temporarily disabled macos-11 here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
# add more ipynb files in future
|
||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
|
||||
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
python -m pip install numba
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
@@ -14,7 +14,10 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -38,10 +41,8 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0
|
||||
run: |
|
||||
python -m pip install pip==23.0
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -48,4 +48,4 @@ tags
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
.idea/
|
||||
.idea/
|
||||
@@ -5,6 +5,12 @@
|
||||
# Required
|
||||
version: 2
|
||||
|
||||
# Set the version of Python and other tools you might need
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.7"
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
configuration: docs/conf.py
|
||||
@@ -14,7 +20,6 @@ formats: all
|
||||
|
||||
# Optionally set the version of Python and requirements required to build your docs
|
||||
python:
|
||||
version: 3.7
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- method: pip
|
||||
26
README.md
26
README.md
@@ -40,7 +40,7 @@ Recent released features
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
|
||||
<img src="docs/_static/img/logo/1.png" />
|
||||
</p>
|
||||
|
||||
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
|
||||
@@ -139,7 +139,7 @@ This table demonstrates the supported Python version of `Qlib`:
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
|
||||
**Note**:
|
||||
1. **Conda** is suggested for managing your Python environment.
|
||||
1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
1. `Qlib`Requires `tables` package, `hdf5` in tables does not support python3.9.
|
||||
@@ -166,13 +166,29 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
* Clone the repository and install ``Qlib`` as follows.
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
pip install .
|
||||
pip install . # `pip install -e .[dev]` is recommended for development. check details in docs/developer/code_standard_and_dev_guide.rst
|
||||
```
|
||||
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommended approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
|
||||
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test_qlib_from_source.yml) may help you find the problem.
|
||||
|
||||
**Tips for Mac**: If you are using Mac with M1, you might encounter issues in building the wheel for LightGBM, which is due to missing dependencies from OpenMP. To solve the problem, install openmp first with ``brew install libomp`` and then run ``pip install .`` to build it successfully.
|
||||
|
||||
## Data Preparation
|
||||
❗ Due to more restrict data security policy. The offical dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.
|
||||
Here is an example to download the data updated on 20220720.
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
mkdir -p ~/.qlib/qlib_data/cn_data
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
rm -f qlib_bin.tar.gz
|
||||
```
|
||||
|
||||
The official dataset below will resume in short future.
|
||||
|
||||
|
||||
----
|
||||
|
||||
Load and prepare data by running the following code:
|
||||
|
||||
### Get with module
|
||||
@@ -321,7 +337,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
# Main Challenges & Solutions in Quant Research
|
||||
Quant investment is an very unique scenario with lots of key challenges to be solved.
|
||||
Quant investment is a very unique scenario with lots of key challenges to be solved.
|
||||
Currently, Qlib provides some solutions for several of them.
|
||||
|
||||
## Forecasting: Finding Valuable Signals/Patterns
|
||||
@@ -360,7 +376,7 @@ Here is a list of models built on `Qlib`.
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
The performance of each model on the `Alpha158` and `Alpha360` dataset can be found [here](examples/benchmarks/README.md).
|
||||
The performance of each model on the `Alpha158` and `Alpha360` datasets can be found [here](examples/benchmarks/README.md).
|
||||
|
||||
### Run a single model
|
||||
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
|
||||
|
||||
@@ -52,7 +52,7 @@ Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency
|
||||
Qlib Format Dataset
|
||||
-------------------
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows. User can also use numpy to load `.bin` file to validate data.
|
||||
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
The price volume data look different from the actual dealing price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
Users can leverage `$factor` to get the original trading price (e.g. `$close / $factor` to get the original close price).
|
||||
|
||||
Here are some discussions about the price adjusting of Qlib.
|
||||
@@ -140,12 +140,13 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
.. code-block::
|
||||
+-----------+-------+
|
||||
| symbol | close |
|
||||
+===========+=======+
|
||||
| SH600000 | 120 |
|
||||
+-----------+-------+
|
||||
|
||||
symbol,close
|
||||
SH600000,120
|
||||
|
||||
- CSV file **must** includes a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
|
||||
- CSV file **must** include a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@@ -153,11 +154,13 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
symbol,date,close,open,volume
|
||||
SH600000,2020-11-01,120,121,12300000
|
||||
SH600000,2020-11-02,123,120,12300000
|
||||
+---------+------------+-------+------+----------+
|
||||
| symbol | date | close | open | volume |
|
||||
+=========+============+=======+======+==========+
|
||||
| SH600000| 2020-11-01 | 120 | 121 | 12300000 |
|
||||
+---------+------------+-------+------+----------+
|
||||
| SH600000| 2020-11-02 | 123 | 120 | 12300000 |
|
||||
+---------+------------+-------+------+----------+
|
||||
|
||||
|
||||
Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv_data/my_data``, they can run the following command to start the conversion.
|
||||
|
||||
@@ -86,7 +86,7 @@ Example
|
||||
},
|
||||
}
|
||||
|
||||
# model initiaiton
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
|
||||
@@ -60,4 +60,4 @@ The `[dev]` option will help you to install some related packages when developin
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -e .[dev]
|
||||
pip install -e ".[dev]"
|
||||
@@ -36,7 +36,7 @@ Name Description
|
||||
the training process of models which enable algorithms controlling the
|
||||
training process.
|
||||
|
||||
`Learning Framework` layer The `Forecast Model` and `Trading Agent` are learnable. They are learned
|
||||
`Learning Framework` layer The `Forecast Model` and `Trading Agent` are trainable. They are trained
|
||||
based on the `Learning Framework` layer and then applied to multiple scenarios
|
||||
in `Workflow` layer. The supported learning paradigms can be categorized into
|
||||
reinforcement learning and supervised learning. The learning framework
|
||||
@@ -51,7 +51,7 @@ Name Description
|
||||
modules. With these signals `Decision Generator` will generate the target
|
||||
trading decisions(i.e. portfolio, orders)
|
||||
If RL-based Strategies are adopted, the `Policy` is learned in a end-to-end way,
|
||||
the trading deicsions are generated directly.
|
||||
the trading decisions are generated directly.
|
||||
Decisions will be executed by `Execution Env`
|
||||
(i.e. the trading market). There may be multiple levels of `Strategy`
|
||||
and `Executor` (e.g. an *order executor trading strategy and intraday order executor*
|
||||
|
||||
@@ -16,7 +16,7 @@ This ``Quick Start`` guide tries to demonstrate
|
||||
Installation
|
||||
============
|
||||
|
||||
Users can easily intsall ``Qlib`` according to the following steps:
|
||||
Users can easily install ``Qlib`` according to the following steps:
|
||||
|
||||
- Before installing ``Qlib`` from source, users need to install some dependencies:
|
||||
|
||||
|
||||
@@ -5,3 +5,4 @@ scipy
|
||||
scikit-learn
|
||||
pandas
|
||||
tianshou
|
||||
sphinx_rtd_theme
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LinearModel
|
||||
module_path: qlib.contrib.model.linear
|
||||
kwargs:
|
||||
estimator: ols
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: True
|
||||
ann_scaler: 252
|
||||
- class: MultiPassPortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -136,7 +136,7 @@ If you want to contribute your new models, you can follow the steps below.
|
||||
- `README.md`: a brief introduction to your models
|
||||
- `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.
|
||||
3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).
|
||||
4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py#LL286C22-L286C22) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).
|
||||
4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).
|
||||
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
|
||||
|
||||
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))
|
||||
|
||||
@@ -324,7 +324,6 @@ class TRAModel(Model):
|
||||
|
||||
|
||||
class LSTM(nn.Module):
|
||||
|
||||
"""LSTM Model
|
||||
|
||||
Args:
|
||||
@@ -414,7 +413,6 @@ class PositionalEncoding(nn.Module):
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
"""Transformer Model
|
||||
|
||||
Args:
|
||||
@@ -475,7 +473,6 @@ class Transformer(nn.Module):
|
||||
|
||||
|
||||
class TRA(nn.Module):
|
||||
|
||||
"""Temporal Routing Adaptor (TRA)
|
||||
|
||||
TRA takes historical prediction errors & latent representation as inputs,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -35,6 +36,10 @@ class DDGDABench(DDGDA):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
auto_init()
|
||||
kwargs = {}
|
||||
if os.environ.get("PROVIDER_URI", "") == "":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
else:
|
||||
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
|
||||
auto_init(**kwargs)
|
||||
fire.Fire(DDGDABench)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -31,6 +32,10 @@ class RollingBenchmark(Rolling):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
auto_init()
|
||||
kwargs = {}
|
||||
if os.environ.get("PROVIDER_URI", "") == "":
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
else:
|
||||
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
|
||||
auto_init(**kwargs)
|
||||
fire.Fire(RollingBenchmark)
|
||||
|
||||
@@ -16,7 +16,7 @@ Current version of script with default value tries to connect localhost **via de
|
||||
|
||||
Run following command to install necessary libraries
|
||||
```
|
||||
pip install pytest coverage
|
||||
pip install pytest coverage gdown
|
||||
pip install arctic # NOTE: pip may fail to resolve the right package dependency !!! Please make sure the dependency are satisfied.
|
||||
```
|
||||
|
||||
@@ -27,13 +27,12 @@ pip install arctic # NOTE: pip may fail to resolve the right package dependency
|
||||
2. Please follow following steps to download example data
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
wget http://fintech.msra.cn/stock_data/downloads/highfreq_orderboook_example_data.tar.bz2
|
||||
tar xf highfreq_orderboook_example_data.tar.bz2
|
||||
gdown https://drive.google.com/uc?id=15nZF7tFT_eKVZAcMFL1qPS4jGyJflH7e # Proxies may be necessary here.
|
||||
python ../../scripts/get_data.py _unzip --file_path highfreq_orderbook_example_data.zip --target_dir .
|
||||
```
|
||||
|
||||
3. Please import the example data to your mongo db
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
python create_dataset.py initialize_library # Initialization Libraries
|
||||
python create_dataset.py import_data # Initialization Libraries
|
||||
```
|
||||
@@ -42,7 +41,6 @@ python create_dataset.py import_data # Initialization Libraries
|
||||
|
||||
After importing these data, you run `example.py` to create some high-frequency features.
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
pytest -s --disable-warnings example.py # If you want run all examples
|
||||
pytest -s --disable-warnings example.py::TestClass::test_exp_10 # If you want to run specific example
|
||||
```
|
||||
|
||||
@@ -20,7 +20,7 @@ We use China stock market data for our example.
|
||||
1. Prepare CSI300 weight:
|
||||
|
||||
```bash
|
||||
wget http://fintech.msra.cn/stock_data/downloads/csi300_weight.zip
|
||||
wget https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/csi300_weight.zip
|
||||
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
|
||||
rm -f csi300_weight.zip
|
||||
```
|
||||
|
||||
@@ -161,7 +161,7 @@
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# model initiaiton\n",
|
||||
"# model initialization\n",
|
||||
"model = init_instance_by_config(task[\"model\"])\n",
|
||||
"dataset = init_instance_by_config(task[\"dataset\"])\n",
|
||||
"\n",
|
||||
|
||||
2
pyproject.toml
Normal file
2
pyproject.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "numpy", "Cython"]
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.9.3"
|
||||
__version__ = "0.9.5.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
@@ -162,13 +162,15 @@ def create_account_instance(
|
||||
init_cash=init_cash,
|
||||
position_dict=position_dict,
|
||||
pos_type=pos_type,
|
||||
benchmark_config={}
|
||||
if benchmark is None
|
||||
else {
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
},
|
||||
benchmark_config=(
|
||||
{}
|
||||
if benchmark is None
|
||||
else {
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -622,9 +622,11 @@ class Indicator:
|
||||
print(
|
||||
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
freq,
|
||||
trade_start_time
|
||||
if isinstance(trade_start_time, str)
|
||||
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
(
|
||||
trade_start_time
|
||||
if isinstance(trade_start_time, str)
|
||||
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
),
|
||||
fulfill_rate,
|
||||
price_advantage,
|
||||
positive_rate,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.contrib.data.loader import Alpha158DL, Alpha360DL
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_callable_kwargs
|
||||
@@ -66,7 +67,7 @@ class Alpha360(DataHandlerLP):
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": self.get_feature_config(),
|
||||
"feature": Alpha360DL.get_feature_config(),
|
||||
"label": kwargs.pop("label", self.get_label_config()),
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
@@ -88,51 +89,6 @@ class Alpha360(DataHandlerLP):
|
||||
def get_label_config(self):
|
||||
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
|
||||
|
||||
@staticmethod
|
||||
def get_feature_config():
|
||||
# NOTE:
|
||||
# Alpha360 tries to provide a dataset with original price data
|
||||
# the original price data includes the prices and volume in the last 60 days.
|
||||
# To make it easier to learn models from this dataset, all the prices and volume
|
||||
# are normalized by the latest price and volume data ( dividing by $close, $volume)
|
||||
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
|
||||
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($close, %d)/$close" % i]
|
||||
names += ["CLOSE%d" % i]
|
||||
fields += ["$close/$close"]
|
||||
names += ["CLOSE0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($open, %d)/$close" % i]
|
||||
names += ["OPEN%d" % i]
|
||||
fields += ["$open/$close"]
|
||||
names += ["OPEN0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($high, %d)/$close" % i]
|
||||
names += ["HIGH%d" % i]
|
||||
fields += ["$high/$close"]
|
||||
names += ["HIGH0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($low, %d)/$close" % i]
|
||||
names += ["LOW%d" % i]
|
||||
fields += ["$low/$close"]
|
||||
names += ["LOW0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($vwap, %d)/$close" % i]
|
||||
names += ["VWAP%d" % i]
|
||||
fields += ["$vwap/$close"]
|
||||
names += ["VWAP0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % i]
|
||||
names += ["VOLUME%d" % i]
|
||||
fields += ["$volume/($volume+1e-12)"]
|
||||
names += ["VOLUME0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class Alpha360vwap(Alpha360):
|
||||
def get_label_config(self):
|
||||
@@ -190,242 +146,11 @@ class Alpha158(DataHandlerLP):
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
return self.parse_config_to_fields(conf)
|
||||
return Alpha158DL.get_feature_config(conf)
|
||||
|
||||
def get_label_config(self):
|
||||
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]
|
||||
|
||||
@staticmethod
|
||||
def parse_config_to_fields(config):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % d if d != 0 else "$volume/($volume+1e-12)" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
|
||||
def use(x):
|
||||
return x not in exclude and (include is None or x in include)
|
||||
|
||||
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
|
||||
if use("ROC"):
|
||||
# https://www.investopedia.com/terms/r/rateofchange.asp
|
||||
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
|
||||
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
# The rate of close price change in the past d days, divided by latest close price to remove unit
|
||||
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
# The R-sqaure value of linear regression for the past d days, represent the trend linear
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
# The max price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
# The low price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
# Used with MIN and MAX
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
# Get the percentile of current close price in past d day's close price.
|
||||
# Represent the current price level comparing to past N days, add additional information to moving average.
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
# Represent the price position between upper and lower resistent price for past d days.
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
# The number of days between current date and previous highest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
# The number of days between current date and previous lowest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
# The time period between previous lowest-price date occur after highest price date.
|
||||
# Large value suggest downward momemtum.
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
# The correlation between absolute close price and log scaled trading volume
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
# The correlation between price change ratio and volume change ratio
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
# The percentage of days in past d days that price go up.
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
# The percentage of days in past d days that price go down.
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
# The diff between past up day and past down day
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
# The total gain / the absolute total price changed
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
# The total lose / the absolute total price changed
|
||||
# Can be derived from SUMP by SUMN = 1 - SUMP
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
# The diff ratio between total gain and total lose
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
# The standard deviation for volume in past d days.
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
# The volume weighted price change volatility
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
# The diff ratio between total volume increase and total volume decrease
|
||||
# RSI indicator for volume
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class Alpha158vwap(Alpha158):
|
||||
def get_label_config(self):
|
||||
|
||||
310
qlib/contrib/data/loader.py
Normal file
310
qlib/contrib/data/loader.py
Normal file
@@ -0,0 +1,310 @@
|
||||
from qlib.data.dataset.loader import QlibDataLoader
|
||||
|
||||
|
||||
class Alpha360DL(QlibDataLoader):
|
||||
"""Dataloader to get Alpha360"""
|
||||
|
||||
def __init__(self, config=None, **kwargs):
|
||||
_config = {
|
||||
"feature": self.get_feature_config(),
|
||||
}
|
||||
if config is not None:
|
||||
_config.update(config)
|
||||
super().__init__(config=_config, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_feature_config():
|
||||
# NOTE:
|
||||
# Alpha360 tries to provide a dataset with original price data
|
||||
# the original price data includes the prices and volume in the last 60 days.
|
||||
# To make it easier to learn models from this dataset, all the prices and volume
|
||||
# are normalized by the latest price and volume data ( dividing by $close, $volume)
|
||||
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
|
||||
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($close, %d)/$close" % i]
|
||||
names += ["CLOSE%d" % i]
|
||||
fields += ["$close/$close"]
|
||||
names += ["CLOSE0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($open, %d)/$close" % i]
|
||||
names += ["OPEN%d" % i]
|
||||
fields += ["$open/$close"]
|
||||
names += ["OPEN0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($high, %d)/$close" % i]
|
||||
names += ["HIGH%d" % i]
|
||||
fields += ["$high/$close"]
|
||||
names += ["HIGH0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($low, %d)/$close" % i]
|
||||
names += ["LOW%d" % i]
|
||||
fields += ["$low/$close"]
|
||||
names += ["LOW0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($vwap, %d)/$close" % i]
|
||||
names += ["VWAP%d" % i]
|
||||
fields += ["$vwap/$close"]
|
||||
names += ["VWAP0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % i]
|
||||
names += ["VOLUME%d" % i]
|
||||
fields += ["$volume/($volume+1e-12)"]
|
||||
names += ["VOLUME0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class Alpha158DL(QlibDataLoader):
|
||||
"""Dataloader to get Alpha158"""
|
||||
|
||||
def __init__(self, config=None, **kwargs):
|
||||
_config = {
|
||||
"feature": self.get_feature_config(),
|
||||
}
|
||||
if config is not None:
|
||||
_config.update(config)
|
||||
super().__init__(config=_config, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_feature_config(
|
||||
config={
|
||||
"kbar": {},
|
||||
"price": {
|
||||
"windows": [0],
|
||||
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/($volume+1e-12)" % d if d != 0 else "$volume/($volume+1e-12)" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
|
||||
def use(x):
|
||||
return x not in exclude and (include is None or x in include)
|
||||
|
||||
# Some factor ref: https://guorn.com/static/upload/file/3/134065454575605.pdf
|
||||
if use("ROC"):
|
||||
# https://www.investopedia.com/terms/r/rateofchange.asp
|
||||
# Rate of change, the price change in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
# https://www.investopedia.com/ask/answers/071414/whats-difference-between-moving-average-and-weighted-moving-average.asp
|
||||
# Simple Moving Average, the simple moving average in the past d days, divided by latest close price to remove unit
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
# The standard diviation of close price for the past d days, divided by latest close price to remove unit
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
# The rate of close price change in the past d days, divided by latest close price to remove unit
|
||||
# For example, price increase 10 dollar per day in the past d days, then Slope will be 10.
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
# The R-sqaure value of linear regression for the past d days, represent the trend linear
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
# The max price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
# The low price for past d days, divided by latest close price to remove unit
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
# Used with MIN and MAX
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
# Get the percentile of current close price in past d day's close price.
|
||||
# Represent the current price level comparing to past N days, add additional information to moving average.
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
# Represent the price position between upper and lower resistent price for past d days.
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
# The number of days between current date and previous highest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
# The number of days between current date and previous lowest price date.
|
||||
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
|
||||
# The indicator measures the time between highs and the time between lows over a time period.
|
||||
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
# The time period between previous lowest-price date occur after highest price date.
|
||||
# Large value suggest downward momemtum.
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
# The correlation between absolute close price and log scaled trading volume
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
# The correlation between price change ratio and volume change ratio
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
# The percentage of days in past d days that price go up.
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
# The percentage of days in past d days that price go down.
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
# The diff between past up day and past down day
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
# The total gain / the absolute total price changed
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
# The total lose / the absolute total price changed
|
||||
# Can be derived from SUMP by SUMN = 1 - SUMP
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
# The diff ratio between total gain and total lose
|
||||
# Similar to RSI indicator. https://www.investopedia.com/terms/r/rsi.asp
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
# Simple Volume Moving average: https://www.barchart.com/education/technical-indicators/volume_moving_average
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
# The standard deviation for volume in past d days.
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
# The volume weighted price change volatility
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
# The total volume increase / the absolute total volume changed
|
||||
# Can be derived from VSUMP by VSUMN = 1 - VSUMP
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
# The diff ratio between total volume increase and total volume decrease
|
||||
# RSI indicator for volume
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
@@ -3,6 +3,7 @@ Here is a batch of evaluation functions.
|
||||
|
||||
The interface should be redesigned carefully in the future.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Tuple
|
||||
from qlib import get_module_logger
|
||||
|
||||
@@ -243,7 +243,7 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
trunc_days: int = None,
|
||||
rolling_ext_days: int = 0,
|
||||
exp_name: Union[str, InternalData],
|
||||
segments: Union[Dict[Text, Tuple], float],
|
||||
segments: Union[Dict[Text, Tuple], float, str],
|
||||
hist_step_n: int = 10,
|
||||
task_mode: str = MetaTask.PROC_MODE_FULL,
|
||||
fill_method: str = "max",
|
||||
@@ -271,12 +271,16 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
- str: the name of the experiment to store the performance of data
|
||||
- InternalData: a prepared internal data
|
||||
segments: Union[Dict[Text, Tuple], float]
|
||||
the segments to divide data
|
||||
both left and right
|
||||
if the segment is a Dict
|
||||
the segments to divide data
|
||||
both left and right are included
|
||||
if segments is a float:
|
||||
the float represents the percentage of data for training
|
||||
if segments is a string:
|
||||
it will try its best to put its data in training and ensure that the date `segments` is in the test set
|
||||
hist_step_n: int
|
||||
length of historical steps for the meta infomation
|
||||
Number of steps of the data similarity information
|
||||
task_mode : str
|
||||
Please refer to the docs of MetaTask
|
||||
"""
|
||||
@@ -383,10 +387,30 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
if isinstance(self.segments, float):
|
||||
train_task_n = int(len(self.meta_task_l) * self.segments)
|
||||
if segment == "train":
|
||||
return self.meta_task_l[:train_task_n]
|
||||
train_tasks = self.meta_task_l[:train_task_n]
|
||||
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
|
||||
return train_tasks
|
||||
elif segment == "test":
|
||||
return self.meta_task_l[train_task_n:]
|
||||
test_tasks = self.meta_task_l[train_task_n:]
|
||||
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
|
||||
return test_tasks
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
elif isinstance(self.segments, str):
|
||||
train_tasks = []
|
||||
test_tasks = []
|
||||
for t in self.meta_task_l:
|
||||
test_end = t.task["dataset"]["kwargs"]["segments"]["test"][1]
|
||||
if test_end is None or pd.Timestamp(test_end) < pd.Timestamp(self.segments):
|
||||
train_tasks.append(t)
|
||||
else:
|
||||
test_tasks.append(t)
|
||||
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
|
||||
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
|
||||
if segment == "train":
|
||||
return train_tasks
|
||||
elif segment == "test":
|
||||
return test_tasks
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@@ -53,7 +53,12 @@ class MetaModelDS(MetaTaskModel):
|
||||
max_epoch=100,
|
||||
seed=43,
|
||||
alpha=0.0,
|
||||
loss_skip_thresh=50,
|
||||
):
|
||||
"""
|
||||
loss_skip_size: int
|
||||
The number of threshold to skip the loss calculation for each day.
|
||||
"""
|
||||
self.step = step
|
||||
self.hist_step_n = hist_step_n
|
||||
self.clip_method = clip_method
|
||||
@@ -63,6 +68,7 @@ class MetaModelDS(MetaTaskModel):
|
||||
self.max_epoch = max_epoch
|
||||
self.fitted = False
|
||||
self.alpha = alpha
|
||||
self.loss_skip_thresh = loss_skip_thresh
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
|
||||
@@ -88,12 +94,14 @@ class MetaModelDS(MetaTaskModel):
|
||||
criterion = nn.MSELoss()
|
||||
loss = criterion(pred, meta_input["y_test"])
|
||||
elif self.criterion == "ic_loss":
|
||||
criterion = ICLoss()
|
||||
criterion = ICLoss(self.loss_skip_thresh)
|
||||
try:
|
||||
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"], skip_size=50)
|
||||
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"])
|
||||
except ValueError as e:
|
||||
get_module_logger("MetaModelDS").warning(f"Exception `{e}` when calculating IC loss")
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unknown criterion: {self.criterion}")
|
||||
|
||||
assert not np.isnan(loss.detach().item()), "NaN loss!"
|
||||
|
||||
|
||||
@@ -10,7 +10,11 @@ from qlib.log import get_module_logger
|
||||
|
||||
|
||||
class ICLoss(nn.Module):
|
||||
def forward(self, pred, y, idx, skip_size=50):
|
||||
def __init__(self, skip_size=50):
|
||||
super().__init__()
|
||||
self.skip_size = skip_size
|
||||
|
||||
def forward(self, pred, y, idx):
|
||||
"""forward.
|
||||
FIXME:
|
||||
- Some times it will be a slightly different from the result from `pandas.corr()`
|
||||
@@ -33,7 +37,7 @@ class ICLoss(nn.Module):
|
||||
skip_n = 0
|
||||
for start_i, end_i in zip(diff_point, diff_point[1:]):
|
||||
pred_focus = pred[start_i:end_i] # TODO: just for fake
|
||||
if pred_focus.shape[0] < skip_size:
|
||||
if pred_focus.shape[0] < self.skip_size:
|
||||
# skip some days which have very small amount of stock.
|
||||
skip_n += 1
|
||||
continue
|
||||
@@ -50,6 +54,7 @@ class ICLoss(nn.Module):
|
||||
)
|
||||
ic_all += ic_day
|
||||
if len(diff_point) - 1 - skip_n <= 0:
|
||||
__import__("ipdb").set_trace()
|
||||
raise ValueError("No enough data for calculating IC")
|
||||
if skip_n > 0:
|
||||
get_module_logger("ICLoss").info(
|
||||
|
||||
@@ -63,6 +63,7 @@ class LinearModel(Model):
|
||||
df_train = pd.concat([df_train, df_valid])
|
||||
except KeyError:
|
||||
get_module_logger("LinearModel").info("include_valid=True, but valid does not exist")
|
||||
df_train = df_train.dropna()
|
||||
if df_train.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
if reweighter is not None:
|
||||
|
||||
@@ -160,6 +160,10 @@ class ALSTM(Model):
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
elif self.metric == "mse":
|
||||
mask = ~torch.isnan(label)
|
||||
weight = torch.ones_like(label)
|
||||
return -self.mse(pred[mask], label[mask], weight[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import copy
|
||||
from typing import Text, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from qlib.workflow import R
|
||||
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...log import get_module_logger
|
||||
from ...model.base import Model
|
||||
from ...utils import get_or_create_path
|
||||
from .pytorch_utils import count_parameters
|
||||
|
||||
|
||||
class GRU(Model):
|
||||
@@ -212,16 +212,31 @@ class GRU(Model):
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
# prepare training and validation data
|
||||
dfs = {
|
||||
k: dataset.prepare(
|
||||
k,
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
for k in ["train", "valid"]
|
||||
if k in dataset.segments
|
||||
}
|
||||
df_train, df_valid = dfs.get("train", pd.DataFrame()), dfs.get("valid", pd.DataFrame())
|
||||
|
||||
# check if training data is empty
|
||||
if df_train.empty:
|
||||
raise ValueError("Empty training data from dataset, please check your dataset config.")
|
||||
|
||||
df_train = df_train.dropna()
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
# check if validation data is provided
|
||||
if not df_valid.empty:
|
||||
df_valid = df_valid.dropna()
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
else:
|
||||
x_valid, y_valid = None, None
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
@@ -235,32 +250,42 @@ class GRU(Model):
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
best_param = copy.deepcopy(self.gru_model.state_dict())
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.gru_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
# evaluate on validation data if provided
|
||||
if x_valid is not None and y_valid is not None:
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.gru_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.gru_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
# Logging
|
||||
rec = R.get_recorder()
|
||||
for k, v_l in evals_result.items():
|
||||
for i, v in enumerate(v_l):
|
||||
rec.log_metrics(step=i, **{k: v})
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -292,6 +317,7 @@ class GRU(Model):
|
||||
|
||||
|
||||
class GRUModel(nn.Module):
|
||||
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -511,7 +511,6 @@ class TRAModel(Model):
|
||||
|
||||
|
||||
class RNN(nn.Module):
|
||||
|
||||
"""RNN Model
|
||||
|
||||
Args:
|
||||
@@ -601,7 +600,6 @@ class PositionalEncoding(nn.Module):
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
"""Transformer Model
|
||||
|
||||
Args:
|
||||
@@ -649,7 +647,6 @@ class Transformer(nn.Module):
|
||||
|
||||
|
||||
class TRA(nn.Module):
|
||||
|
||||
"""Temporal Routing Adaptor (TRA)
|
||||
|
||||
TRA takes historical prediction errors & latent representation as inputs,
|
||||
|
||||
@@ -1,5 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Here we have a comprehensive set of analysis classes.
|
||||
|
||||
Here is an example.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.contrib.report.data.ana import FeaMeanStd
|
||||
fa = FeaMeanStd(ret_df)
|
||||
fa.plot_all(wspace=0.3, sub_figsize=(12, 3), col_n=5)
|
||||
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from qlib.contrib.report.data.base import FeaAnalyser
|
||||
@@ -152,6 +164,7 @@ class FeaSkewTurt(NumFeaAnalyser):
|
||||
self._kurt[col].plot(ax=right_ax, label="kurt", color="green")
|
||||
right_ax.set_xlabel("")
|
||||
right_ax.set_ylabel("kurt")
|
||||
right_ax.grid(None) # set the grid to None to avoid two layer of grid
|
||||
|
||||
h1, l1 = ax.get_legend_handles_labels()
|
||||
h2, l2 = right_ax.get_legend_handles_labels()
|
||||
@@ -171,12 +184,15 @@ class FeaMeanStd(NumFeaAnalyser):
|
||||
ax.set_xlabel("")
|
||||
ax.set_ylabel("mean")
|
||||
ax.legend()
|
||||
ax.tick_params(axis="x", rotation=90)
|
||||
|
||||
right_ax = ax.twinx()
|
||||
|
||||
self._std[col].plot(ax=right_ax, label="std", color="green")
|
||||
right_ax.set_xlabel("")
|
||||
right_ax.set_ylabel("std")
|
||||
right_ax.tick_params(axis="x", rotation=90)
|
||||
right_ax.grid(None) # set the grid to None to avoid two layer of grid
|
||||
|
||||
h1, l1 = ax.get_legend_handles_labels()
|
||||
h2, l2 = right_ax.get_legend_handles_labels()
|
||||
|
||||
@@ -14,6 +14,24 @@ from qlib.contrib.report.utils import sub_fig_generator
|
||||
|
||||
class FeaAnalyser:
|
||||
def __init__(self, dataset: pd.DataFrame):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : pd.DataFrame
|
||||
|
||||
We often have multiple columns for dataset. Each column corresponds to one sub figure.
|
||||
There will be a datatime column in the index levels.
|
||||
Aggretation will be used for more summarized metrics overtime.
|
||||
Here is an example of data:
|
||||
|
||||
.. code-block::
|
||||
|
||||
return
|
||||
datetime instrument
|
||||
2007-02-06 equity_tpx 0.010087
|
||||
equity_spx 0.000786
|
||||
"""
|
||||
self._dataset = dataset
|
||||
with TimeInspector.logt("calc_stat_values"):
|
||||
self.calc_stat_values()
|
||||
|
||||
@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
|
||||
def sub_fig_generator(sub_figsize=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
|
||||
"""sub_fig_generator.
|
||||
it will return a generator, each row contains <col_n> sub graph
|
||||
|
||||
@@ -13,7 +13,7 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sub_fs :
|
||||
sub_figsize :
|
||||
the figure size of each subgraph in <col_n> * <row_n> subgraphs
|
||||
col_n :
|
||||
the number of subgraph in each row; It will generating a new graph after generating <col_n> of subgraphs.
|
||||
@@ -33,7 +33,7 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None
|
||||
|
||||
while True:
|
||||
fig, axes = plt.subplots(
|
||||
row_n, col_n, figsize=(sub_fs[0] * col_n, sub_fs[1] * row_n), sharex=sharex, sharey=sharey
|
||||
row_n, col_n, figsize=(sub_figsize[0] * col_n, sub_figsize[1] * row_n), sharex=sharex, sharey=sharey
|
||||
)
|
||||
plt.subplots_adjust(wspace=wspace, hspace=hspace)
|
||||
axes = axes.reshape(row_n, col_n)
|
||||
|
||||
@@ -73,8 +73,8 @@ class Rolling:
|
||||
The horizon of the prediction target.
|
||||
This is used to override the prediction horizon of the file.
|
||||
h_path : Optional[str]
|
||||
the dumped data handler;
|
||||
It may come from other data source. It will override the data handler in the config.
|
||||
It is other data source that is dumped as a handler. It will override the data handler section in the config.
|
||||
If it is not given, it will create a customized cache for the handler when `enable_handler_cache=True`
|
||||
test_end : Optional[str]
|
||||
the test end for the data. It is typically used together with the handler
|
||||
You can do the same thing with task_ext_conf in a more complicated way
|
||||
@@ -119,7 +119,7 @@ class Rolling:
|
||||
with self.conf_path.open("r") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def _replace_hanler_with_cache(self, task: dict):
|
||||
def _replace_handler_with_cache(self, task: dict):
|
||||
"""
|
||||
Due to the data processing part in original rolling is slow. So we have to
|
||||
This class tries to add more feature
|
||||
@@ -159,13 +159,20 @@ class Rolling:
|
||||
# - get horizon automatically from the expression!!!!
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
else:
|
||||
self.logger.info("The prediction horizon is overrided")
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [
|
||||
"Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1)
|
||||
]
|
||||
if enable_handler_cache and self.h_path is not None:
|
||||
self.logger.info("Fail to override the horizon due to data handler cache")
|
||||
else:
|
||||
self.logger.info("The prediction horizon is overrided")
|
||||
if isinstance(task["dataset"]["kwargs"]["handler"], dict):
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [
|
||||
"Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1)
|
||||
]
|
||||
else:
|
||||
self.logger.warning("Try to automatically configure the lablel but failed.")
|
||||
|
||||
if enable_handler_cache:
|
||||
task = self._replace_hanler_with_cache(task)
|
||||
if self.h_path is not None or enable_handler_cache:
|
||||
# if we already have provided data source or we want to create one
|
||||
task = self._replace_handler_with_cache(task)
|
||||
task = self._update_start_end_time(task)
|
||||
|
||||
if self.task_ext_conf is not None:
|
||||
@@ -173,6 +180,16 @@ class Rolling:
|
||||
self.logger.info(task)
|
||||
return task
|
||||
|
||||
def run_basic_task(self):
|
||||
"""
|
||||
Run the basic task without rolling.
|
||||
This is for fast testing for model tunning.
|
||||
"""
|
||||
task = self.basic_task()
|
||||
print(task)
|
||||
trainer = TrainerR(experiment_name=self.exp_name)
|
||||
trainer([task])
|
||||
|
||||
def get_task_list(self) -> List[dict]:
|
||||
"""return a batch of tasks for rolling."""
|
||||
task = self.basic_task()
|
||||
|
||||
@@ -80,6 +80,11 @@ class DDGDA(Rolling):
|
||||
sim_task_model: UTIL_MODEL_TYPE = "gbdt",
|
||||
meta_1st_train_end: Optional[str] = None,
|
||||
alpha: float = 0.01,
|
||||
loss_skip_thresh: int = 50,
|
||||
fea_imp_n: Optional[int] = 30,
|
||||
meta_data_proc: Optional[str] = "V01",
|
||||
segments: Union[float, str] = 0.62,
|
||||
hist_step_n: int = 30,
|
||||
working_dir: Optional[Union[str, Path]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -94,6 +99,15 @@ class DDGDA(Rolling):
|
||||
alpha: float
|
||||
Setting the L2 regularization for ridge
|
||||
The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)
|
||||
loss_skip_thresh: int
|
||||
The thresh to skip the loss calculation for each day. If the number of item is less than it, it will skip the loss on that day.
|
||||
meta_data_proc : Optional[str]
|
||||
How we process the meta dataset for learning meta model.
|
||||
segments : Union[float, str]
|
||||
if segments is a float:
|
||||
The ratio of training data in the meta task dataset
|
||||
if segments is a string:
|
||||
it will try its best to put its data in training and ensure that the date `segments` is in the test set
|
||||
"""
|
||||
# NOTE:
|
||||
# the horizon must match the meaning in the base task template
|
||||
@@ -104,14 +118,22 @@ class DDGDA(Rolling):
|
||||
super().__init__(**kwargs)
|
||||
self.working_dir = self.conf_path.parent if working_dir is None else Path(working_dir)
|
||||
self.proxy_hd = self.working_dir / "handler_proxy.pkl"
|
||||
self.fea_imp_n = fea_imp_n
|
||||
self.meta_data_proc = meta_data_proc
|
||||
self.loss_skip_thresh = loss_skip_thresh
|
||||
self.segments = segments
|
||||
self.hist_step_n = hist_step_n
|
||||
|
||||
def _adjust_task(self, task: dict, astype: UTIL_MODEL_TYPE):
|
||||
"""
|
||||
some task are use for special purpose.
|
||||
Base on the original task, we need to do some extra things.
|
||||
|
||||
For example:
|
||||
- GBDT for calculating feature importance
|
||||
- Linear or GBDT for calculating similarity
|
||||
- Datset (well processed) that aligned to Linear that for meta learning
|
||||
|
||||
So we may need to change the dataset and model for the special purpose and other settings remains the same.
|
||||
"""
|
||||
# NOTE: here is just for aligning with previous implementation
|
||||
# It is not necessary for the current implementation
|
||||
@@ -119,12 +141,16 @@ class DDGDA(Rolling):
|
||||
if astype == "gbdt":
|
||||
task["model"] = LGBM_MODEL
|
||||
if isinstance(handler, dict):
|
||||
# We don't need preprocessing when using GBDT model
|
||||
for k in ["infer_processors", "learn_processors"]:
|
||||
if k in handler.setdefault("kwargs", {}):
|
||||
handler["kwargs"].pop(k)
|
||||
elif astype == "linear":
|
||||
task["model"] = LINEAR_MODEL
|
||||
handler["kwargs"].update(PROC_ARGS)
|
||||
if isinstance(handler, dict):
|
||||
handler["kwargs"].update(PROC_ARGS)
|
||||
else:
|
||||
self.logger.warning("The handler can't be adjusted.")
|
||||
else:
|
||||
raise ValueError(f"astype not supported: {astype}")
|
||||
return task
|
||||
@@ -155,12 +181,15 @@ class DDGDA(Rolling):
|
||||
The meta model will be trained upon the proxy forecasting model.
|
||||
This dataset is for the proxy forecasting model.
|
||||
"""
|
||||
topk = 30
|
||||
fi = self._get_feature_importance()
|
||||
col_selected = fi.nlargest(topk)
|
||||
|
||||
# NOTE: adjusting to `self.sim_task_model` just for aligning with previous implementation.
|
||||
# In previous version. The data for proxy model is using sim_task_model's way for processing
|
||||
task = self._adjust_task(self.basic_task(enable_handler_cache=False), self.sim_task_model)
|
||||
task = replace_task_handler_with_cache(task, self.working_dir)
|
||||
# if self.meta_data_proc is not None:
|
||||
# else:
|
||||
# # Otherwise, we don't need futher processing
|
||||
# task = self.basic_task()
|
||||
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -168,12 +197,18 @@ class DDGDA(Rolling):
|
||||
feature_df = prep_ds["feature"]
|
||||
label_df = prep_ds["label"]
|
||||
|
||||
feature_selected = feature_df.loc[:, col_selected.index]
|
||||
if self.fea_imp_n is not None:
|
||||
fi = self._get_feature_importance()
|
||||
col_selected = fi.nlargest(self.fea_imp_n)
|
||||
feature_selected = feature_df.loc[:, col_selected.index]
|
||||
else:
|
||||
feature_selected = feature_df
|
||||
|
||||
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
|
||||
lambda df: (df - df.mean()).div(df.std())
|
||||
)
|
||||
feature_selected = feature_selected.fillna(0.0)
|
||||
if self.meta_data_proc == "V01":
|
||||
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
|
||||
lambda df: (df - df.mean()).div(df.std())
|
||||
)
|
||||
feature_selected = feature_selected.fillna(0.0)
|
||||
|
||||
df_all = {
|
||||
"label": label_df.reindex(feature_selected.index),
|
||||
@@ -223,7 +258,10 @@ class DDGDA(Rolling):
|
||||
# 1) leverage the simplified proxy forecasting model to train meta model.
|
||||
# - Only the dataset part is important, in current version of meta model will integrate the
|
||||
|
||||
# the train_start for training meta model does not necessarily align with final rolling
|
||||
# NOTE:
|
||||
# - The train_start for training meta model does not necessarily align with final rolling
|
||||
# But please select a right time to make sure the finnal rolling tasks are not leaked in the training data.
|
||||
# - The test_start is automatically aligned to the next day of test_end. Validation is ignored.
|
||||
train_start = "2008-01-01" if self.train_start is None else self.train_start
|
||||
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
|
||||
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
@@ -249,9 +287,9 @@ class DDGDA(Rolling):
|
||||
kwargs = dict(
|
||||
task_tpl=proxy_forecast_model_task,
|
||||
step=self.step,
|
||||
segments=0.62, # keep test period consistent with the dataset yaml
|
||||
segments=self.segments, # keep test period consistent with the dataset yaml
|
||||
trunc_days=1 + self.horizon,
|
||||
hist_step_n=30,
|
||||
hist_step_n=self.hist_step_n,
|
||||
fill_method=fill_method,
|
||||
rolling_ext_days=0,
|
||||
)
|
||||
@@ -268,7 +306,13 @@ class DDGDA(Rolling):
|
||||
with R.start(experiment_name=self.meta_exp_name):
|
||||
R.log_params(**kwargs)
|
||||
mm = MetaModelDS(
|
||||
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=30, seed=43, alpha=self.alpha
|
||||
step=self.step,
|
||||
hist_step_n=kwargs["hist_step_n"],
|
||||
lr=0.001,
|
||||
max_epoch=30,
|
||||
seed=43,
|
||||
alpha=self.alpha,
|
||||
loss_skip_thresh=self.loss_skip_thresh,
|
||||
)
|
||||
mm.fit(md)
|
||||
R.save_objects(model=mm)
|
||||
|
||||
@@ -373,7 +373,6 @@ class WeightStrategyBase(BaseSignalStrategy):
|
||||
|
||||
|
||||
class EnhancedIndexingStrategy(WeightStrategyBase):
|
||||
|
||||
"""Enhanced Indexing Strategy
|
||||
|
||||
Enhanced indexing combines the arts of active management and passive management,
|
||||
|
||||
@@ -35,7 +35,7 @@ class Client:
|
||||
def connect_server(self):
|
||||
"""Connect to server."""
|
||||
try:
|
||||
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
|
||||
self.sio.connect(f"ws://{self.server_host}:{self.server_port}")
|
||||
except socketio.exceptions.ConnectionError:
|
||||
self.logger.error("Cannot connect to server - check your network or server status")
|
||||
|
||||
|
||||
@@ -536,7 +536,6 @@ class DatasetProvider(abc.ABC):
|
||||
"""
|
||||
if len(fields) == 0:
|
||||
raise ValueError("fields cannot be empty")
|
||||
fields = fields.copy()
|
||||
column_names = [str(f) for f in fields]
|
||||
return column_names
|
||||
|
||||
@@ -617,7 +616,7 @@ class DatasetProvider(abc.ABC):
|
||||
|
||||
data = pd.DataFrame(obj)
|
||||
if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")):
|
||||
# If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format
|
||||
# If the underlaying provides the data not in datetime format, we'll convert it into datetime format
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
data.index = _calendar[data.index.values.astype(int)]
|
||||
data.index.names = ["datetime"]
|
||||
|
||||
@@ -403,7 +403,7 @@ class TSDataSampler:
|
||||
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
|
||||
axis=0,
|
||||
)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716
|
||||
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
|
||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
import warnings
|
||||
import pandas as pd
|
||||
|
||||
from typing import Tuple, Union, List
|
||||
from typing import Tuple, Union, List, Dict
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
|
||||
@@ -247,10 +247,14 @@ class StaticDataLoader(DataLoader, Serializable):
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
self._maybe_load_raw_data()
|
||||
|
||||
# 1) Filter by instruments
|
||||
if instruments is None:
|
||||
df = self._data
|
||||
else:
|
||||
df = self._data.loc(axis=0)[:, instruments]
|
||||
|
||||
# 2) Filter by Datetime
|
||||
if start_time is None and end_time is None:
|
||||
return df # NOTE: avoid copy by loc
|
||||
# pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
|
||||
@@ -275,6 +279,55 @@ class StaticDataLoader(DataLoader, Serializable):
|
||||
self._data = self._config
|
||||
|
||||
|
||||
class NestedDataLoader(DataLoader):
|
||||
"""
|
||||
We have multiple DataLoader, we can use this class to combine them.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader_l: List[Dict], join="left") -> None:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataloader_l : list[dict]
|
||||
A list of dataloader, for exmaple
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
nd = NestedDataLoader(
|
||||
dataloader_l=[
|
||||
{
|
||||
"class": "qlib.contrib.data.loader.Alpha158DL",
|
||||
}, {
|
||||
"class": "qlib.contrib.data.loader.Alpha360DL",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"label": ( ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
join :
|
||||
it will pass to pd.concat when merging it.
|
||||
"""
|
||||
super().__init__()
|
||||
self.data_loader_l = [
|
||||
(dl if isinstance(dl, DataLoader) else init_instance_by_config(dl)) for dl in dataloader_l
|
||||
]
|
||||
self.join = join
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
df_full = None
|
||||
for dl in self.data_loader_l:
|
||||
df_current = dl.load(instruments, start_time, end_time)
|
||||
if df_full is None:
|
||||
df_full = df_current
|
||||
else:
|
||||
df_full = pd.merge(df_full, df_current, left_index=True, right_index=True, how=self.join)
|
||||
return df_full.sort_index(axis=1)
|
||||
|
||||
|
||||
class DataLoaderDH(DataLoader):
|
||||
"""DataLoaderDH
|
||||
DataLoader based on (D)ata (H)andler
|
||||
|
||||
@@ -318,9 +318,13 @@ class CSZScoreNorm(Processor):
|
||||
# try not modify original dataframe
|
||||
if not isinstance(self.fields_group, list):
|
||||
self.fields_group = [self.fields_group]
|
||||
for g in self.fields_group:
|
||||
cols = get_group_columns(df, g)
|
||||
df[cols] = df[cols].groupby("datetime", group_keys=False).apply(self.zscore_func)
|
||||
# depress warning by references:
|
||||
# https://stackoverflow.com/questions/20625582/how-to-deal-with-settingwithcopywarning-in-pandas
|
||||
# https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html#getting-and-setting-options
|
||||
with pd.option_context("mode.chained_assignment", None):
|
||||
for g in self.fields_group:
|
||||
cols = get_group_columns(df, g)
|
||||
df[cols] = df[cols].groupby("datetime", group_keys=False).apply(self.zscore_func)
|
||||
return df
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
||||
from qlib.data.dataset import DataHandler
|
||||
|
||||
|
||||
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
def get_level_index(df: pd.DataFrame, level: Union[str, int]) -> int:
|
||||
"""
|
||||
|
||||
get the level index of `df` given `level`
|
||||
|
||||
@@ -30,7 +30,6 @@ class Ensemble:
|
||||
|
||||
|
||||
class SingleKeyEnsemble(Ensemble):
|
||||
|
||||
"""
|
||||
Extract the object if there is only one key and value in the dict. Make the result more readable.
|
||||
{Only key: Only value} -> Only value
|
||||
@@ -64,7 +63,6 @@ class SingleKeyEnsemble(Ensemble):
|
||||
|
||||
|
||||
class RollingEnsemble(Ensemble):
|
||||
|
||||
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".
|
||||
|
||||
@@ -51,3 +51,6 @@ class MetaTask:
|
||||
Return the **processed** meta_info
|
||||
"""
|
||||
return self.meta_info
|
||||
|
||||
def __repr__(self):
|
||||
return f"MetaTask(task={self.task}, meta_info={self.meta_info})"
|
||||
|
||||
@@ -247,9 +247,7 @@ class ShrinkCovEstimator(RiskModel):
|
||||
v1 = y.T.dot(z) / t - cov_mkt[:, None] * S
|
||||
roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt
|
||||
v3 = z.T.dot(z) / t - var_mkt * S
|
||||
roff3 = (
|
||||
np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
|
||||
)
|
||||
roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
|
||||
roff = 2 * roff1 - roff3
|
||||
rho = rdiag + roff
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ def _log_task_info(task_config: dict):
|
||||
|
||||
def _exe_task(task_config: dict):
|
||||
rec = R.get_recorder()
|
||||
# model & dataset initiation
|
||||
# model & dataset initialization
|
||||
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
|
||||
reweighter: Reweighter = task_config.get("reweighter", None)
|
||||
|
||||
@@ -168,7 +168,9 @@ class TrainingVessel(TrainingVesselBase):
|
||||
self.policy.train()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)))
|
||||
collector = Collector(
|
||||
self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)), exploration_noise=True
|
||||
)
|
||||
|
||||
# Number of episodes collected in each training iteration can be overridden by fast dev run.
|
||||
if self.trainer.fast_dev_run is not None:
|
||||
|
||||
@@ -12,15 +12,11 @@ import datetime
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from cryptography.fernet import Fernet
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
class GetData:
|
||||
REMOTE_URL = "https://qlibpublic.blob.core.windows.net/data/default/stock_data"
|
||||
# "?" is not included in the token.
|
||||
TOKEN = b"gAAAAABkmDhojHc0VSCDdNK1MqmRzNLeDFXe5hy8obHpa6SDQh4de6nW5gtzuD-fa6O_WZb0yyqYOL7ndOfJX_751W3xN5YB4-n-P22jK-t6ucoZqhT70KPD0Lf0_P328QPJVZ1gDnjIdjhi2YLOcP4BFTHLNYO0mvzszR8TKm9iT5AKRvuysWnpi8bbYwGU9zAcJK3x9EPL43hOGtxliFHcPNGMBoJW4g_ercdhi0-Qgv5_JLsV-29_MV-_AhuaYvJuN2dEywBy"
|
||||
KEY = "EYcA8cgorA8X9OhyMwVfuFxn_1W3jGk6jCbs3L2oPoA="
|
||||
REMOTE_URL = "https://github.com/SunsetWolf/qlib_dataset/releases/download"
|
||||
|
||||
def __init__(self, delete_zip_file=False):
|
||||
"""
|
||||
@@ -33,9 +29,45 @@ class GetData:
|
||||
self.delete_zip_file = delete_zip_file
|
||||
|
||||
def merge_remote_url(self, file_name: str):
|
||||
fernet = Fernet(self.KEY)
|
||||
token = fernet.decrypt(self.TOKEN).decode()
|
||||
return f"{self.REMOTE_URL}/{file_name}?{token}"
|
||||
"""
|
||||
Generate download links.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_name: str
|
||||
The name of the file to be downloaded.
|
||||
The file name can be accompanied by a version number, (e.g.: v2/qlib_data_simple_cn_1d_latest.zip),
|
||||
if no version number is attached, it will be downloaded from v0 by default.
|
||||
"""
|
||||
return f"{self.REMOTE_URL}/{file_name}" if "/" in file_name else f"{self.REMOTE_URL}/v0/{file_name}"
|
||||
|
||||
def download(self, url: str, target_path: [Path, str]):
|
||||
"""
|
||||
Download a file from the specified url.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url: str
|
||||
The url of the data.
|
||||
target_path: str
|
||||
The location where the data is saved, including the file name.
|
||||
"""
|
||||
file_name = str(target_path).rsplit("/", maxsplit=1)[-1]
|
||||
resp = requests.get(url, stream=True, timeout=60)
|
||||
resp.raise_for_status()
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
chunk_size = 1024
|
||||
logger.warning(
|
||||
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
|
||||
)
|
||||
logger.info(f"{os.path.basename(file_name)} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
fp.write(chunk)
|
||||
p_bar.update(chunk_size)
|
||||
|
||||
def download_data(self, file_name: str, target_dir: [Path, str], delete_old: bool = True):
|
||||
"""
|
||||
@@ -70,21 +102,7 @@ class GetData:
|
||||
target_path = target_dir.joinpath(_target_file_name)
|
||||
|
||||
url = self.merge_remote_url(file_name)
|
||||
resp = requests.get(url, stream=True, timeout=60)
|
||||
resp.raise_for_status()
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
chunk_size = 1024
|
||||
logger.warning(
|
||||
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
|
||||
)
|
||||
logger.info(f"{os.path.basename(file_name)} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
fp.write(chunk)
|
||||
p_bar.update(chunk_size)
|
||||
self.download(url=url, target_path=target_path)
|
||||
|
||||
self._unzip(target_path, target_dir, delete_old)
|
||||
if self.delete_zip_file:
|
||||
@@ -99,7 +117,9 @@ class GetData:
|
||||
return status
|
||||
|
||||
@staticmethod
|
||||
def _unzip(file_path: Path, target_dir: Path, delete_old: bool = True):
|
||||
def _unzip(file_path: [Path, str], target_dir: [Path, str], delete_old: bool = True):
|
||||
file_path = Path(file_path)
|
||||
target_dir = Path(target_dir)
|
||||
if delete_old:
|
||||
logger.warning(
|
||||
f"will delete the old qlib data directory(features, instruments, calendars, features_cache, dataset_cache): {target_dir}"
|
||||
|
||||
@@ -25,7 +25,12 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Callable
|
||||
from packaging import version
|
||||
from .file import get_or_create_path, save_multiple_parts_file, unpack_archive_with_buffer, get_tmp_file_with_buffer
|
||||
from .file import (
|
||||
get_or_create_path,
|
||||
save_multiple_parts_file,
|
||||
unpack_archive_with_buffer,
|
||||
get_tmp_file_with_buffer,
|
||||
)
|
||||
from ..config import C
|
||||
from ..log import get_module_logger, set_log_with_config
|
||||
|
||||
@@ -37,7 +42,12 @@ is_deprecated_lexsorted_pandas = version.parse(pd.__version__) > version.parse("
|
||||
#################### Server ####################
|
||||
def get_redis_connection():
|
||||
"""get redis connection instance."""
|
||||
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password)
|
||||
return redis.StrictRedis(
|
||||
host=C.redis_host,
|
||||
port=C.redis_port,
|
||||
db=C.redis_task_db,
|
||||
password=C.redis_password,
|
||||
)
|
||||
|
||||
|
||||
#################### Data ####################
|
||||
@@ -96,7 +106,14 @@ def get_period_offset(first_year, period, quarterly):
|
||||
return offset
|
||||
|
||||
|
||||
def read_period_data(index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None):
|
||||
def read_period_data(
|
||||
index_path,
|
||||
data_path,
|
||||
period,
|
||||
cur_date_int: int,
|
||||
quarterly,
|
||||
last_period_index: int = None,
|
||||
):
|
||||
"""
|
||||
At `cur_date`(e.g. 20190102), read the information at `period`(e.g. 201803).
|
||||
Only the updating info before cur_date or at cur_date will be used.
|
||||
@@ -273,7 +290,10 @@ def parse_field(field):
|
||||
# \uff09 -> )
|
||||
chinese_punctuation_regex = r"\u3001\uff1a\uff08\uff09"
|
||||
for pattern, new in [
|
||||
(rf"\$\$([\w{chinese_punctuation_regex}]+)", r'PFeature("\1")'), # $$ must be before $
|
||||
(
|
||||
rf"\$\$([\w{chinese_punctuation_regex}]+)",
|
||||
r'PFeature("\1")',
|
||||
), # $$ must be before $
|
||||
(rf"\$([\w{chinese_punctuation_regex}]+)", r'Feature("\1")'),
|
||||
(r"(\w+\s*)\(", r"Operators.\1("),
|
||||
]: # Features # Operators
|
||||
@@ -383,7 +403,14 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
|
||||
return calendar
|
||||
|
||||
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None):
|
||||
def get_date_by_shift(
|
||||
trading_date,
|
||||
shift,
|
||||
future=False,
|
||||
clip_shift=True,
|
||||
freq="day",
|
||||
align: Optional[str] = None,
|
||||
):
|
||||
"""get trading date with shift bias will cur_date
|
||||
e.g. : shift == 1, return next trading date
|
||||
shift == -1, return previous trading date
|
||||
@@ -569,7 +596,38 @@ def exists_qlib_data(qlib_dir):
|
||||
# check instruments
|
||||
code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir()))
|
||||
_instrument = instruments_dir.joinpath("all.txt")
|
||||
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
|
||||
# Removed two possible ticker names "NA" and "NULL" from the default na_values list for column 0
|
||||
miss_code = set(
|
||||
pd.read_csv(
|
||||
_instrument,
|
||||
sep="\t",
|
||||
header=None,
|
||||
keep_default_na=False,
|
||||
na_values={
|
||||
0: [
|
||||
" ",
|
||||
"#N/A",
|
||||
"#N/A N/A",
|
||||
"#NA",
|
||||
"-1.#IND",
|
||||
"-1.#QNAN",
|
||||
"-NaN",
|
||||
"-nan",
|
||||
"1.#IND",
|
||||
"1.#QNAN",
|
||||
"<NA>",
|
||||
"N/A",
|
||||
"NaN",
|
||||
"None",
|
||||
"n/a",
|
||||
"nan",
|
||||
"null ",
|
||||
]
|
||||
},
|
||||
)
|
||||
.loc[:, 0]
|
||||
.apply(str.lower)
|
||||
) - set(code_names)
|
||||
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
|
||||
return False
|
||||
|
||||
|
||||
@@ -108,6 +108,12 @@ class Index:
|
||||
self.index_map = self.idx_list = np.arange(idx_list)
|
||||
self._is_sorted = True
|
||||
else:
|
||||
# Check if all elements in idx_list are of the same type
|
||||
if not all(isinstance(x, type(idx_list[0])) for x in idx_list):
|
||||
raise TypeError("All elements in idx_list must be of the same type")
|
||||
# Check if all elements in idx_list are of the same datetime64 precision
|
||||
if isinstance(idx_list[0], np.datetime64) and not all(x.dtype == idx_list[0].dtype for x in idx_list):
|
||||
raise TypeError("All elements in idx_list must be of the same datetime64 precision")
|
||||
self.idx_list = np.array(idx_list)
|
||||
# NOTE: only the first appearance is indexed
|
||||
self.index_map = dict(zip(self.idx_list, range(len(self))))
|
||||
@@ -131,7 +137,12 @@ class Index:
|
||||
if self.idx_list.dtype.type is np.datetime64:
|
||||
if isinstance(item, pd.Timestamp):
|
||||
# This happens often when creating index based on pandas.DatetimeIndex and query with pd.Timestamp
|
||||
return item.to_numpy()
|
||||
return item.to_numpy().astype(self.idx_list.dtype)
|
||||
elif isinstance(item, np.datetime64):
|
||||
# This happens often when creating index based on np.datetime64 and query with another precision
|
||||
return item.astype(self.idx_list.dtype)
|
||||
# NOTE: It is hard to consider every case at first.
|
||||
# We just try to cover part of cases to make it more user-friendly
|
||||
return item
|
||||
|
||||
def index(self, item) -> int:
|
||||
|
||||
@@ -161,7 +161,13 @@ def init_instance_by_config(
|
||||
# path like 'file:///<path to pickle file>/obj.pkl'
|
||||
pr = urlparse(config)
|
||||
if pr.scheme == "file":
|
||||
pr_path = os.path.join(pr.netloc, pr.path) if bool(pr.path) else pr.netloc
|
||||
|
||||
# To enable relative path like file://data/a/b/c.pkl. pr.netloc will be data
|
||||
path = pr.path
|
||||
if pr.netloc != "":
|
||||
path = path.lstrip("/")
|
||||
|
||||
pr_path = os.path.join(pr.netloc, path) if bool(pr.path) else pr.netloc
|
||||
with open(os.path.normpath(pr_path), "rb") as f:
|
||||
return pickle.load(f)
|
||||
else:
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import fire
|
||||
from jinja2 import Template, meta
|
||||
import ruamel.yaml as yaml
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import ruamel.yaml as yaml
|
||||
from qlib.config import C
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.utils.data import update_config
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.utils import set_log_with_config
|
||||
from qlib.utils.data import update_config
|
||||
|
||||
set_log_with_config(C.logging_config)
|
||||
logger = get_module_logger("qrun", logging.INFO)
|
||||
@@ -47,6 +49,39 @@ def sys_config(config, config_path):
|
||||
sys.path.append(str(Path(config_path).parent.resolve().absolute() / p))
|
||||
|
||||
|
||||
def render_template(config_path: str) -> str:
|
||||
"""
|
||||
render the template based on the environment
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str
|
||||
configuration path
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
the rendered content
|
||||
"""
|
||||
with open(config_path, "r") as f:
|
||||
config = f.read()
|
||||
# Set up the Jinja2 environment
|
||||
template = Template(config)
|
||||
|
||||
# Parse the template to find undeclared variables
|
||||
env = template.environment
|
||||
parsed_content = env.parse(config)
|
||||
variables = meta.find_undeclared_variables(parsed_content)
|
||||
|
||||
# Get context from os.environ according to the variables
|
||||
context = {var: os.getenv(var, "") for var in variables if var in os.environ}
|
||||
logger.info(f"Render the template with the context: {context}")
|
||||
|
||||
# Render the template with the context
|
||||
rendered_content = template.render(context)
|
||||
return rendered_content
|
||||
|
||||
|
||||
# workflow handler function
|
||||
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
"""
|
||||
@@ -67,8 +102,9 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
market: csi300
|
||||
|
||||
"""
|
||||
with open(config_path) as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
# Render the template
|
||||
rendered_yaml = render_template(config_path)
|
||||
config = yaml.safe_load(rendered_yaml)
|
||||
|
||||
base_config_path = config.get("BASE_CONFIG_PATH", None)
|
||||
if base_config_path:
|
||||
|
||||
@@ -90,7 +90,6 @@ class OnlineStrategy:
|
||||
|
||||
|
||||
class RollingStrategy(OnlineStrategy):
|
||||
|
||||
"""
|
||||
This example strategy always uses the latest rolling model sas online models.
|
||||
"""
|
||||
|
||||
@@ -4,8 +4,10 @@
|
||||
import logging
|
||||
import warnings
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from pprint import pprint
|
||||
from typing import Union, List, Optional
|
||||
from typing import Union, List, Optional, Dict
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from ..contrib.evaluate import risk_analysis, indicator_analysis
|
||||
@@ -17,6 +19,7 @@ from ..log import get_module_logger
|
||||
from ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_shift
|
||||
from ..utils.time import Freq
|
||||
from ..utils.data import deepcopy_basic_type
|
||||
from ..utils.exceptions import QlibException
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
|
||||
@@ -230,9 +233,16 @@ class ACRecordTemp(RecordTemp):
|
||||
except FileNotFoundError:
|
||||
logger.warning("The dependent data does not exists. Generation skipped.")
|
||||
return
|
||||
return self._generate(*args, **kwargs)
|
||||
artifact_dict = self._generate(*args, **kwargs)
|
||||
if isinstance(artifact_dict, dict):
|
||||
self.save(**artifact_dict)
|
||||
return artifact_dict
|
||||
|
||||
def _generate(self, *args, **kwargs):
|
||||
def _generate(self, *args, **kwargs) -> Dict[str, object]:
|
||||
"""
|
||||
Run the concrete generating task, return the dictionary of the generated results.
|
||||
The caller method will save the results to the recorder.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_generate` method")
|
||||
|
||||
|
||||
@@ -336,8 +346,8 @@ class SigAnaRecord(ACRecordTemp):
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.save(**objects)
|
||||
pprint(metrics)
|
||||
return objects
|
||||
|
||||
def list(self):
|
||||
paths = ["ic.pkl", "ric.pkl"]
|
||||
@@ -468,17 +478,18 @@ class PortAnaRecord(ACRecordTemp):
|
||||
if self.backtest_config["end_time"] is None:
|
||||
self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1)
|
||||
|
||||
artifact_objects = {}
|
||||
# custom strategy and get backtest
|
||||
portfolio_metric_dict, indicator_dict = normal_backtest(
|
||||
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
|
||||
)
|
||||
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
|
||||
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
|
||||
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
|
||||
artifact_objects.update({f"report_normal_{_freq}.pkl": report_normal})
|
||||
artifact_objects.update({f"positions_normal_{_freq}.pkl": positions_normal})
|
||||
|
||||
for _freq, indicators_normal in indicator_dict.items():
|
||||
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
|
||||
self.save(**{f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
|
||||
artifact_objects.update({f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
|
||||
artifact_objects.update({f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq not in portfolio_metric_dict:
|
||||
@@ -500,7 +511,7 @@ class PortAnaRecord(ACRecordTemp):
|
||||
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
artifact_objects.update({f"port_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -525,12 +536,13 @@ class PortAnaRecord(ACRecordTemp):
|
||||
analysis_dict = analysis_df["value"].to_dict()
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
artifact_objects.update({f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
pprint(f"The following are analysis results of indicators({_analysis_freq}).")
|
||||
pprint(analysis_df)
|
||||
return artifact_objects
|
||||
|
||||
def list(self):
|
||||
list_path = []
|
||||
@@ -553,3 +565,124 @@ class PortAnaRecord(ACRecordTemp):
|
||||
else:
|
||||
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
|
||||
return list_path
|
||||
|
||||
|
||||
class MultiPassPortAnaRecord(PortAnaRecord):
|
||||
"""
|
||||
This is the Multiple Pass Portfolio Analysis Record class that run backtest multiple times and generates the analysis results such as those of backtest. This class inherits the ``PortAnaRecord`` class.
|
||||
|
||||
If shuffle_init_score enabled, the prediction score of the first backtest date will be shuffled, so that initial position will be random.
|
||||
The shuffle_init_score will only works when the signal is used as <PRED> placeholder. The placeholder will be replaced by pred.pkl saved in recorder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorder : Recorder
|
||||
The recorder used to save the backtest results.
|
||||
pass_num : int
|
||||
The number of backtest passes.
|
||||
shuffle_init_score : bool
|
||||
Whether to shuffle the prediction score of the first backtest date.
|
||||
"""
|
||||
|
||||
depend_cls = SignalRecord
|
||||
|
||||
def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
recorder : Recorder
|
||||
The recorder used to save the backtest results.
|
||||
pass_num : int
|
||||
The number of backtest passes.
|
||||
shuffle_init_score : bool
|
||||
Whether to shuffle the prediction score of the first backtest date.
|
||||
"""
|
||||
self.pass_num = pass_num
|
||||
self.shuffle_init_score = shuffle_init_score
|
||||
|
||||
super().__init__(recorder, **kwargs)
|
||||
|
||||
# Save original strategy so that pred df can be replaced in next generate
|
||||
self.original_strategy = deepcopy_basic_type(self.strategy_config)
|
||||
if not isinstance(self.original_strategy, dict):
|
||||
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to be a dict")
|
||||
if "signal" not in self.original_strategy.get("kwargs", {}):
|
||||
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter")
|
||||
|
||||
def random_init(self):
|
||||
pred_df = self.load("pred.pkl")
|
||||
|
||||
all_pred_dates = pred_df.index.get_level_values("datetime")
|
||||
bt_start_date = pd.to_datetime(self.backtest_config.get("start_time"))
|
||||
if bt_start_date is None:
|
||||
first_bt_pred_date = all_pred_dates.min()
|
||||
else:
|
||||
first_bt_pred_date = all_pred_dates[all_pred_dates >= bt_start_date].min()
|
||||
|
||||
# Shuffle the first backtest date's pred score
|
||||
first_date_score = pred_df.loc[first_bt_pred_date]["score"]
|
||||
np.random.shuffle(first_date_score.values)
|
||||
|
||||
# Use shuffled signal as the strategy signal
|
||||
self.strategy_config = deepcopy_basic_type(self.original_strategy)
|
||||
self.strategy_config["kwargs"]["signal"] = pred_df
|
||||
|
||||
def _generate(self, **kwargs):
|
||||
risk_analysis_df_map = {}
|
||||
|
||||
# Collect each frequency's analysis df as df list
|
||||
for i in trange(self.pass_num):
|
||||
if self.shuffle_init_score:
|
||||
self.random_init()
|
||||
|
||||
# Not check for cache file list
|
||||
single_run_artifacts = super()._generate(**kwargs)
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, [])
|
||||
risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list
|
||||
|
||||
analysis_df = single_run_artifacts[f"port_analysis_{_analysis_freq}.pkl"]
|
||||
analysis_df["run_id"] = i
|
||||
risk_analysis_df_list.append(analysis_df)
|
||||
|
||||
result_artifacts = {}
|
||||
# Concat df list
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
combined_df = pd.concat(risk_analysis_df_map[_analysis_freq])
|
||||
|
||||
# Calculate return and information ratio's mean, std and mean/std
|
||||
multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1]).apply(
|
||||
lambda x: pd.Series(
|
||||
{"mean": x["risk"].mean(), "std": x["risk"].std(), "mean_std": x["risk"].mean() / x["risk"].std()}
|
||||
)
|
||||
)
|
||||
|
||||
# Only look at "annualized_return" and "information_ratio"
|
||||
multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[
|
||||
(slice(None), ["annualized_return", "information_ratio"]), :
|
||||
]
|
||||
pprint(multi_pass_port_analysis_df)
|
||||
|
||||
# Save new df
|
||||
result_artifacts.update({f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df})
|
||||
|
||||
# Log metrics
|
||||
metrics = flatten_dict(
|
||||
{
|
||||
"mean": multi_pass_port_analysis_df["mean"].unstack().T.to_dict(),
|
||||
"std": multi_pass_port_analysis_df["std"].unstack().T.to_dict(),
|
||||
"mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(),
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
return result_artifacts
|
||||
|
||||
def list(self):
|
||||
list_path = []
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(f"multi_pass_port_analysis_{_analysis_freq}.pkl")
|
||||
else:
|
||||
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
|
||||
return list_path
|
||||
|
||||
@@ -242,7 +242,7 @@ class TimeAdjuster:
|
||||
|
||||
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
|
||||
"""
|
||||
Shift the datatime of segment
|
||||
Shift the datetime of segment
|
||||
|
||||
If there are None (which indicates unbounded index) in the segment, this method will return None.
|
||||
|
||||
|
||||
81
scripts/data_collector/baostock_5min/README.md
Normal file
81
scripts/data_collector/baostock_5min/README.md
Normal file
@@ -0,0 +1,81 @@
|
||||
## Collector Data
|
||||
|
||||
### Get Qlib data(`bin file`)
|
||||
|
||||
- get data: `python scripts/get_data.py qlib_data`
|
||||
- parameters:
|
||||
- `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data_5min*
|
||||
- `version`: dataset version, value from [`v2`], by default `v2`
|
||||
- `v2` end date is *2022-12*
|
||||
- `interval`: `5min`
|
||||
- `region`: `hs300`
|
||||
- `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`
|
||||
- `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`
|
||||
- examples:
|
||||
```bash
|
||||
# hs300 5min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/hs300_data_5min --region hs300 --interval 5min
|
||||
```
|
||||
|
||||
### Collector *Baostock high frequency* data to qlib
|
||||
> collector *Baostock high frequency* data and *dump* into `qlib` format.
|
||||
> If the above ready-made data can't meet users' requirements, users can follow this section to crawl the latest data and convert it to qlib-data.
|
||||
1. download data to csv: `python scripts/data_collector/baostock_5min/collector.py download_data`
|
||||
|
||||
This will download the raw data such as date, symbol, open, high, low, close, volume, amount, adjustflag from baostock to a local directory. One file per symbol.
|
||||
- parameters:
|
||||
- `source_dir`: save the directory
|
||||
- `interval`: `5min`
|
||||
- `region`: `HS300`
|
||||
- `start`: start datetime, by default *None*
|
||||
- `end`: end datetime, by default *None*
|
||||
- examples:
|
||||
```bash
|
||||
# cn 5min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300
|
||||
```
|
||||
2. normalize data: `python scripts/data_collector/baostock_5min/collector.py normalize_data`
|
||||
|
||||
This will:
|
||||
1. Normalize high, low, close, open price using adjclose.
|
||||
2. Normalize the high, low, close, open price so that the first valid trading date's close price is 1.
|
||||
- parameters:
|
||||
- `source_dir`: csv directory
|
||||
- `normalize_dir`: result directory
|
||||
- `interval`: `5min`
|
||||
> if **`interval == 5min`**, `qlib_data_1d_dir` cannot be `None`
|
||||
- `region`: `HS300`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`
|
||||
- `qlib_data_1d_dir`: qlib directory(1d data)
|
||||
if interval==5min, qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data;
|
||||
```
|
||||
# qlib_data_1d can be obtained like this:
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3
|
||||
```
|
||||
- examples:
|
||||
```bash
|
||||
# normalize 5min cn
|
||||
python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min
|
||||
```
|
||||
3. dump data: `python scripts/dump_bin.py dump_all`
|
||||
|
||||
This will convert the normalized csv in `feature` directory as numpy array and store the normalized data one file per column and one symbol per directory.
|
||||
|
||||
- parameters:
|
||||
- `csv_path`: stock data path or directory, **normalize result(normalize_dir)**
|
||||
- `qlib_dir`: qlib(dump) data director
|
||||
- `freq`: transaction frequency, by default `day`
|
||||
> `freq_map = {1d:day, 5mih: 5min}`
|
||||
- `max_workers`: number of threads, by default *16*
|
||||
- `include_fields`: dump fields, by default `""`
|
||||
- `exclude_fields`: fields not dumped, by default `"""
|
||||
> dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- examples:
|
||||
```bash
|
||||
# dump 5min cn
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/hs300_5min_nor --qlib_dir ~/.qlib/qlib_data/hs300_5min_bin --freq 5min --exclude_fields date,symbol
|
||||
```
|
||||
328
scripts/data_collector/baostock_5min/collector.py
Normal file
328
scripts/data_collector/baostock_5min/collector.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import sys
|
||||
import copy
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import baostock as bs
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Iterable, List
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily, calc_adjusted_price
|
||||
|
||||
|
||||
class BaostockCollectorHS3005min(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="5min",
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
stock save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [5min], default 5min
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: int
|
||||
check data length, by default None
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
bs.login()
|
||||
super(BaostockCollectorHS3005min, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
def get_trade_calendar(self):
|
||||
_format = "%Y-%m-%d"
|
||||
start = self.start_datetime.strftime(_format)
|
||||
end = self.end_datetime.strftime(_format)
|
||||
rs = bs.query_trade_dates(start_date=start, end_date=end)
|
||||
calendar_list = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
calendar_list.append(rs.get_row_data())
|
||||
calendar_df = pd.DataFrame(calendar_list, columns=rs.fields)
|
||||
trade_calendar_df = calendar_df[~calendar_df["is_trading_day"].isin(["0"])]
|
||||
return trade_calendar_df["calendar_date"].values
|
||||
|
||||
@staticmethod
|
||||
def process_interval(interval: str):
|
||||
if interval == "1d":
|
||||
return {"interval": "d", "fields": "date,code,open,high,low,close,volume,amount,adjustflag"}
|
||||
if interval == "5min":
|
||||
return {"interval": "5", "fields": "date,time,code,open,high,low,close,volume,amount,adjustflag"}
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
df = self.get_data_from_remote(
|
||||
symbol=symbol, interval=interval, start_datetime=start_datetime, end_datetime=end_datetime
|
||||
)
|
||||
df.columns = ["date", "time", "symbol", "open", "high", "low", "close", "volume", "amount", "adjustflag"]
|
||||
df["time"] = pd.to_datetime(df["time"], format="%Y%m%d%H%M%S%f")
|
||||
df["date"] = df["time"].dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
df["date"] = df["date"].map(lambda x: pd.Timestamp(x) - pd.Timedelta(minutes=5))
|
||||
df.drop(["time"], axis=1, inplace=True)
|
||||
df["symbol"] = df["symbol"].map(lambda x: str(x).replace(".", "").upper())
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def get_data_from_remote(
|
||||
symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
df = pd.DataFrame()
|
||||
rs = bs.query_history_k_data_plus(
|
||||
symbol,
|
||||
BaostockCollectorHS3005min.process_interval(interval=interval)["fields"],
|
||||
start_date=str(start_datetime.strftime("%Y-%m-%d")),
|
||||
end_date=str(end_datetime.strftime("%Y-%m-%d")),
|
||||
frequency=BaostockCollectorHS3005min.process_interval(interval=interval)["interval"],
|
||||
adjustflag="3",
|
||||
)
|
||||
if rs.error_code == "0" and len(rs.data) > 0:
|
||||
data_list = rs.data
|
||||
columns = rs.fields
|
||||
df = pd.DataFrame(data_list, columns=columns)
|
||||
return df
|
||||
|
||||
def get_hs300_symbols(self) -> List[str]:
|
||||
hs300_stocks = []
|
||||
trade_calendar = self.get_trade_calendar()
|
||||
with tqdm(total=len(trade_calendar)) as p_bar:
|
||||
for date in trade_calendar:
|
||||
rs = bs.query_hs300_stocks(date=date)
|
||||
while rs.error_code == "0" and rs.next():
|
||||
hs300_stocks.append(rs.get_row_data())
|
||||
p_bar.update()
|
||||
return sorted({e[1] for e in hs300_stocks})
|
||||
|
||||
def get_instrument_list(self):
|
||||
logger.info("get HS stock symbols......")
|
||||
symbols = self.get_hs300_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def normalize_symbol(self, symbol: str):
|
||||
return str(symbol).replace(".", "").upper()
|
||||
|
||||
|
||||
class BaostockNormalizeHS3005min(BaseNormalize):
|
||||
COLUMNS = ["open", "close", "high", "low", "volume"]
|
||||
AM_RANGE = ("09:30:00", "11:29:00")
|
||||
PM_RANGE = ("13:00:00", "14:59:00")
|
||||
|
||||
def __init__(
|
||||
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_data_1d_dir: str, Path
|
||||
the qlib data to be updated for yahoo, usually from: Normalised to 5min using local 1d data
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
bs.login()
|
||||
qlib.init(provider_uri=qlib_data_1d_dir)
|
||||
self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
|
||||
super(BaostockNormalizeHS3005min, self).__init__(date_field_name, symbol_field_name)
|
||||
|
||||
@staticmethod
|
||||
def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series:
|
||||
df = df.copy()
|
||||
_tmp_series = df["close"].fillna(method="ffill")
|
||||
_tmp_shift_series = _tmp_series.shift(1)
|
||||
if last_close is not None:
|
||||
_tmp_shift_series.iloc[0] = float(last_close)
|
||||
change_series = _tmp_series / _tmp_shift_series - 1
|
||||
return change_series
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return self.generate_5min_from_daily(self.calendar_list_1d)
|
||||
|
||||
@property
|
||||
def calendar_list_1d(self):
|
||||
calendar_list_1d = getattr(self, "_calendar_list_1d", None)
|
||||
if calendar_list_1d is None:
|
||||
calendar_list_1d = self._get_1d_calendar_list()
|
||||
setattr(self, "_calendar_list_1d", calendar_list_1d)
|
||||
return calendar_list_1d
|
||||
|
||||
@staticmethod
|
||||
def normalize_baostock(
|
||||
df: pd.DataFrame,
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
last_close: float = None,
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
symbol = df.loc[df[symbol_field_name].first_valid_index(), symbol_field_name]
|
||||
columns = copy.deepcopy(BaostockNormalizeHS3005min.COLUMNS)
|
||||
df = df.copy()
|
||||
df.set_index(date_field_name, inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
if calendar_list is not None:
|
||||
df = df.reindex(
|
||||
pd.DataFrame(index=calendar_list)
|
||||
.loc[pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timedelta(days=1)]
|
||||
.index
|
||||
)
|
||||
df.sort_index(inplace=True)
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan
|
||||
|
||||
df["change"] = BaostockNormalizeHS3005min.calc_change(df, last_close)
|
||||
|
||||
columns += ["change"]
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
||||
|
||||
df[symbol_field_name] = symbol
|
||||
df.index.names = [date_field_name]
|
||||
return df.reset_index()
|
||||
|
||||
def generate_5min_from_daily(self, calendars: Iterable) -> pd.Index:
|
||||
return generate_minutes_calendar_from_daily(
|
||||
calendars, freq="5min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = calc_adjusted_price(
|
||||
df=df,
|
||||
_date_field_name=self._date_field_name,
|
||||
_symbol_field_name=self._symbol_field_name,
|
||||
frequence="5min",
|
||||
_1d_data_all=self.all_1d_data,
|
||||
)
|
||||
return df
|
||||
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return list(D.calendar(freq="day"))
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
df = self.normalize_baostock(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
|
||||
# adjusted price
|
||||
df = self.adjusted_price(df)
|
||||
return df
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="5min", region="HS300"):
|
||||
"""
|
||||
Changed the default value of: scripts.data_collector.base.BaseRun.
|
||||
"""
|
||||
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
||||
self.region = region
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return f"BaostockCollector{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self):
|
||||
return f"BaostockNormalize{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0.5,
|
||||
start=None,
|
||||
end=None,
|
||||
check_data_length=None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Baostock
|
||||
|
||||
Notes
|
||||
-----
|
||||
check_data_length, example:
|
||||
hs300 5min, a week: 4 * 60 * 5
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get hs300 5min data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300
|
||||
"""
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
end_date: str = None,
|
||||
qlib_data_1d_dir: str = None,
|
||||
):
|
||||
"""normalize data
|
||||
|
||||
Attention
|
||||
---------
|
||||
qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data;
|
||||
|
||||
qlib_data_1d can be obtained like this:
|
||||
$ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3
|
||||
or:
|
||||
download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min
|
||||
"""
|
||||
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
|
||||
raise ValueError(
|
||||
"If normalize 5min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
|
||||
)
|
||||
super(Run, self).normalize_data(
|
||||
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Run)
|
||||
13
scripts/data_collector/baostock_5min/requirements.txt
Normal file
13
scripts/data_collector/baostock_5min/requirements.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
loguru
|
||||
fire
|
||||
requests
|
||||
numpy
|
||||
pandas
|
||||
tqdm
|
||||
lxml
|
||||
yahooquery
|
||||
joblib
|
||||
beautifulsoup4
|
||||
bs4
|
||||
soupsieve
|
||||
baostock
|
||||
@@ -8,7 +8,7 @@ import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Type, Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
@@ -290,7 +290,7 @@ class Normalize:
|
||||
|
||||
# some symbol_field values such as TRUE, NA are decoded as True(bool), NaN(np.float) by pandas default csv parsing.
|
||||
# manually defines dtype and na_values of the symbol_field.
|
||||
default_na = pd._libs.parsers.STR_NA_VALUES
|
||||
default_na = pd._libs.parsers.STR_NA_VALUES # pylint: disable=I1101
|
||||
symbol_na = default_na.copy()
|
||||
symbol_na.remove("NA")
|
||||
columns = pd.read_csv(file_path, nrows=0).columns
|
||||
@@ -301,6 +301,7 @@ class Normalize:
|
||||
na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns},
|
||||
)
|
||||
|
||||
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if df is not None and not df.empty:
|
||||
if self._end_date is not None:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from functools import partial
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
import datetime
|
||||
|
||||
import fire
|
||||
@@ -98,7 +97,7 @@ class IBOVIndex(IndexBase):
|
||||
now = datetime.datetime.now()
|
||||
current_year = now.year
|
||||
current_month = now.month
|
||||
for year in [item for item in range(init_year, current_year)]:
|
||||
for year in [item for item in range(init_year, current_year)]: # pylint: disable=R1721
|
||||
for el in four_months_period:
|
||||
self.years_4_month_periods.append(str(year) + "_" + el)
|
||||
# For current year the logic must be a little different
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
async-generator==1.10
|
||||
attrs==21.4.0
|
||||
certifi==2021.10.8
|
||||
certifi==2022.12.7
|
||||
cffi==1.15.0
|
||||
charset-normalizer==2.0.12
|
||||
cryptography==36.0.1
|
||||
@@ -8,7 +8,7 @@ fire==0.4.0
|
||||
h11==0.13.0
|
||||
idna==3.3
|
||||
loguru==0.6.0
|
||||
lxml==4.8.0
|
||||
lxml==4.9.1
|
||||
multitasking==0.0.10
|
||||
numpy==1.22.2
|
||||
outcome==1.1.0
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
import re
|
||||
import abc
|
||||
import sys
|
||||
import datetime
|
||||
from io import BytesIO
|
||||
from typing import List, Iterable
|
||||
from pathlib import Path
|
||||
@@ -39,7 +38,7 @@ def retry_request(url: str, method: str = "get", exclude_status: List = None):
|
||||
if exclude_status is None:
|
||||
exclude_status = []
|
||||
method_func = getattr(requests, method)
|
||||
_resp = method_func(url, headers=REQ_HEADERS)
|
||||
_resp = method_func(url, headers=REQ_HEADERS, timeout=None)
|
||||
_status = _resp.status_code
|
||||
if _status not in exclude_status and _status != 200:
|
||||
raise ValueError(f"response status: {_status}, url={url}")
|
||||
@@ -397,14 +396,7 @@ class CSI500Index(CSIIndex):
|
||||
today = pd.Timestamp.now()
|
||||
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
|
||||
ret_list = []
|
||||
col = ["date", "symbol", "code_name"]
|
||||
for date in tqdm(date_range, desc="Download CSI500"):
|
||||
rs = bs.query_zz500_stocks(date=str(date))
|
||||
zz500_stocks = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
zz500_stocks.append(rs.get_row_data())
|
||||
result = pd.DataFrame(zz500_stocks, columns=col)
|
||||
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
|
||||
result = self.get_data_from_baostock(date)
|
||||
ret_list.append(result[["date", "symbol"]])
|
||||
bs.logout()
|
||||
|
||||
@@ -5,3 +5,5 @@ pandas
|
||||
lxml
|
||||
loguru
|
||||
tqdm
|
||||
yahooquery
|
||||
openpyxl
|
||||
|
||||
@@ -9,7 +9,7 @@ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage of the dataset
|
||||
> *Crypto dateset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*
|
||||
> *Crypto dataset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*
|
||||
|
||||
## Collector Data
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from dateutil.tz import tzlocal
|
||||
@@ -31,15 +30,15 @@ def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
crypto symbols in given exchanges list of coingecko
|
||||
"""
|
||||
global _CG_CRYPTO_SYMBOLS
|
||||
global _CG_CRYPTO_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_coingecko():
|
||||
try:
|
||||
cg = CoinGeckoAPI()
|
||||
resp = pd.DataFrame(cg.get_coins_markets(vs_currency="usd"))
|
||||
except:
|
||||
raise ValueError("request error")
|
||||
except Exception as e:
|
||||
raise ValueError("request error") from e
|
||||
try:
|
||||
_symbols = resp["id"].to_list()
|
||||
except Exception as e:
|
||||
|
||||
@@ -107,7 +107,7 @@ class FundCollector(BaseCollector):
|
||||
url = INDEX_BENCH_URL.format(
|
||||
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
|
||||
)
|
||||
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
|
||||
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}, timeout=None)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
@@ -116,8 +116,8 @@ class FundCollector(BaseCollector):
|
||||
|
||||
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
|
||||
SYType = data["Data"]["SYType"]
|
||||
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
|
||||
raise Exception("The fund contains 每*份收益")
|
||||
if SYType in {"每万份收益", "每百份收益", "每百万份收益"}:
|
||||
raise ValueError("The fund contains 每*份收益")
|
||||
|
||||
# TODO: should we sort the value by datetime?
|
||||
_resp = pd.DataFrame(data["Data"]["LSJZList"])
|
||||
|
||||
@@ -53,7 +53,7 @@ class CollectorFutureCalendar:
|
||||
return datetime_d.strftime(self.calendar_format)
|
||||
|
||||
def write_calendar(self, calendar: Iterable):
|
||||
calendars_list = list(map(lambda x: self._format_datetime(x), sorted(set(self.calendar_list + calendar))))
|
||||
calendars_list = [self._format_datetime(x) for x in sorted(set(self.calendar_list + calendar))]
|
||||
np.savetxt(self.future_path, calendars_list, fmt="%s", encoding="utf-8")
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
import abc
|
||||
from functools import partial
|
||||
import sys
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
@@ -113,7 +112,7 @@ class WIKIIndex(IndexBase):
|
||||
return _calendar_list
|
||||
|
||||
def _request_new_companies(self) -> requests.Response:
|
||||
resp = requests.get(self._target_url)
|
||||
resp = requests.get(self._target_url, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {self._target_url}")
|
||||
|
||||
@@ -164,7 +163,7 @@ class NASDAQ100Index(WIKIIndex):
|
||||
df = pd.read_pickle(cache_path)
|
||||
else:
|
||||
url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date)
|
||||
resp = requests.post(url)
|
||||
resp = requests.post(url, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {url}")
|
||||
df = pd.DataFrame(resp.json()["aaData"])
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import copy
|
||||
import importlib
|
||||
import time
|
||||
import bisect
|
||||
@@ -14,7 +15,6 @@ from typing import Iterable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from tqdm import tqdm
|
||||
@@ -68,7 +68,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
logger.info(f"get calendar list: {bench_code}......")
|
||||
|
||||
def _get_calendar(url):
|
||||
_value_list = requests.get(url).json()["data"]["klines"]
|
||||
_value_list = requests.get(url, timeout=None).json()["data"]["klines"]
|
||||
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
|
||||
|
||||
calendar = _CALENDAR_MAP.get(bench_code, None)
|
||||
@@ -85,12 +85,14 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
def _get_calendar(month):
|
||||
_cal = []
|
||||
try:
|
||||
resp = requests.get(SZSE_CALENDAR_URL.format(month=month, random=random.random)).json()
|
||||
resp = requests.get(
|
||||
SZSE_CALENDAR_URL.format(month=month, random=random.random), timeout=None
|
||||
).json()
|
||||
for _r in resp["data"]:
|
||||
if int(_r["jybz"]):
|
||||
_cal.append(pd.Timestamp(_r["jyrq"]))
|
||||
except Exception as e:
|
||||
raise ValueError(f"{month}-->{e}")
|
||||
raise ValueError(f"{month}-->{e}") from e
|
||||
return _cal
|
||||
|
||||
month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M")
|
||||
@@ -109,7 +111,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
|
||||
def return_date_list(date_field_name: str, file_path: Path):
|
||||
date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list()
|
||||
return sorted(map(lambda x: pd.Timestamp(x), date_list))
|
||||
return sorted([pd.Timestamp(x) for x in date_list])
|
||||
|
||||
|
||||
def get_calendar_list_by_ratio(
|
||||
@@ -155,7 +157,7 @@ def get_calendar_list_by_ratio(
|
||||
if date_list:
|
||||
all_oldest_list.append(date_list[0])
|
||||
for date in date_list:
|
||||
if date not in _dict_count_trade.keys():
|
||||
if date not in _dict_count_trade:
|
||||
_dict_count_trade[date] = 0
|
||||
|
||||
_dict_count_trade[date] += 1
|
||||
@@ -163,7 +165,7 @@ def get_calendar_list_by_ratio(
|
||||
p_bar.update()
|
||||
|
||||
logger.info(f"count how many funds have founded in this day......")
|
||||
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
|
||||
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade} # dict{date:count}
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
for oldest_date in all_oldest_list:
|
||||
for date in _dict_count_founding.keys():
|
||||
@@ -171,9 +173,7 @@ def get_calendar_list_by_ratio(
|
||||
_dict_count_founding[date] -= 1
|
||||
|
||||
calendar = [
|
||||
date
|
||||
for date in _dict_count_trade
|
||||
if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)
|
||||
date for date, count in _dict_count_trade.items() if count >= max(int(count * threshold), minimum_count)
|
||||
]
|
||||
|
||||
return calendar
|
||||
@@ -186,20 +186,46 @@ def get_hs_stock_symbols() -> list:
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _HS_SYMBOLS
|
||||
global _HS_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
def _get_symbol():
|
||||
_res = set()
|
||||
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
|
||||
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k))
|
||||
_res |= set(
|
||||
map(
|
||||
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
|
||||
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
|
||||
)
|
||||
)
|
||||
time.sleep(3)
|
||||
return _res
|
||||
"""
|
||||
Get the stock pool from a web page and process it into the format required by yahooquery.
|
||||
Format of data retrieved from the web page: 600519, 000001
|
||||
The data format required by yahooquery: 600519.ss, 000001.sz
|
||||
|
||||
Returns
|
||||
-------
|
||||
set: Returns the set of symbol codes.
|
||||
|
||||
Examples:
|
||||
-------
|
||||
{600000.ss, 600001.ss, 600002.ss, 600003.ss, ...}
|
||||
"""
|
||||
url = "http://99.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&po=1&np=1&fs=m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048&fields=f12"
|
||||
try:
|
||||
resp = requests.get(url, timeout=None)
|
||||
resp.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise requests.exceptions.HTTPError(f"Request to {url} failed with status code {resp.status_code}") from e
|
||||
|
||||
try:
|
||||
_symbols = [_v["f12"] for _v in resp.json()["data"]["diff"]]
|
||||
except Exception as e:
|
||||
logger.warning("An error occurred while extracting data from the response.")
|
||||
raise
|
||||
|
||||
if len(_symbols) < 3900:
|
||||
raise ValueError("The complete list of stocks is not available.")
|
||||
|
||||
# Add suffix after the stock code to conform to yahooquery standard, otherwise the data will not be fetched.
|
||||
_symbols = [
|
||||
_symbol + ".ss" if _symbol.startswith("6") else _symbol + ".sz" if _symbol.startswith(("0", "3")) else None
|
||||
for _symbol in _symbols
|
||||
]
|
||||
_symbols = [_symbol for _symbol in _symbols if _symbol is not None]
|
||||
|
||||
return set(_symbols)
|
||||
|
||||
if _HS_SYMBOLS is None:
|
||||
symbols = set()
|
||||
@@ -230,12 +256,12 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _US_SYMBOLS
|
||||
global _US_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12"
|
||||
resp = requests.get(url)
|
||||
resp = requests.get(url, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
@@ -277,7 +303,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"maxResultsPerPage": 10000,
|
||||
"filterToken": "",
|
||||
}
|
||||
resp = requests.post(url, json=_parms)
|
||||
resp = requests.post(url, json=_parms, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
@@ -317,7 +343,7 @@ def get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _IN_SYMBOLS
|
||||
global _IN_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_nifty():
|
||||
@@ -358,7 +384,7 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
B3 stock symbols
|
||||
"""
|
||||
global _BR_SYMBOLS
|
||||
global _BR_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_ibovespa():
|
||||
@@ -367,7 +393,7 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
|
||||
# Request
|
||||
agent = {"User-Agent": "Mozilla/5.0"}
|
||||
page = requests.get(url, headers=agent)
|
||||
page = requests.get(url, headers=agent, timeout=None)
|
||||
|
||||
# BeautifulSoup
|
||||
soup = BeautifulSoup(page.content, "html.parser")
|
||||
@@ -375,7 +401,7 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
|
||||
children = tbody.findChildren("a", recursive=True)
|
||||
for child in children:
|
||||
_symbols.append(str(child).split('"')[-1].split(">")[1].split("<")[0])
|
||||
_symbols.append(str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0])
|
||||
|
||||
return _symbols
|
||||
|
||||
@@ -409,12 +435,12 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
fund symbols in China
|
||||
"""
|
||||
global _EN_FUND_SYMBOLS
|
||||
global _EN_FUND_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://fund.eastmoney.com/js/fundcode_search.js"
|
||||
resp = requests.get(url)
|
||||
resp = requests.get(url, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
@@ -605,5 +631,177 @@ def get_instruments(
|
||||
getattr(obj, method)()
|
||||
|
||||
|
||||
def _get_all_1d_data(_date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame):
|
||||
df = copy.deepcopy(_1d_data_all)
|
||||
df.reset_index(inplace=True)
|
||||
df.rename(columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, inplace=True)
|
||||
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
|
||||
return df
|
||||
|
||||
|
||||
def get_1d_data(
|
||||
_date_field_name: str,
|
||||
_symbol_field_name: str,
|
||||
symbol: str,
|
||||
start: str,
|
||||
end: str,
|
||||
_1d_data_all: pd.DataFrame,
|
||||
) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
data_1d.columns = [_date_field_name, _symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
_all_1d_data = _get_all_1d_data(_date_field_name, _symbol_field_name, _1d_data_all)
|
||||
return _all_1d_data[
|
||||
(_all_1d_data[_symbol_field_name] == symbol.upper())
|
||||
& (_all_1d_data[_date_field_name] >= pd.Timestamp(start))
|
||||
& (_all_1d_data[_date_field_name] < pd.Timestamp(end))
|
||||
]
|
||||
|
||||
|
||||
def calc_adjusted_price(
|
||||
df: pd.DataFrame,
|
||||
_1d_data_all: pd.DataFrame,
|
||||
_date_field_name: str,
|
||||
_symbol_field_name: str,
|
||||
frequence: str,
|
||||
consistent_1d: bool = True,
|
||||
calc_paused: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
"""calc adjusted price
|
||||
This method does 4 things.
|
||||
1. Adds the `paused` field.
|
||||
- The added paused field comes from the paused field of the 1d data.
|
||||
2. Aligns the time of the 1d data.
|
||||
3. The data is reweighted.
|
||||
- The reweighting method:
|
||||
- volume / factor
|
||||
- open * factor
|
||||
- high * factor
|
||||
- low * factor
|
||||
- close * factor
|
||||
4. Called `calc_paused_num` method to add the `paused_num` field.
|
||||
- The `paused_num` is the number of consecutive days of trading suspension.
|
||||
"""
|
||||
# TODO: using daily data factor
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df.drop_duplicates(subset=_date_field_name, inplace=True)
|
||||
df.sort_values(_date_field_name, inplace=True)
|
||||
symbol = df.iloc[0][_symbol_field_name]
|
||||
df[_date_field_name] = pd.to_datetime(df[_date_field_name])
|
||||
# get 1d data from qlib
|
||||
_start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d")
|
||||
_end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
data_1d: pd.DataFrame = get_1d_data(_date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all)
|
||||
data_1d = data_1d.copy()
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"]
|
||||
# TODO: np.nan or 1 or 0
|
||||
df["paused"] = np.nan
|
||||
else:
|
||||
# NOTE: volume is np.nan or volume <= 0, paused = 1
|
||||
# FIXME: find a more accurate data source
|
||||
data_1d["paused"] = 0
|
||||
data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1
|
||||
data_1d = data_1d.set_index(_date_field_name)
|
||||
|
||||
# add factor from 1d data
|
||||
# NOTE: 1d data info:
|
||||
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
|
||||
def _calc_factor(df_1d: pd.DataFrame):
|
||||
try:
|
||||
_date = pd.Timestamp(pd.Timestamp(df_1d[_date_field_name].iloc[0]).date())
|
||||
df_1d["factor"] = data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"]
|
||||
df_1d["paused"] = data_1d.loc[_date]["paused"]
|
||||
except Exception:
|
||||
df_1d["factor"] = np.nan
|
||||
df_1d["paused"] = np.nan
|
||||
return df_1d
|
||||
|
||||
df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply(_calc_factor)
|
||||
if consistent_1d:
|
||||
# the date sequence is consistent with 1d
|
||||
df.set_index(_date_field_name, inplace=True)
|
||||
df = df.reindex(
|
||||
generate_minutes_calendar_from_daily(
|
||||
calendars=pd.to_datetime(data_1d.reset_index()[_date_field_name].drop_duplicates()),
|
||||
freq=frequence,
|
||||
am_range=("09:30:00", "11:29:00"),
|
||||
pm_range=("13:00:00", "14:59:00"),
|
||||
)
|
||||
)
|
||||
df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][_symbol_field_name]
|
||||
df.index.names = [_date_field_name]
|
||||
df.reset_index(inplace=True)
|
||||
for _col in ["open", "close", "high", "low", "volume"]:
|
||||
if _col not in df.columns:
|
||||
continue
|
||||
if _col == "volume":
|
||||
df[_col] = df[_col] / df["factor"]
|
||||
else:
|
||||
df[_col] = df[_col] * df["factor"]
|
||||
if calc_paused:
|
||||
df = calc_paused_num(df, _date_field_name, _symbol_field_name)
|
||||
return df
|
||||
|
||||
|
||||
def calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name):
|
||||
"""calc paused num
|
||||
This method adds the paused_num field
|
||||
- The `paused_num` is the number of consecutive days of trading suspension.
|
||||
"""
|
||||
_symbol = df.iloc[0][_symbol_field_name]
|
||||
df = df.copy()
|
||||
df["_tmp_date"] = df[_date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
# remove data that starts and ends with `np.nan` all day
|
||||
all_data = []
|
||||
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
|
||||
all_nan_nums = 0
|
||||
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
|
||||
not_nan_nums = 0
|
||||
for _date, _df in df.groupby("_tmp_date"):
|
||||
_df["paused"] = 0
|
||||
if not _df.loc[_df["volume"] < 0].empty:
|
||||
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
|
||||
_df.loc[_df["volume"] < 0, "volume"] = np.nan
|
||||
|
||||
check_fields = set(_df.columns) - {
|
||||
"_tmp_date",
|
||||
"paused",
|
||||
"factor",
|
||||
_date_field_name,
|
||||
_symbol_field_name,
|
||||
}
|
||||
if _df.loc[:, list(check_fields)].isna().values.all() or (_df["volume"] == 0).all():
|
||||
all_nan_nums += 1
|
||||
not_nan_nums = 0
|
||||
_df["paused"] = 1
|
||||
if all_data:
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
else:
|
||||
all_nan_nums = 0
|
||||
not_nan_nums += 1
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
all_data = all_data[: len(all_data) - all_nan_nums]
|
||||
if all_data:
|
||||
df = pd.concat(all_data, sort=False)
|
||||
else:
|
||||
logger.warning(f"data is empty: {_symbol}")
|
||||
df = pd.DataFrame()
|
||||
return df
|
||||
del df["_tmp_date"]
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
|
||||
|
||||
@@ -121,7 +121,7 @@ pip install -r requirements.txt
|
||||
|
||||
qlib_data_1d can be obtained like this:
|
||||
$ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d
|
||||
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
|
||||
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --end_date <end_date>
|
||||
or:
|
||||
download 1d data from YahooFinance
|
||||
|
||||
@@ -180,9 +180,8 @@ pip install -r requirements.txt
|
||||
|
||||
* Manual update of data
|
||||
```
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --end_date <end date>
|
||||
```
|
||||
* `trading_date`: start of trading day
|
||||
* `end_date`: end of trading day(not included)
|
||||
* `check_data_length`: check the number of rows per *symbol*, by default `None`
|
||||
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
|
||||
@@ -191,10 +190,10 @@ pip install -r requirements.txt
|
||||
* `source_dir`: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
* `normalize_dir`: Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
* `qlib_data_1d_dir`: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
|
||||
* `trading_date`: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
* `end_date`: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
* `region`: region, value from ["CN", "US"], default "CN"
|
||||
|
||||
* `interval`: interval, default "1d"(Currently only supports 1d data)
|
||||
* `exists_skip`: exists skip, by default False
|
||||
|
||||
## Using qlib data
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
from re import I
|
||||
import sys
|
||||
import copy
|
||||
import time
|
||||
@@ -21,6 +20,8 @@ from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.utils import code_to_fname, fname_to_code, exists_qlib_data
|
||||
from qlib.constant import REG_CN as REGION_CN
|
||||
@@ -38,6 +39,7 @@ from data_collector.utils import (
|
||||
get_in_stock_symbols,
|
||||
get_br_stock_symbols,
|
||||
generate_minutes_calendar_from_daily,
|
||||
calc_adjusted_price,
|
||||
)
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
@@ -229,9 +231,9 @@ class YahooCollectorCN1d(YahooCollectorCN):
|
||||
df = pd.DataFrame(
|
||||
map(
|
||||
lambda x: x.split(","),
|
||||
requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[
|
||||
"data"
|
||||
]["klines"],
|
||||
requests.get(
|
||||
INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end), timeout=None
|
||||
).json()["data"]["klines"],
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -316,7 +318,7 @@ class YahooCollectorIN1min(YahooCollectorIN):
|
||||
|
||||
|
||||
class YahooCollectorBR(YahooCollector, ABC):
|
||||
def retry(cls):
|
||||
def retry(cls): # pylint: disable=E0213
|
||||
"""
|
||||
The reason to use retry=2 is due to the fact that
|
||||
Yahoo Finance unfortunately does not keep track of some
|
||||
@@ -356,12 +358,10 @@ class YahooCollectorBR(YahooCollector, ABC):
|
||||
|
||||
class YahooCollectorBR1d(YahooCollectorBR):
|
||||
retry = 2
|
||||
pass
|
||||
|
||||
|
||||
class YahooCollectorBR1min(YahooCollectorBR):
|
||||
retry = 2
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalize(BaseNormalize):
|
||||
@@ -393,6 +393,7 @@ class YahooNormalize(BaseNormalize):
|
||||
df = df.copy()
|
||||
df.set_index(date_field_name, inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df.index = df.index.tz_localize(None)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
if calendar_list is not None:
|
||||
df = df.reindex(
|
||||
@@ -522,78 +523,39 @@ class YahooNormalize1dExtend(YahooNormalize1d):
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name)
|
||||
self._first_close_field = "first_close"
|
||||
self._ori_close_field = "ori_close"
|
||||
self.column_list = ["open", "high", "low", "close", "volume", "factor", "change"]
|
||||
self.old_qlib_data = self._get_old_data(old_qlib_data_dir)
|
||||
|
||||
def _get_old_data(self, qlib_data_dir: [str, Path]):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())
|
||||
qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)
|
||||
df = D.features(D.instruments("all"), ["$close/$factor", "$adjclose/$close"])
|
||||
df.columns = [self._ori_close_field, self._first_close_field]
|
||||
df = D.features(D.instruments("all"), ["$" + col for col in self.column_list])
|
||||
df.columns = self.column_list
|
||||
return df
|
||||
|
||||
def _get_close(self, df: pd.DataFrame, field_name: str):
|
||||
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
|
||||
_df = self.old_qlib_data.loc(axis=0)[_symbol]
|
||||
_close = _df.loc[_df.last_valid_index()][field_name]
|
||||
return _close
|
||||
|
||||
def _get_first_close(self, df: pd.DataFrame) -> float:
|
||||
try:
|
||||
_close = self._get_close(df, field_name=self._first_close_field)
|
||||
except KeyError:
|
||||
_close = super(YahooNormalize1dExtend, self)._get_first_close(df)
|
||||
return _close
|
||||
|
||||
def _get_last_close(self, df: pd.DataFrame) -> float:
|
||||
try:
|
||||
_close = self._get_close(df, field_name=self._ori_close_field)
|
||||
except KeyError:
|
||||
_close = None
|
||||
return _close
|
||||
|
||||
def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp:
|
||||
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
|
||||
try:
|
||||
_df = self.old_qlib_data.loc(axis=0)[_symbol]
|
||||
_date = _df.index.max()
|
||||
except KeyError:
|
||||
_date = None
|
||||
return _date
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
_last_close = self._get_last_close(df)
|
||||
# reindex
|
||||
_last_date = self._get_last_date(df)
|
||||
if _last_date is not None:
|
||||
df = df.set_index(self._date_field_name)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
_max_date = df.index.max()
|
||||
df = df.reindex(self._calendar_list).loc[:_max_date].reset_index()
|
||||
df = df[df[self._date_field_name] > _last_date]
|
||||
if df.empty:
|
||||
return pd.DataFrame()
|
||||
_si = df["close"].first_valid_index()
|
||||
if _si > df.index[0]:
|
||||
logger.warning(
|
||||
f"{df.loc[_si][self._symbol_field_name]} missing data: {df.loc[:_si - 1][self._date_field_name].to_list()}"
|
||||
)
|
||||
# normalize
|
||||
df = self.normalize_yahoo(
|
||||
df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close
|
||||
)
|
||||
# adjusted price
|
||||
df = self.adjusted_price(df)
|
||||
df = self._manual_adj_data(df)
|
||||
return df
|
||||
df = super(YahooNormalize1dExtend, self).normalize(df)
|
||||
df.set_index(self._date_field_name, inplace=True)
|
||||
symbol_name = df[self._symbol_field_name].iloc[0]
|
||||
old_symbol_list = self.old_qlib_data.index.get_level_values("instrument").unique().to_list()
|
||||
if str(symbol_name).upper() not in old_symbol_list:
|
||||
return df.reset_index()
|
||||
old_df = self.old_qlib_data.loc[str(symbol_name).upper()]
|
||||
latest_date = old_df.index[-1]
|
||||
df = df.loc[latest_date:]
|
||||
new_latest_data = df.iloc[0]
|
||||
old_latest_data = old_df.loc[latest_date]
|
||||
for col in self.column_list[:-1]:
|
||||
if col == "volume":
|
||||
df[col] = df[col] / (new_latest_data[col] / old_latest_data[col])
|
||||
else:
|
||||
df[col] = df[col] * (old_latest_data[col] / new_latest_data[col])
|
||||
return df.drop(df.index[0]).reset_index()
|
||||
|
||||
|
||||
class YahooNormalize1min(YahooNormalize, ABC):
|
||||
"""Normalised to 1min using local 1d data"""
|
||||
|
||||
AM_RANGE = None # type: tuple # eg: ("09:30:00", "11:29:00")
|
||||
PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00")
|
||||
|
||||
@@ -601,160 +563,6 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
CONSISTENT_1d = True
|
||||
CALC_PAUSED_NUM = True
|
||||
|
||||
@property
|
||||
def calendar_list_1d(self):
|
||||
calendar_list_1d = getattr(self, "_calendar_list_1d", None)
|
||||
if calendar_list_1d is None:
|
||||
calendar_list_1d = self._get_1d_calendar_list()
|
||||
setattr(self, "_calendar_list_1d", calendar_list_1d)
|
||||
return calendar_list_1d
|
||||
|
||||
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
|
||||
return generate_minutes_calendar_from_daily(
|
||||
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
|
||||
if not (data_1d is None or data_1d.empty):
|
||||
_class_name = self.__class__.__name__.replace("min", "d")
|
||||
_class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name)
|
||||
data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
|
||||
data_1d = data_1d_obj.normalize(data_1d)
|
||||
return data_1d
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# TODO: using daily data factor
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df = df.sort_values(self._date_field_name)
|
||||
symbol = df.iloc[0][self._symbol_field_name]
|
||||
# get 1d data from yahoo
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
|
||||
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
|
||||
data_1d = data_1d.copy()
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"]
|
||||
# TODO: np.nan or 1 or 0
|
||||
df["paused"] = np.nan
|
||||
else:
|
||||
# NOTE: volume is np.nan or volume <= 0, paused = 1
|
||||
# FIXME: find a more accurate data source
|
||||
data_1d["paused"] = 0
|
||||
data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1
|
||||
data_1d = data_1d.set_index(self._date_field_name)
|
||||
|
||||
# add factor from 1d data
|
||||
# NOTE: yahoo 1d data info:
|
||||
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
|
||||
def _calc_factor(df_1d: pd.DataFrame):
|
||||
try:
|
||||
_date = pd.Timestamp(pd.Timestamp(df_1d[self._date_field_name].iloc[0]).date())
|
||||
df_1d["factor"] = (
|
||||
data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"]
|
||||
)
|
||||
df_1d["paused"] = data_1d.loc[_date]["paused"]
|
||||
except Exception:
|
||||
df_1d["factor"] = np.nan
|
||||
df_1d["paused"] = np.nan
|
||||
return df_1d
|
||||
|
||||
df = df.groupby([df[self._date_field_name].dt.date]).apply(_calc_factor)
|
||||
|
||||
if self.CONSISTENT_1d:
|
||||
# the date sequence is consistent with 1d
|
||||
df.set_index(self._date_field_name, inplace=True)
|
||||
df = df.reindex(
|
||||
self.generate_1min_from_daily(
|
||||
pd.to_datetime(data_1d.reset_index()[self._date_field_name].drop_duplicates())
|
||||
)
|
||||
)
|
||||
df[self._symbol_field_name] = df.loc[df[self._symbol_field_name].first_valid_index()][
|
||||
self._symbol_field_name
|
||||
]
|
||||
df.index.names = [self._date_field_name]
|
||||
df.reset_index(inplace=True)
|
||||
for _col in self.COLUMNS:
|
||||
if _col not in df.columns:
|
||||
continue
|
||||
if _col == "volume":
|
||||
df[_col] = df[_col] / df["factor"]
|
||||
else:
|
||||
df[_col] = df[_col] * df["factor"]
|
||||
|
||||
if self.CALC_PAUSED_NUM:
|
||||
df = self.calc_paused_num(df)
|
||||
return df
|
||||
|
||||
def calc_paused_num(self, df: pd.DataFrame):
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
df = df.copy()
|
||||
df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
# remove data that starts and ends with `np.nan` all day
|
||||
all_data = []
|
||||
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
|
||||
all_nan_nums = 0
|
||||
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
|
||||
not_nan_nums = 0
|
||||
for _date, _df in df.groupby("_tmp_date"):
|
||||
_df["paused"] = 0
|
||||
if not _df.loc[_df["volume"] < 0].empty:
|
||||
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
|
||||
_df.loc[_df["volume"] < 0, "volume"] = np.nan
|
||||
|
||||
check_fields = set(_df.columns) - {
|
||||
"_tmp_date",
|
||||
"paused",
|
||||
"factor",
|
||||
self._date_field_name,
|
||||
self._symbol_field_name,
|
||||
}
|
||||
if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all():
|
||||
all_nan_nums += 1
|
||||
not_nan_nums = 0
|
||||
_df["paused"] = 1
|
||||
if all_data:
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
else:
|
||||
all_nan_nums = 0
|
||||
not_nan_nums += 1
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
all_data = all_data[: len(all_data) - all_nan_nums]
|
||||
if all_data:
|
||||
df = pd.concat(all_data, sort=False)
|
||||
else:
|
||||
logger.warning(f"data is empty: {_symbol}")
|
||||
df = pd.DataFrame()
|
||||
return df
|
||||
del df["_tmp_date"]
|
||||
return df
|
||||
|
||||
@abc.abstractmethod
|
||||
def symbol_to_yahoo(self, symbol):
|
||||
raise NotImplementedError("rewrite symbol_to_yahoo")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
raise NotImplementedError("rewrite _get_1d_calendar_list")
|
||||
|
||||
|
||||
class YahooNormalize1minOffline(YahooNormalize1min):
|
||||
"""Normalised to 1min using local 1d data"""
|
||||
|
||||
def __init__(
|
||||
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
@@ -769,42 +577,45 @@ class YahooNormalize1minOffline(YahooNormalize1min):
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self.qlib_data_1d_dir = qlib_data_1d_dir
|
||||
super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name)
|
||||
self._all_1d_data = self._get_all_1d_data()
|
||||
super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name)
|
||||
qlib.init(provider_uri=qlib_data_1d_dir)
|
||||
self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
|
||||
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
return list(D.calendar(freq="day"))
|
||||
|
||||
def _get_all_1d_data(self):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
@property
|
||||
def calendar_list_1d(self):
|
||||
calendar_list_1d = getattr(self, "_calendar_list_1d", None)
|
||||
if calendar_list_1d is None:
|
||||
calendar_list_1d = self._get_1d_calendar_list()
|
||||
setattr(self, "_calendar_list_1d", calendar_list_1d)
|
||||
return calendar_list_1d
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
|
||||
df.reset_index(inplace=True)
|
||||
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
|
||||
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
|
||||
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
|
||||
return generate_minutes_calendar_from_daily(
|
||||
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = calc_adjusted_price(
|
||||
df=df,
|
||||
_date_field_name=self._date_field_name,
|
||||
_symbol_field_name=self._symbol_field_name,
|
||||
frequence="1min",
|
||||
consistent_1d=self.CONSISTENT_1d,
|
||||
calc_paused=self.CALC_PAUSED_NUM,
|
||||
_1d_data_all=self.all_1d_data,
|
||||
)
|
||||
return df
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
@abc.abstractmethod
|
||||
def symbol_to_yahoo(self, symbol):
|
||||
raise NotImplementedError("rewrite symbol_to_yahoo")
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
return self._all_1d_data[
|
||||
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
|
||||
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
|
||||
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
|
||||
]
|
||||
@abc.abstractmethod
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
raise NotImplementedError("rewrite _get_1d_calendar_list")
|
||||
|
||||
|
||||
class YahooNormalizeUS:
|
||||
@@ -821,7 +632,7 @@ class YahooNormalizeUS1dExtend(YahooNormalizeUS, YahooNormalize1dExtend):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline):
|
||||
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
@@ -844,7 +655,7 @@ class YahooNormalizeIN1d(YahooNormalizeIN, YahooNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1minOffline):
|
||||
class YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1min):
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
@@ -872,7 +683,7 @@ class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline):
|
||||
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
AM_RANGE = ("09:30:00", "11:29:00")
|
||||
PM_RANGE = ("13:00:00", "14:59:00")
|
||||
|
||||
@@ -899,7 +710,7 @@ class YahooNormalizeBR1d(YahooNormalizeBR, YahooNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeBR1min(YahooNormalizeBR, YahooNormalize1minOffline):
|
||||
class YahooNormalizeBR1min(YahooNormalizeBR, YahooNormalize1min):
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
@@ -985,6 +796,9 @@ class Run(BaseRun):
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
if self.interval == "1d" and pd.Timestamp(end) > pd.Timestamp(datetime.datetime.now().strftime("%Y-%m-%d")):
|
||||
raise ValueError(f"end_date: {end} is greater than the current date.")
|
||||
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(
|
||||
@@ -1123,10 +937,10 @@ class Run(BaseRun):
|
||||
def update_data_to_bin(
|
||||
self,
|
||||
qlib_data_1d_dir: str,
|
||||
trading_date: str = None,
|
||||
end_date: str = None,
|
||||
check_data_length: int = None,
|
||||
delay: float = 1,
|
||||
exists_skip: bool = False,
|
||||
):
|
||||
"""update yahoo data to bin
|
||||
|
||||
@@ -1135,14 +949,14 @@ class Run(BaseRun):
|
||||
qlib_data_1d_dir: str
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
|
||||
trading_date: str
|
||||
trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
end_date: str
|
||||
end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
delay: float
|
||||
time.sleep(delay), default 1
|
||||
exists_skip: bool
|
||||
exists skip, by default False
|
||||
Notes
|
||||
-----
|
||||
If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day
|
||||
@@ -1150,24 +964,24 @@ class Run(BaseRun):
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
# get 1m data
|
||||
"""
|
||||
|
||||
if self.interval.lower() != "1d":
|
||||
logger.warning(f"currently supports 1d data updates: --interval 1d")
|
||||
|
||||
# start/end date
|
||||
if trading_date is None:
|
||||
trading_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
logger.warning(f"trading_date is None, use the current date: {trading_date}")
|
||||
|
||||
if end_date is None:
|
||||
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
# download qlib 1d data
|
||||
qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
|
||||
if not exists_qlib_data(qlib_data_1d_dir):
|
||||
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
|
||||
GetData().qlib_data(
|
||||
target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region, exists_skip=exists_skip
|
||||
)
|
||||
|
||||
# start/end date
|
||||
calendar_df = pd.read_csv(Path(qlib_data_1d_dir).joinpath("calendars/day.txt"))
|
||||
trading_date = (pd.Timestamp(calendar_df.iloc[-1, 0]) - pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
if end_date is None:
|
||||
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
# download data from yahoo
|
||||
# NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1
|
||||
|
||||
@@ -135,7 +135,7 @@ class DumpDataBase:
|
||||
|
||||
def _get_source_data(self, file_path: Path) -> pd.DataFrame:
|
||||
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
|
||||
df[self.date_field_name] = df[self.date_field_name].astype(str).astype(np.datetime64)
|
||||
df[self.date_field_name] = df[self.date_field_name].astype(str).astype("datetime64[ns]")
|
||||
# df.drop_duplicates([self.date_field_name], inplace=True)
|
||||
return df
|
||||
|
||||
@@ -146,9 +146,7 @@ class DumpDataBase:
|
||||
return (
|
||||
self._include_fields
|
||||
if self._include_fields
|
||||
else set(df_columns) - set(self._exclude_fields)
|
||||
if self._exclude_fields
|
||||
else df_columns
|
||||
else set(df_columns) - set(self._exclude_fields) if self._exclude_fields else df_columns
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -176,7 +174,7 @@ class DumpDataBase:
|
||||
def save_calendars(self, calendars_data: list):
|
||||
self._calendars_dir.mkdir(parents=True, exist_ok=True)
|
||||
calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
|
||||
result_calendars_list = list(map(lambda x: self._format_datetime(x), calendars_data))
|
||||
result_calendars_list = [self._format_datetime(x) for x in calendars_data]
|
||||
np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8")
|
||||
|
||||
def save_instruments(self, instruments_data: Union[list, pd.DataFrame]):
|
||||
@@ -195,7 +193,7 @@ class DumpDataBase:
|
||||
def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame:
|
||||
# calendars
|
||||
calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name])
|
||||
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
|
||||
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype("datetime64[ns]")
|
||||
cal_df = calendars_df[
|
||||
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
|
||||
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
|
||||
|
||||
@@ -3,24 +3,21 @@
|
||||
"""
|
||||
TODO:
|
||||
- A more well-designed PIT database is required.
|
||||
- seperated insert, delete, update, query operations are required.
|
||||
- separated insert, delete, update, query operations are required.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import shutil
|
||||
import struct
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Union
|
||||
from typing import Iterable
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from qlib.utils import fname_to_code, code_to_fname, get_period_offset
|
||||
from qlib.utils import fname_to_code, get_period_offset
|
||||
from qlib.config import C
|
||||
|
||||
|
||||
@@ -135,9 +132,11 @@ class DumpPitData:
|
||||
return (
|
||||
set(self._include_fields)
|
||||
if self._include_fields
|
||||
else set(df[self.field_column_name]) - set(self._exclude_fields)
|
||||
if self._exclude_fields
|
||||
else set(df[self.field_column_name])
|
||||
else (
|
||||
set(df[self.field_column_name]) - set(self._exclude_fields)
|
||||
if self._exclude_fields
|
||||
else set(df[self.field_column_name])
|
||||
)
|
||||
)
|
||||
|
||||
def get_filenames(self, symbol, field, interval):
|
||||
|
||||
18
setup.py
18
setup.py
@@ -46,7 +46,7 @@ if not _CYTHON_INSTALLED:
|
||||
REQUIRED = [
|
||||
"numpy>=1.12.0, <1.24",
|
||||
"pandas>=0.25.1",
|
||||
"scipy>=1.0.0",
|
||||
"scipy>=1.7.3",
|
||||
"requests>=2.18.0",
|
||||
"sacred>=0.7.4",
|
||||
"python-socketio",
|
||||
@@ -65,18 +65,24 @@ REQUIRED = [
|
||||
# To ensure stable operation of the experiment manager, we have limited the version of mlflow,
|
||||
# and we need to verify whether version 2.0 of mlflow can serve qlib properly.
|
||||
"mlflow>=1.12.1, <=1.30.0",
|
||||
# mlflow 1.30.0 requires packaging<22, so we limit the packaging version, otherwise the CI will fail.
|
||||
"packaging<22",
|
||||
"tqdm",
|
||||
"loguru",
|
||||
"lightgbm>=3.3.0",
|
||||
"tornado",
|
||||
"joblib>=0.17.0",
|
||||
"ruamel.yaml>=0.16.12",
|
||||
# With the upgrading of ruamel.yaml to 0.18, the safe_load method was deprecated,
|
||||
# which would cause qlib.workflow.cli to not work properly,
|
||||
# and no good replacement has been found, so the version of ruamel.yaml has been restricted for now.
|
||||
# Refs: https://pypi.org/project/ruamel.yaml/
|
||||
"ruamel.yaml<=0.17.36",
|
||||
"pymongo==3.7.2", # For task management
|
||||
"scikit-learn>=0.22",
|
||||
"dill",
|
||||
"dataclasses;python_version<'3.7'",
|
||||
"filelock",
|
||||
"jinja2<3.1.0", # for passing the readthedocs workflow.
|
||||
"jinja2",
|
||||
"gym",
|
||||
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
|
||||
"protobuf<=3.20.1;python_version<='3.8'",
|
||||
@@ -140,7 +146,8 @@ setup(
|
||||
"wheel",
|
||||
"setuptools",
|
||||
"black",
|
||||
"pylint",
|
||||
# Version 3.0 of pylint had problems with the build process, so we limited the version of pylint.
|
||||
"pylint<=2.17.6",
|
||||
# Using the latest versions(0.981 and 0.982) of mypy,
|
||||
# the error "multiprocessing.Value()" is detected in the file "qlib/rl/utils/data_queue.py",
|
||||
# If this is fixed in a subsequent version of mypy, then we will revert to the latest version of mypy.
|
||||
@@ -159,6 +166,9 @@ setup(
|
||||
"lxml",
|
||||
"baostock",
|
||||
"yahooquery",
|
||||
# 2024-05-30 scs has released a new version: 3.2.4.post2,
|
||||
# this version, causes qlib installation to fail, so we've limited the scs version a bit for now.
|
||||
"scs<=3.2.4",
|
||||
"beautifulsoup4",
|
||||
# In version 0.4.11 of tianshou, the code:
|
||||
# logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
|
||||
50
tests/data_mid_layer_tests/test_dataloader.py
Normal file
50
tests/data_mid_layer_tests/test_dataloader.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# TODO:
|
||||
# dump alpha 360 to dataframe and merge it with Alpha158
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import qlib
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent))
|
||||
from qlib.data.dataset.loader import NestedDataLoader
|
||||
from qlib.contrib.data.loader import Alpha158DL, Alpha360DL
|
||||
|
||||
|
||||
class TestDataLoader(unittest.TestCase):
|
||||
|
||||
def test_nested_data_loader(self):
|
||||
qlib.init()
|
||||
nd = NestedDataLoader(
|
||||
dataloader_l=[
|
||||
{
|
||||
"class": "qlib.contrib.data.loader.Alpha158DL",
|
||||
},
|
||||
{
|
||||
"class": "qlib.contrib.data.loader.Alpha360DL",
|
||||
"kwargs": {"config": {"label": (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])}},
|
||||
},
|
||||
]
|
||||
)
|
||||
# Of course you can use StaticDataLoader
|
||||
|
||||
dataset = nd.load()
|
||||
|
||||
assert dataset is not None
|
||||
|
||||
columns = dataset.columns.tolist()
|
||||
columns_list = [tup[1] for tup in columns]
|
||||
|
||||
for col in Alpha158DL.get_feature_config()[1]:
|
||||
assert col in columns_list
|
||||
|
||||
for col in Alpha360DL.get_feature_config()[1]:
|
||||
assert col in columns_list
|
||||
|
||||
assert "LABEL0" in columns_list
|
||||
|
||||
# Then you can use it wth DataHandler;
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -5,8 +5,9 @@ import unittest
|
||||
import pytest
|
||||
import sys
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.data.dataset import TSDatasetH
|
||||
from qlib.data.dataset import TSDatasetH, TSDataSampler
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import time
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
@@ -98,6 +99,54 @@ class TestDataset(TestAutoData):
|
||||
print(idx[i])
|
||||
|
||||
|
||||
class TestTSDataSampler(unittest.TestCase):
|
||||
def test_TSDataSampler(self):
|
||||
"""
|
||||
Test TSDataSampler for issue #1716
|
||||
"""
|
||||
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
|
||||
instruments = ["000001", "000002", "000003", "000004", "000005"]
|
||||
index = pd.MultiIndex.from_product(
|
||||
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
|
||||
)
|
||||
data = np.random.randn(len(datetime_list) * len(instruments))
|
||||
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
|
||||
dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2)
|
||||
print()
|
||||
print("--------------dataset[0]--------------")
|
||||
print(dataset[0])
|
||||
print("--------------dataset[1]--------------")
|
||||
print(dataset[1])
|
||||
assert len(dataset[0]) == 2
|
||||
self.assertTrue(np.isnan(dataset[0][0]))
|
||||
self.assertEqual(dataset[0][1], dataset[1][0])
|
||||
self.assertEqual(dataset[1][1], dataset[2][0])
|
||||
self.assertEqual(dataset[2][1], dataset[3][0])
|
||||
|
||||
def test_TSDataSampler2(self):
|
||||
"""
|
||||
Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front
|
||||
"""
|
||||
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
|
||||
instruments = ["000001", "000002", "000003", "000004", "000005"]
|
||||
index = pd.MultiIndex.from_product(
|
||||
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
|
||||
)
|
||||
data = np.random.randn(len(datetime_list) * len(instruments))
|
||||
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
|
||||
dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3)
|
||||
print()
|
||||
print("--------------dataset[0]--------------")
|
||||
print(dataset[0])
|
||||
print("--------------dataset[1]--------------")
|
||||
print(dataset[1])
|
||||
for i in range(3):
|
||||
self.assertFalse(np.isnan(dataset[0][i]))
|
||||
self.assertFalse(np.isnan(dataset[1][i]))
|
||||
self.assertEqual(dataset[0][1], dataset[1][0])
|
||||
self.assertEqual(dataset[0][2], dataset[1][1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=10)
|
||||
|
||||
|
||||
@@ -94,6 +94,24 @@ class IndexDataTest(unittest.TestCase):
|
||||
print(sd)
|
||||
self.assertTrue(sd.iloc[0] == 2)
|
||||
|
||||
# test different precisions of time data
|
||||
timeindex = [
|
||||
np.datetime64("2024-06-22T00:00:00.000000000"),
|
||||
np.datetime64("2024-06-21T00:00:00.000000000"),
|
||||
np.datetime64("2024-06-20T00:00:00.000000000"),
|
||||
]
|
||||
sd = idd.SingleData([1, 2, 3], index=timeindex)
|
||||
self.assertTrue(
|
||||
sd.index.index(np.datetime64("2024-06-21T00:00:00.000000000"))
|
||||
== sd.index.index(np.datetime64("2024-06-21T00:00:00"))
|
||||
)
|
||||
self.assertTrue(sd.index.index(pd.Timestamp("2024-06-21 00:00")) == 1)
|
||||
|
||||
# Bad case: the input is not aligned
|
||||
timeindex[1] = (np.datetime64("2024-06-21T00:00:00.00"),)
|
||||
with self.assertRaises(TypeError):
|
||||
sd = idd.SingleData([1, 2, 3], index=timeindex)
|
||||
|
||||
def test_ops(self):
|
||||
sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
|
||||
@@ -27,7 +27,7 @@ def train(uri_path: str = None):
|
||||
model performance
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
# model initialization
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
# To test __repr__
|
||||
|
||||
@@ -13,7 +13,9 @@ from pathlib import Path
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.tests.data import GetData
|
||||
from scripts.dump_pit import DumpPitData
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from dump_pit import DumpPitData
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts/data_collector/pit")))
|
||||
from collector import Run
|
||||
|
||||
@@ -9,7 +9,9 @@ from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
class WorkflowTest(TestAutoData):
|
||||
TMP_PATH = Path("./.mlruns_tmp/")
|
||||
# Creating the directory manually doesn't work with mlflow,
|
||||
# so we add a subfolder named .trash when we create the directory.
|
||||
TMP_PATH = Path("./.mlruns_tmp/.trash")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
if self.TMP_PATH.exists():
|
||||
@@ -17,6 +19,8 @@ class WorkflowTest(TestAutoData):
|
||||
|
||||
def test_get_local_dir(self):
|
||||
""" """
|
||||
self.TMP_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with R.start(uri=str(self.TMP_PATH)):
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user