mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 00:51:19 +08:00
Compare commits
113 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92055d64ec | ||
|
|
b9809a4c33 | ||
|
|
fc243fd29b | ||
|
|
b6a8bd5b80 | ||
|
|
6ee0fe366c | ||
|
|
55b6ff123e | ||
|
|
45ea4bae4e | ||
|
|
17d472cf01 | ||
|
|
c500a01226 | ||
|
|
114c38b4c3 | ||
|
|
414c3082c0 | ||
|
|
3fc2f8c93c | ||
|
|
66ff3e5bf6 | ||
|
|
8ff68a182e | ||
|
|
a105ef1d76 | ||
|
|
d02965ea70 | ||
|
|
b8d1e08010 | ||
|
|
51709c20d8 | ||
|
|
28c99c77be | ||
|
|
bb5cdfe050 | ||
|
|
fb21c591bb | ||
|
|
5279e71423 | ||
|
|
f35254c288 | ||
|
|
5e82c18cb2 | ||
|
|
2759e8c28d | ||
|
|
2461575d30 | ||
|
|
867667531d | ||
|
|
0fc52333b7 | ||
|
|
ab9b6dc47a | ||
|
|
4c5a4d5cd7 | ||
|
|
e84cc23589 | ||
|
|
707399a245 | ||
|
|
6e88ccca88 | ||
|
|
ee5f3de800 | ||
|
|
3605cd7b96 | ||
|
|
d1cbf4c3d9 | ||
|
|
6011a21308 | ||
|
|
76a05f37a9 | ||
|
|
c99494eb76 | ||
|
|
e8126b0c39 | ||
|
|
8f4d320832 | ||
|
|
e2739ac72c | ||
|
|
19d15ddc38 | ||
|
|
12af8f304b | ||
|
|
25b771ddf1 | ||
|
|
1158472489 | ||
|
|
84d2cb3226 | ||
|
|
509bfcb02e | ||
|
|
6608a40965 | ||
|
|
3e75cead93 | ||
|
|
6697f209d4 | ||
|
|
e3b57b1901 | ||
|
|
82a5223166 | ||
|
|
398131cff7 | ||
|
|
e71e2f941c | ||
|
|
0483406c12 | ||
|
|
da1f4db968 | ||
|
|
a7c41b6969 | ||
|
|
5b7b48e376 | ||
|
|
4f9f978909 | ||
|
|
319a2f38cc | ||
|
|
a2c38c979e | ||
|
|
07655f2d5b | ||
|
|
9303415666 | ||
|
|
05d28469ad | ||
|
|
dc6859bdd9 | ||
|
|
a6f9dde006 | ||
|
|
1d22ee56d3 | ||
|
|
3810a4cd33 | ||
|
|
48af7126b6 | ||
|
|
025b1dcff9 | ||
|
|
29e66b2dea | ||
|
|
698e59ac72 | ||
|
|
e006ef40ad | ||
|
|
59d4bc9394 | ||
|
|
b07e0bffb1 | ||
|
|
161343018f | ||
|
|
bee031af68 | ||
|
|
35840606a8 | ||
|
|
2df9b6e076 | ||
|
|
0c3eaf3f16 | ||
|
|
2eee064eb8 | ||
|
|
096ef5a62b | ||
|
|
dd0eebed53 | ||
|
|
7b20abeda1 | ||
|
|
5519420efd | ||
|
|
eb3c5b3088 | ||
|
|
f03df874bf | ||
|
|
8fa22bd2e1 | ||
|
|
d1c8d885aa | ||
|
|
bf7732e284 | ||
|
|
3f5334ab39 | ||
|
|
c97a96363d | ||
|
|
2023f714c9 | ||
|
|
f8a2b0533b | ||
|
|
3183a232df | ||
|
|
8b715268bd | ||
|
|
28cb827a23 | ||
|
|
b723f14619 | ||
|
|
47535ba530 | ||
|
|
d70e5a4f88 | ||
|
|
3b8087677c | ||
|
|
4ec41ea0e7 | ||
|
|
cfcd9fb1f8 | ||
|
|
457dcaa466 | ||
|
|
3c740fc2de | ||
|
|
6d91f28474 | ||
|
|
be8653c505 | ||
|
|
a8974ce535 | ||
|
|
79026e5390 | ||
|
|
4610e16ac2 | ||
|
|
b504cc6ac8 | ||
|
|
d5059e609f |
2
.github/workflows/python-publish.yml
vendored
2
.github/workflows/python-publish.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
95
.github/workflows/test.yml
vendored
95
.github/workflows/test.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Test
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -12,8 +12,8 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -25,96 +25,41 @@ jobs:
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install black
|
||||
$CONDA\\python.exe -m black qlib -l 120 --check --diff
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install black
|
||||
$CONDA/bin/python -m black qlib -l 120 --check --diff
|
||||
fi
|
||||
shell: bash
|
||||
pip install --upgrade pip
|
||||
pip install black wheel
|
||||
black qlib -l 120 --check --diff
|
||||
|
||||
# Test Qlib installed with pip
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install numpy==1.19.5
|
||||
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml numpy --user
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install numpy==1.19.5
|
||||
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
pip install numpy==1.19.5 ruamel.yaml
|
||||
pip install pyqlib --ignore-installed
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
else
|
||||
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
fi
|
||||
shell: bash
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
$CONDA\\python.exe -m pip uninstall -y pyqlib
|
||||
else
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
sudo $CONDA/bin/python -m pip uninstall -y pyqlib
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
# Test Qlib installed from source
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade cython
|
||||
$CONDA\\python.exe -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
$CONDA\\python.exe -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
$CONDA\\python.exe setup.py install
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade cython
|
||||
sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
sudo $CONDA/bin/python setup.py install
|
||||
fi
|
||||
shell: bash
|
||||
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install -e .
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade pip
|
||||
$CONDA\\python.exe -m pip install black pytest
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade pip
|
||||
sudo $CONDA/bin/python -m pip install black pytest
|
||||
fi
|
||||
shell: bash
|
||||
pip install --upgrade pip
|
||||
pip install black pytest
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pytest . --durations=0
|
||||
else
|
||||
$CONDA/bin/python -m pytest . --durations=0
|
||||
fi
|
||||
shell: bash
|
||||
python -m pytest . --durations=10
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
else
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
fi
|
||||
shell: bash
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
|
||||
72
.github/workflows/test_macos.yml
vendored
Normal file
72
.github/workflows/test_macos.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
# There are some issues (in the downloading data phase) on MacOS when running with other tests. So we split it into an individual config.
|
||||
name: Test MacOS
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: macos-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
python -m pip install pip --upgrade
|
||||
python -m pip install wheel --upgrade
|
||||
python -m pip install black
|
||||
python -m black qlib -l 120 --check --diff
|
||||
# Test Qlib installed with pip
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
python -m pip install numpy==1.19.5
|
||||
python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
- name: Install Lightgbm for MacOS
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
# FIX MacOS error: Segmentation fault
|
||||
# reference: https://github.com/microsoft/LightGBM/issues/4229
|
||||
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/fb8323f2b170bd4ae97e1bac9bf3e2983af3fdb0/Formula/libomp.rb
|
||||
brew unlink libomp
|
||||
brew install libomp.rb
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
python setup.py install
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -U pyopenssl idna
|
||||
python -m pip install black pytest
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
python -m pytest . --durations=0
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
1
MANIFEST.in
Normal file
1
MANIFEST.in
Normal file
@@ -0,0 +1 @@
|
||||
include qlib/VERSION.txt
|
||||
34
README.md
34
README.md
@@ -11,6 +11,9 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
|
||||
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
|
||||
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
|
||||
| TCTS Model | [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |
|
||||
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
|
||||
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
|
||||
@@ -21,8 +24,6 @@ Recent released features
|
||||
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
|
||||
</p>
|
||||
@@ -43,7 +44,7 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
|
||||
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
|
||||
- [**Quant Model Zoo**](#quant-model-zoo)
|
||||
- [**Quant Model(Paper) Zoo**](#quant-model-paper-zoo)
|
||||
- [Run a single model](#run-a-single-model)
|
||||
- [Run multiple models](#run-multiple-models)
|
||||
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
|
||||
@@ -105,8 +106,9 @@ This table demonstrates the supported Python version of `Qlib`:
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
|
||||
**Note**:
|
||||
1. **Conda** is suggested for managing your Python environment.
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
|
||||
### Install with pip
|
||||
Users can easily install ``Qlib`` by pip according to the following command.
|
||||
@@ -160,7 +162,7 @@ Users could create the same dataset with it.
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
|
||||
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
### Automatic update of daily frequency data(from yahoo finance)
|
||||
### Automatic update of daily frequency data (from yahoo finance)
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
|
||||
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
@@ -274,7 +276,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
# [Quant Model (Paper) Zoo](examples/benchmarks)
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](qlib/contrib/model/xgboost.py)
|
||||
@@ -290,6 +292,9 @@ Here is a list of models built on `Qlib`.
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](qlib/contrib/model/pytorch_tcts.py)
|
||||
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py)
|
||||
- [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py)
|
||||
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](qlib/contrib/model/pytorch_tra.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -303,9 +308,10 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
|
||||
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
|
||||
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.
|
||||
|
||||
@@ -370,9 +376,7 @@ Such overheads greatly slow down the data loading process.
|
||||
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
|
||||
|
||||
# Related Reports
|
||||
- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA)
|
||||
- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/)
|
||||
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
|
||||
- [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
|
||||
- [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
|
||||
|
||||
@@ -389,7 +393,17 @@ Join IM discussion groups:
|
||||
|
||||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
This project welcomes contributions and suggestions.
|
||||
**Here are some
|
||||
[code standards](docs/developer/code_standard.rst) when you submit a pull request.**
|
||||
|
||||
If you want to contribute to Qlib's document, you can follow the steps in the figure below.
|
||||
<p align="center">
|
||||
<img src="https://github.com/demon143/qlib/blob/main/docs/_static/img/change%20doc.gif" />
|
||||
</p>
|
||||
|
||||
|
||||
Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the right to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
|
||||
1
VERSION.txt
Normal file
1
VERSION.txt
Normal file
@@ -0,0 +1 @@
|
||||
0.7.2
|
||||
@@ -97,4 +97,57 @@ Also, feel free to post a new issue in our GitHub repository. We always check ea
|
||||
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
|
||||
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
|
||||
|
||||
|
||||
4. BadNamespaceError: / is not a connected namespace
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
File "qlib_online.py", line 35, in <module>
|
||||
cal = D.calendar()
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
|
||||
return Cal.calendar(start_time, end_time, freq, future=future)
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
|
||||
self.conn.send_request(
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\python_socketio-5.3.0-py3.8.egg\socketio\client.py", line 369, in emit
|
||||
raise exceptions.BadNamespaceError(
|
||||
BadNamespaceError: / is not a connected namespace.
|
||||
|
||||
- The version of ``python-socketio`` in qlib needs to be the same as the version of ``python-socketio`` in qlib-server:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U python-socketio==<qlib-server python-socketio version>
|
||||
|
||||
|
||||
5. TypeError: send() got an unexpected keyword argument 'binary'
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
File "qlib_online.py", line 35, in <module>
|
||||
cal = D.calendar()
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
|
||||
return Cal.calendar(start_time, end_time, freq, future=future)
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
|
||||
self.conn.send_request(
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 263, in emit
|
||||
self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 339, in _send_packet
|
||||
self.eio.send(ep, binary=binary)
|
||||
TypeError: send() got an unexpected keyword argument 'binary'
|
||||
|
||||
|
||||
- The ``python-engineio`` version needs to be compatible with the ``python-socketio`` version, reference: https://github.com/miguelgrinberg/python-socketio#version-compatibility
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U python-engineio==<compatible python-socketio version>
|
||||
# or
|
||||
pip install -U python-socketio==3.1.2 python-engineio==3.13.2
|
||||
|
||||
BIN
docs/_static/img/change doc.gif
vendored
Normal file
BIN
docs/_static/img/change doc.gif
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.3 MiB |
@@ -179,6 +179,7 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
The Restoration factor. Normally, ``factor = adjusted_price / original_price``, `adjusted price` reference: `split adjusted <https://www.investopedia.com/terms/s/splitadjusted.asp>`_
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
If you want to use your own alpha-factor which can't be calculate by OCHLV, like PE, EPS and so on, you could add it to the CSV files with OHCLV together and then dump it to the Qlib format data.
|
||||
|
||||
Stock Pool (Market)
|
||||
--------------------------------
|
||||
|
||||
@@ -21,6 +21,8 @@ which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online S
|
||||
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
|
||||
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
|
||||
|
||||
**NOTE**: User should keep his data source updated to support online serving. For example, Qlib provides `a batch of scripts <https://github.com/microsoft/qlib/blob/main/scripts/data_collector/yahoo/README.md#automatic-update-of-daily-frequency-datafrom-yahoo-finance>`_ to help users update Yahoo daily data.
|
||||
|
||||
Online Manager
|
||||
=============
|
||||
|
||||
@@ -43,4 +45,4 @@ Updater
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
:members:
|
||||
|
||||
22
docs/developer/code_standard.rst
Normal file
22
docs/developer/code_standard.rst
Normal file
@@ -0,0 +1,22 @@
|
||||
.. _code_standard:
|
||||
|
||||
=================================
|
||||
Code Standard
|
||||
=================================
|
||||
|
||||
Docstring
|
||||
=================================
|
||||
Please use the `Numpydoc Style <https://stackoverflow.com/a/24385103>`_.
|
||||
|
||||
Continuous Integration
|
||||
=================================
|
||||
Continuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.
|
||||
|
||||
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
|
||||
|
||||
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pip install black
|
||||
python -m black . -l 120
|
||||
@@ -241,6 +241,7 @@ Online Tool
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
|
||||
RecordUpdater
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
@@ -257,4 +258,4 @@ Serializable
|
||||
:members:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -77,7 +77,8 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
})
|
||||
- `mongo`
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
|
||||
Users need to follow the steps in `installation <https://www.mongodb.com/try/download/community>`_ to install MongoDB firstly and then access it via a URI.
|
||||
Users can access mongodb with credential by setting "task_url" to a string like `"mongodb://%s:%s@%s" % (user, pwd, host + ":" + port)`.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
|
||||
16
examples/benchmarks/LightGBM/features_sample.py
Normal file
16
examples/benchmarks/LightGBM/features_sample.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import datetime
|
||||
import pandas as pd
|
||||
|
||||
from qlib.data.inst_processor import InstProcessor
|
||||
|
||||
|
||||
class Resample1minProcessor(InstProcessor):
|
||||
def __init__(self, hour: int, minute: int, **kwargs):
|
||||
self.hour = hour
|
||||
self.minute = minute
|
||||
|
||||
def __call__(self, df: pd.DataFrame, *args, **kwargs):
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df.loc[df.index.time == datetime.time(self.hour, self.minute)]
|
||||
df.index = df.index.normalize()
|
||||
return df
|
||||
@@ -63,4 +63,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
qlib_init:
|
||||
provider_uri:
|
||||
day: "~/.qlib/qlib_data/cn_data"
|
||||
1min: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
dataset_cache: null
|
||||
maxtasksperchild: 1
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
# 1min closing time is 15:00:00
|
||||
end_time: "2020-08-01 15:00:00"
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
freq:
|
||||
label: day
|
||||
feature: 1min
|
||||
# with label as reference
|
||||
inst_processor:
|
||||
feature:
|
||||
- class: Resample1minProcessor
|
||||
module_path: features_sample.py
|
||||
kwargs:
|
||||
hour: 14
|
||||
minute: 56
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -78,4 +78,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
3
examples/benchmarks/Localformer/requirements.txt
Normal file
3
examples/benchmarks/Localformer/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
torch==1.2.0
|
||||
@@ -0,0 +1,82 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LocalformerModel
|
||||
module_path: qlib.contrib.model.pytorch_localformer_ts
|
||||
kwargs:
|
||||
seed: 0
|
||||
n_jobs: 20
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,73 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LocalformerModel
|
||||
module_path: qlib.contrib.model.pytorch_localformer
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
seed: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,6 +1,6 @@
|
||||
# Benchmarks Performance
|
||||
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs.
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
|
||||
|
||||
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
|
||||
|
||||
@@ -23,6 +23,9 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0491±0.01 | 0.3868±0.06 | 0.0589±0.00 | 0.4802±0.04 | 0.0898±0.02 | 1.2490±0.32 | -0.0778±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -39,6 +42,10 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features)| 0.0409±0.00 | 0.3253±0.04 | 0.0488±0.00 | 0.4045±0.02 | 0.0673±0.02 | 1.0389±0.39 | -0.0830±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0442±0.00 | 0.3426±0.03 | 0.0555±0.00 | 0.4395±0.03 | 0.0833±0.03 | 1.2064±0.36 | -0.0849±0.02 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
|
||||
|
||||
### Notes
|
||||
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
|
||||
1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.
|
||||
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
|
||||
3. The model must run in GPU, or an error will be raised.
|
||||
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
tensorflow-gpu==1.15.0
|
||||
numpy == 1.19.4
|
||||
pandas==1.1.0
|
||||
pandas==1.1.0
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow.compat.v1 as tf
|
||||
@@ -243,7 +245,7 @@ class TFTModel(ModelFT):
|
||||
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
|
||||
# 0.9)
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
print("Training completed.".format(dte.datetime.now()))
|
||||
print("Training completed at {}.".format(dte.datetime.now()))
|
||||
# ===========================Training Process===========================
|
||||
|
||||
def predict(self, dataset):
|
||||
@@ -289,3 +291,24 @@ class TFTModel(ModelFT):
|
||||
dataset for finetuning
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_pickle(self, path: Union[Path, str]):
|
||||
"""
|
||||
Tensorflow model can't be dumped directly.
|
||||
So the data should be save seperatedly
|
||||
|
||||
**TODO**: Please implement the function to load the files
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Union[Path, str]
|
||||
the target path to be dumped
|
||||
"""
|
||||
# save tensorflow model
|
||||
# path = Path(path)
|
||||
# path.mkdir(parents=True)
|
||||
# self.model.save(path)
|
||||
|
||||
# save qlib model wrapper
|
||||
self.model = None
|
||||
super(TFTModel, self).to_pickle(path / "qlib_model")
|
||||
|
||||
@@ -1,53 +1,77 @@
|
||||
# Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport
|
||||
|
||||
This code provides a PyTorch implementation for TRA (Temporal Routing Adaptor), as described in the paper [Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport](http://arxiv.org/abs/2106.12950).
|
||||
Temporal Routing Adaptor (TRA) is designed to capture multiple trading patterns in the stock market data. Please refer to [our paper](http://arxiv.org/abs/2106.12950) for more details.
|
||||
|
||||
* TRA (Temporal Routing Adaptor) is a lightweight module that consists of a set of independent predictors for learning multiple patterns as well as a router to dispatch samples to different predictors.
|
||||
* We also design a learning algorithm based on Optimal Transport (OT) to obtain the optimal sample to predictor assignment and effectively optimize the router with such assignment through an auxiliary loss term.
|
||||
If you find our work useful in your research, please cite:
|
||||
```
|
||||
@inproceedings{HengxuKDD2021,
|
||||
author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
|
||||
title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
|
||||
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
|
||||
series = {KDD '21},
|
||||
year = {2021},
|
||||
publisher = {ACM},
|
||||
}
|
||||
|
||||
@article{yang2020qlib,
|
||||
title={Qlib: An AI-oriented Quantitative Investment Platform},
|
||||
author={Yang, Xiao and Liu, Weiqing and Zhou, Dong and Bian, Jiang and Liu, Tie-Yan},
|
||||
journal={arXiv preprint arXiv:2009.11189},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
# Running TRA
|
||||
## Usage (Recommended)
|
||||
|
||||
## Requirements
|
||||
- Install `Qlib` main branch
|
||||
**Update**: `TRA` has been moved to `qlib.contrib.model.pytorch_tra` to support other `Qlib` components like `qlib.workflow` and `Alpha158/Alpha360` dataset.
|
||||
|
||||
## Running
|
||||
Please follow the official [doc](https://qlib.readthedocs.io/en/latest/component/workflow.html) to use `TRA` with `workflow`. Here we also provide several example config files:
|
||||
|
||||
- `workflow_config_tra_Alpha360.yaml`: running `TRA` with `Alpha360` dataset
|
||||
- `workflow_config_tra_Alpha158.yaml`: running `TRA` with `Alpha158` dataset (with feature subsampling)
|
||||
- `workflow_config_tra_Alpha158_full.yaml`: running `TRA` with `Alpha158` dataset (without feature subsampling)
|
||||
|
||||
The performances of `TRA` are reported in [Benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).
|
||||
|
||||
## Usage (Not Maintained)
|
||||
|
||||
This section is used to reproduce the results in the paper.
|
||||
|
||||
### Running
|
||||
|
||||
We attach our running scripts for the paper in `run.sh`.
|
||||
|
||||
And here are two ways to run the model:
|
||||
|
||||
* Running from scripts with default parameters
|
||||
You can directly run from Qlib command `qrun`:
|
||||
```
|
||||
qrun configs/config_alstm.yaml
|
||||
```
|
||||
|
||||
You can directly run from Qlib command `qrun`:
|
||||
```
|
||||
qrun configs/config_alstm.yaml
|
||||
```
|
||||
|
||||
* Running from code with self-defined parameters
|
||||
Setting different parameters is also allowed. See codes in `example.py`:
|
||||
```
|
||||
python example.py --config_file configs/config_alstm.yaml
|
||||
```
|
||||
|
||||
Setting different parameters is also allowed. See codes in `example.py`:
|
||||
```
|
||||
python example.py --config_file configs/config_alstm.yaml
|
||||
```
|
||||
|
||||
Here we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scipts.
|
||||
|
||||
# Results
|
||||
|
||||
## Outputs
|
||||
### Results
|
||||
|
||||
After running the scripts, you can find result files in path `./output`:
|
||||
|
||||
`info.json` - config settings and result metrics.
|
||||
* `info.json` - config settings and result metrics.
|
||||
* `log.csv` - running logs.
|
||||
* `model.bin` - the model parameter dictionary.
|
||||
* `pred.pkl` - the prediction scores and output for inference.
|
||||
|
||||
`log.csv` - running logs.
|
||||
Evaluation metrics reported in the paper:
|
||||
|
||||
`model.bin` - the model parameter dictionary.
|
||||
|
||||
`pred.pkl` - the prediction scores and output for inference.
|
||||
|
||||
## Our Results
|
||||
| Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD |
|
||||
|-------------------|-------------------|---------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|
||||
|-------|-------|------|-----|-----|-----|-----|-----|-----|
|
||||
|Linear|0.163|0.327|0.020|0.132|-3.2%|16.8%|-0.191|32.1%|
|
||||
|LightGBM|0.160(0.000)|0.323(0.000)|0.041|0.292|7.8%|15.5%|0.503|25.7%|
|
||||
|MLP|0.160(0.002)|0.323(0.003)|0.037|0.273|3.7%|15.3%|0.264|26.2%|
|
||||
@@ -61,21 +85,8 @@ After running the scripts, you can find result files in path `./output`:
|
||||
|
||||
A more detailed demo for our experiment results in the paper can be found in `Report.ipynb`.
|
||||
|
||||
# Common Issues
|
||||
## Common Issues
|
||||
|
||||
For help or issues using TRA, please submit a GitHub issue.
|
||||
|
||||
Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.
|
||||
|
||||
# Citation
|
||||
If you find this repository useful in your research, please cite:
|
||||
```
|
||||
@inproceedings{HengxuKDD2021,
|
||||
author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
|
||||
title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
|
||||
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
|
||||
series = {KDD '21},
|
||||
year = {2021},
|
||||
publisher = {ACM},
|
||||
}
|
||||
```
|
||||
Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.
|
||||
|
||||
129
examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml
Normal file
129
examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml
Normal file
@@ -0,0 +1,129 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 20
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.0
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha158
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
123
examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml
Normal file
123
examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml
Normal file
@@ -0,0 +1,123 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 158
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.2
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha158_full
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
123
examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml
Normal file
123
examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml
Normal file
@@ -0,0 +1,123 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.0
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha360
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size: 6
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
3
examples/benchmarks/Transformer/requirements.txt
Normal file
3
examples/benchmarks/Transformer/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
torch==1.2.0
|
||||
@@ -0,0 +1,82 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TransformerModel
|
||||
module_path: qlib.contrib.model.pytorch_transformer_ts
|
||||
kwargs:
|
||||
seed: 0
|
||||
n_jobs: 20
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,73 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TransformerModel
|
||||
module_path: qlib.contrib.model.pytorch_transformer
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
seed: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -99,8 +99,6 @@ class HighFreqHandler(DataHandlerLP):
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
|
||||
names += ["date"]
|
||||
return fields, names
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,9 @@ class HighFreqNorm(Processor):
|
||||
self.feature_vmin[name] = np.nanmin(part_values)
|
||||
|
||||
def __call__(self, df_features):
|
||||
df_features["date"] = pd.to_datetime(
|
||||
df_features.index.get_level_values(level="datetime").to_series().dt.date.values
|
||||
)
|
||||
df_features.set_index("date", append=True, drop=True, inplace=True)
|
||||
df_values = df_features.values
|
||||
names = {
|
||||
|
||||
1
examples/model_rolling/requirements.txt
Normal file
1
examples/model_rolling/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
xgboost
|
||||
@@ -17,7 +17,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from qlib.model.trainer import TrainerRM, task_train
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
# init qlib
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
exp_folder_name = "run_all_model_records"
|
||||
@@ -40,6 +39,7 @@ exp_manager = {
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
|
||||
|
||||
# decorator to check the arguments
|
||||
def only_allow_defined_args(function_to_decorate):
|
||||
@functools.wraps(function_to_decorate)
|
||||
@@ -92,7 +92,8 @@ def create_env():
|
||||
|
||||
|
||||
# function to execute the cmd
|
||||
def execute(cmd):
|
||||
def execute(cmd, wait_when_err=False):
|
||||
print("Running CMD:", cmd)
|
||||
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
|
||||
for line in p.stdout:
|
||||
sys.stdout.write(line.split("\b")[0])
|
||||
@@ -102,6 +103,8 @@ def execute(cmd):
|
||||
sys.stdout.write("\b" * 10 + "\b".join(line.split("\b")[1:-1]))
|
||||
|
||||
if p.returncode != 0:
|
||||
if wait_when_err:
|
||||
input("Press Enter to Continue")
|
||||
return p.stderr
|
||||
else:
|
||||
return None
|
||||
@@ -184,7 +187,15 @@ def gen_and_save_md_table(metrics, dataset):
|
||||
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
def run(
|
||||
times=1,
|
||||
models=None,
|
||||
dataset="Alpha360",
|
||||
exclude=False,
|
||||
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
|
||||
wait_before_rm_env: bool = False,
|
||||
wait_when_err: bool = False,
|
||||
):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parrallel running the same model
|
||||
@@ -200,6 +211,13 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
determines whether the model being used is excluded or included.
|
||||
dataset : str
|
||||
determines the dataset to be used for each model.
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path
|
||||
wait_before_rm_env : bool
|
||||
wait before remove environment.
|
||||
wait_when_err : bool
|
||||
wait when errors raised when executing commands
|
||||
|
||||
Usage:
|
||||
-------
|
||||
@@ -240,32 +258,36 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
sys.stderr.write("\n")
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
execute(f"{python_path} -m pip install -r {req_path}")
|
||||
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
|
||||
sys.stderr.write("\n")
|
||||
# setup gpu for tft
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn"
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
execute(f"{python_path} -m pip install --upgrade pip") # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e git+https://github.com/microsoft/qlib#egg=pyqlib"
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
else:
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e git+https://github.com/microsoft/qlib#egg=pyqlib"
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config for multiple times
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
errs = execute(
|
||||
f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
|
||||
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
if errs is not None:
|
||||
_errs = errors.get(fn, {})
|
||||
@@ -274,6 +296,8 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
if wait_before_rm_env:
|
||||
input("Press Enter to Continue")
|
||||
shutil.rmtree(env_path)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
__version__ = "0.7.0"
|
||||
_version_path = Path(__file__).absolute().parent / "VERSION.txt" # This file is copyed from setup.py
|
||||
__version__ = _version_path.read_text(encoding="utf-8").strip()
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
|
||||
|
||||
import os
|
||||
import yaml
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
@@ -33,69 +31,70 @@ def init(default_conf="client", **kwargs):
|
||||
H.clear()
|
||||
C.set(default_conf, **kwargs)
|
||||
|
||||
# check path if server/local
|
||||
if C.get_uri_type() == C.LOCAL_URI:
|
||||
if not os.path.exists(C["provider_uri"]):
|
||||
if C["auto_mount"]:
|
||||
logger.error(
|
||||
f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
|
||||
elif C.get_uri_type() == C.NFS_URI:
|
||||
_mount_nfs_uri(C)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of URI is not supported")
|
||||
# mount nfs
|
||||
for _freq, provider_uri in C.provider_uri.items():
|
||||
mount_path = C["mount_path"][_freq]
|
||||
# check path if server/local
|
||||
uri_type = C.dpm.get_uri_type(provider_uri)
|
||||
if uri_type == C.LOCAL_URI:
|
||||
if not Path(provider_uri).exists():
|
||||
if C["auto_mount"]:
|
||||
logger.error(
|
||||
f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"auto_path is False, please make sure {mount_path} is mounted")
|
||||
elif uri_type == C.NFS_URI:
|
||||
_mount_nfs_uri(provider_uri, mount_path, C["auto_mount"])
|
||||
else:
|
||||
raise NotImplementedError(f"This type of URI is not supported")
|
||||
|
||||
C.register()
|
||||
|
||||
if "flask_server" in C:
|
||||
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
logger.info("qlib successfully initialized based on %s settings." % default_conf)
|
||||
logger.info(f"data_path={C.get_data_path()}")
|
||||
data_path = {_freq: C.dpm.get_data_path(_freq) for _freq in C.dpm.provider_uri.keys()}
|
||||
logger.info(f"data_path={data_path}")
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
|
||||
LOG = get_module_logger("mount nfs", level=logging.INFO)
|
||||
|
||||
# FIXME: the C["provider_uri"] is modified in this function
|
||||
# If it is not modified, we can pass only provider_uri or mount_path instead of C
|
||||
mount_command = "sudo mount.nfs %s %s" % (C["provider_uri"], C["mount_path"])
|
||||
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
|
||||
# If the provider uri looks like this 172.23.233.89//data/csdesign'
|
||||
# It will be a nfs path. The client provider will be used
|
||||
if not C["auto_mount"]:
|
||||
if not os.path.exists(C["mount_path"]):
|
||||
if not auto_mount:
|
||||
if not Path(mount_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Invalid mount path: {C['mount_path']}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
)
|
||||
else:
|
||||
# Judging system type
|
||||
sys_type = platform.system()
|
||||
if "win" in sys_type.lower():
|
||||
# system: window
|
||||
exec_result = os.popen("mount -o anon %s %s" % (C["provider_uri"], C["mount_path"] + ":"))
|
||||
exec_result = os.popen("mount -o anon %s %s" % (provider_uri, mount_path + ":"))
|
||||
result = exec_result.read()
|
||||
if "85" in result:
|
||||
LOG.warning("already mounted or window mount path already exists")
|
||||
LOG.warning(f"{provider_uri} on Windows:{mount_path} is already mounted")
|
||||
elif "53" in result:
|
||||
raise OSError("not find network path")
|
||||
elif "error" in result or "错误" in result:
|
||||
raise OSError("Invalid mount path")
|
||||
elif C["provider_uri"] in result:
|
||||
elif provider_uri in result:
|
||||
LOG.info("window success mount..")
|
||||
else:
|
||||
raise OSError(f"unknown error: {result}")
|
||||
|
||||
# config mount path
|
||||
C["mount_path"] = C["mount_path"] + ":\\"
|
||||
else:
|
||||
# system: linux/Unix/Mac
|
||||
# check mount
|
||||
_remote_uri = C["provider_uri"]
|
||||
_remote_uri = _remote_uri[:-1] if _remote_uri.endswith("/") else _remote_uri
|
||||
_mount_path = C["mount_path"]
|
||||
_mount_path = _mount_path[:-1] if _mount_path.endswith("/") else _mount_path
|
||||
_remote_uri = provider_uri[:-1] if provider_uri.endswith("/") else provider_uri
|
||||
_mount_path = mount_path[:-1] if mount_path.endswith("/") else mount_path
|
||||
_check_level_num = 2
|
||||
_is_mount = False
|
||||
while _check_level_num:
|
||||
@@ -121,11 +120,9 @@ def _mount_nfs_uri(C):
|
||||
|
||||
if not _is_mount:
|
||||
try:
|
||||
os.makedirs(C["mount_path"], exist_ok=True)
|
||||
Path(mount_path).mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
raise OSError(
|
||||
f"Failed to create directory {C['mount_path']}, please create {C['mount_path']} manually!"
|
||||
)
|
||||
raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
|
||||
|
||||
# check nfs-common
|
||||
command_res = os.popen("dpkg -l | grep nfs-common")
|
||||
@@ -136,11 +133,11 @@ def _mount_nfs_uri(C):
|
||||
command_status = os.system(mount_command)
|
||||
if command_status == 256:
|
||||
raise OSError(
|
||||
f"mount {C['provider_uri']} on {C['mount_path']} error! Needs SUDO! Please mount manually: {mount_command}"
|
||||
f"mount {provider_uri} on {mount_path} error! Needs SUDO! Please mount manually: {mount_command}"
|
||||
)
|
||||
elif command_status == 32512:
|
||||
# LOG.error("Command error")
|
||||
raise OSError(f"mount {C['provider_uri']} on {C['mount_path']} error! Command error")
|
||||
raise OSError(f"mount {provider_uri} on {mount_path} error! Command error")
|
||||
elif command_status == 0:
|
||||
LOG.info("Mount finished")
|
||||
else:
|
||||
|
||||
132
qlib/config.py
132
qlib/config.py
@@ -15,8 +15,10 @@ import os
|
||||
import re
|
||||
import copy
|
||||
import logging
|
||||
import platform
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -73,6 +75,12 @@ REG_US = "us"
|
||||
|
||||
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
|
||||
|
||||
DISK_DATASET_CACHE = "DiskDatasetCache"
|
||||
SIMPLE_DATASET_CACHE = "SimpleDatasetCache"
|
||||
DISK_EXPRESSION_CACHE = "DiskExpressionCache"
|
||||
|
||||
DEPENDENCY_REDIS_CACHE = (DISK_DATASET_CACHE, DISK_EXPRESSION_CACHE)
|
||||
|
||||
_default_config = {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
@@ -82,6 +90,15 @@ _default_config = {
|
||||
"dataset_provider": "LocalDatasetProvider",
|
||||
"provider": "LocalProvider",
|
||||
# config it in qlib.init()
|
||||
# "provider_uri" str or dict:
|
||||
# # str
|
||||
# "~/.qlib/stock_data/cn_data"
|
||||
# # dict
|
||||
# {"day": "~/.qlib/stock_data/cn_data", "1min": "~/.qlib/stock_data/cn_data_1min"}
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
|
||||
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
|
||||
# 3. qlib.init: provider_uri
|
||||
"provider_uri": "",
|
||||
# cache
|
||||
"expression_cache": None,
|
||||
@@ -167,8 +184,9 @@ MODE_CONF = {
|
||||
"redis_task_db": 1,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# cache
|
||||
"expression_cache": "DiskExpressionCache",
|
||||
"dataset_cache": "DiskDatasetCache",
|
||||
"expression_cache": DISK_EXPRESSION_CACHE,
|
||||
"dataset_cache": DISK_DATASET_CACHE,
|
||||
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
|
||||
"mount_path": None,
|
||||
},
|
||||
"client": {
|
||||
@@ -183,8 +201,10 @@ MODE_CONF = {
|
||||
"provider_uri": "~/.qlib/qlib_data/cn_data",
|
||||
# cache
|
||||
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
|
||||
"expression_cache": "DiskExpressionCache",
|
||||
"dataset_cache": "DiskDatasetCache",
|
||||
"expression_cache": DISK_EXPRESSION_CACHE,
|
||||
"dataset_cache": DISK_DATASET_CACHE,
|
||||
# SimpleDatasetCache directory
|
||||
"local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(),
|
||||
"calendar_cache": None,
|
||||
# client config
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
@@ -228,11 +248,43 @@ class QlibConfig(Config):
|
||||
# URI_TYPE
|
||||
LOCAL_URI = "local"
|
||||
NFS_URI = "nfs"
|
||||
DEFAULT_FREQ = "__DEFAULT_FREQ"
|
||||
|
||||
def __init__(self, default_conf):
|
||||
super().__init__(default_conf)
|
||||
self._registered = False
|
||||
|
||||
class DataPathManager:
|
||||
def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]):
|
||||
self.provider_uri = provider_uri
|
||||
self.mount_path = mount_path
|
||||
|
||||
@staticmethod
|
||||
def get_uri_type(uri: Union[str, Path]):
|
||||
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
|
||||
is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:'
|
||||
# such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None
|
||||
|
||||
if is_nfs_or_win and not is_win:
|
||||
return QlibConfig.NFS_URI
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def get_data_path(self, freq: str = None) -> Path:
|
||||
if freq is None or freq not in self.provider_uri:
|
||||
freq = QlibConfig.DEFAULT_FREQ
|
||||
_provider_uri = self.provider_uri[freq]
|
||||
if self.get_uri_type(_provider_uri) == QlibConfig.LOCAL_URI:
|
||||
return Path(_provider_uri)
|
||||
elif self.get_uri_type(_provider_uri) == QlibConfig.NFS_URI:
|
||||
if "win" in platform.system().lower():
|
||||
# windows, mount_path is the drive
|
||||
return Path(f"{self.mount_path[freq]}:\\")
|
||||
return Path(self.mount_path[freq])
|
||||
else:
|
||||
raise NotImplementedError(f"This type of uri is not supported")
|
||||
|
||||
def set_mode(self, mode):
|
||||
# raise KeyError
|
||||
self.update(MODE_CONF[mode])
|
||||
@@ -242,32 +294,43 @@ class QlibConfig(Config):
|
||||
# raise KeyError
|
||||
self.update(_default_region_config[region])
|
||||
|
||||
@staticmethod
|
||||
def is_depend_redis(cache_name: str):
|
||||
return cache_name in DEPENDENCY_REDIS_CACHE
|
||||
|
||||
@property
|
||||
def dpm(self):
|
||||
return self.DataPathManager(self["provider_uri"], self["mount_path"])
|
||||
|
||||
def resolve_path(self):
|
||||
# resolve path
|
||||
if self["mount_path"] is not None:
|
||||
self["mount_path"] = str(Path(self["mount_path"]).expanduser().resolve())
|
||||
_mount_path = self["mount_path"]
|
||||
_provider_uri = self["provider_uri"]
|
||||
if _provider_uri is None:
|
||||
raise ValueError("provider_uri cannot be None")
|
||||
if not isinstance(_provider_uri, dict):
|
||||
_provider_uri = {self.DEFAULT_FREQ: _provider_uri}
|
||||
if not isinstance(_mount_path, dict):
|
||||
_mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}
|
||||
|
||||
if self.get_uri_type() == QlibConfig.LOCAL_URI:
|
||||
self["provider_uri"] = str(Path(self["provider_uri"]).expanduser().resolve())
|
||||
# check provider_uri and mount_path
|
||||
_miss_freq = set(_provider_uri.keys()) - set(_mount_path.keys())
|
||||
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
|
||||
|
||||
def get_uri_type(self):
|
||||
is_win = re.match("^[a-zA-Z]:.*", self["provider_uri"]) is not None # such as 'C:\\data', 'D:'
|
||||
is_nfs_or_win = (
|
||||
re.match("^[^/]+:.+", self["provider_uri"]) is not None
|
||||
) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
# resolve
|
||||
for _freq, _uri in _provider_uri.items():
|
||||
# provider_uri
|
||||
if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
|
||||
_provider_uri[_freq] = str(Path(_uri).expanduser().resolve())
|
||||
# mount_path
|
||||
_mount_path[_freq] = (
|
||||
_mount_path[_freq]
|
||||
if _mount_path[_freq] is None
|
||||
else str(Path(_mount_path[_freq]).expanduser().resolve())
|
||||
)
|
||||
|
||||
if is_nfs_or_win and not is_win:
|
||||
return QlibConfig.NFS_URI
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def get_data_path(self):
|
||||
if self.get_uri_type() == QlibConfig.LOCAL_URI:
|
||||
return self["provider_uri"]
|
||||
elif self.get_uri_type() == QlibConfig.NFS_URI:
|
||||
return self["mount_path"]
|
||||
else:
|
||||
raise NotImplementedError(f"This type of uri is not supported")
|
||||
self["provider_uri"] = _provider_uri
|
||||
self["mount_path"] = _mount_path
|
||||
|
||||
def set(self, default_conf="client", **kwargs):
|
||||
from .utils import set_log_with_config, get_module_logger, can_use_cache
|
||||
@@ -299,11 +362,20 @@ class QlibConfig(Config):
|
||||
if not (self["expression_cache"] is None and self["dataset_cache"] is None):
|
||||
# check redis
|
||||
if not can_use_cache():
|
||||
logger.warning(
|
||||
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), cache will not be used!"
|
||||
)
|
||||
self["expression_cache"] = None
|
||||
self["dataset_cache"] = None
|
||||
log_str = ""
|
||||
# check expression cache
|
||||
if self.is_depend_redis(self["expression_cache"]):
|
||||
log_str += self["expression_cache"]
|
||||
self["expression_cache"] = None
|
||||
# check dataset cache
|
||||
if self.is_depend_redis(self["dataset_cache"]):
|
||||
log_str += f" and {self['dataset_cache']}" if log_str else self["dataset_cache"]
|
||||
self["dataset_cache"] = None
|
||||
if log_str:
|
||||
logger.warning(
|
||||
f"redis connection failed(host={self['redis_host']} port={self['redis_port']}), "
|
||||
f"{log_str} will not be used!"
|
||||
)
|
||||
|
||||
def register(self):
|
||||
from .utils import init_instance_by_config
|
||||
|
||||
@@ -16,6 +16,7 @@ def get_benchmark_weight(
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
path=None,
|
||||
freq="day",
|
||||
):
|
||||
"""get_benchmark_weight
|
||||
|
||||
@@ -25,6 +26,7 @@ def get_benchmark_weight(
|
||||
:param start_date:
|
||||
:param end_date:
|
||||
:param path:
|
||||
:param freq:
|
||||
|
||||
:return: The weight distribution of the the benchmark described by a pandas dataframe
|
||||
Every row corresponds to a trading day.
|
||||
@@ -33,7 +35,7 @@ def get_benchmark_weight(
|
||||
|
||||
"""
|
||||
if not path:
|
||||
path = Path(C.get_data_path()).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
|
||||
path = Path(C.dpm.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
|
||||
# TODO: the storage of weights should be implemented in a more elegent way
|
||||
# TODO: The benchmark is not consistant with the filename in instruments.
|
||||
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
|
||||
@@ -222,6 +224,7 @@ def brinson_pa(
|
||||
group_method="category",
|
||||
group_n=None,
|
||||
deal_price="vwap",
|
||||
freq="day",
|
||||
):
|
||||
"""brinson profit attribution
|
||||
|
||||
@@ -243,7 +246,7 @@ def brinson_pa(
|
||||
|
||||
start_date, end_date = min(dates), max(dates)
|
||||
|
||||
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date)
|
||||
bench_stock_weight = get_benchmark_weight(bench, start_date, end_date, freq)
|
||||
|
||||
# The attributes for allocation will not
|
||||
if not group_field.startswith("$"):
|
||||
@@ -259,13 +262,14 @@ def brinson_pa(
|
||||
start_time=shift_start_date,
|
||||
end_time=end_date,
|
||||
as_list=True,
|
||||
freq=freq,
|
||||
)
|
||||
stock_df = D.features(
|
||||
instruments,
|
||||
[group_field, deal_price],
|
||||
start_time=shift_start_date,
|
||||
end_time=end_date,
|
||||
freq="day",
|
||||
freq=freq,
|
||||
)
|
||||
stock_df.columns = [group_field, "deal_price"]
|
||||
|
||||
|
||||
346
qlib/contrib/data/dataset.py
Normal file
346
qlib/contrib/data/dataset.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return torch.tensor(x, dtype=torch.float, device=device)
|
||||
return x
|
||||
|
||||
|
||||
def _create_ts_slices(index, seq_len):
|
||||
"""
|
||||
create time series slices from pandas index
|
||||
|
||||
Args:
|
||||
index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
|
||||
seq_len (int): sequence length
|
||||
"""
|
||||
assert isinstance(index, pd.MultiIndex), "unsupported index type"
|
||||
assert seq_len > 0, "sequence length should be larger than 0"
|
||||
assert index.is_monotonic_increasing, "index should be sorted"
|
||||
|
||||
# number of dates for each instrument
|
||||
sample_count_by_insts = index.to_series().groupby(level=0).size().values
|
||||
|
||||
# start index for each instrument
|
||||
start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1)
|
||||
start_index_of_insts[0] = 0
|
||||
|
||||
# all the [start, stop) indices of features
|
||||
# features between [start, stop) will be used to predict label at `stop - 1`
|
||||
slices = []
|
||||
for cur_loc, cur_cnt in zip(start_index_of_insts, sample_count_by_insts):
|
||||
for stop in range(1, cur_cnt + 1):
|
||||
end = cur_loc + stop
|
||||
start = max(end - seq_len, 0)
|
||||
slices.append(slice(start, end))
|
||||
slices = np.array(slices, dtype="object")
|
||||
|
||||
assert len(slices) == len(index) # the i-th slice = index[i]
|
||||
|
||||
return slices
|
||||
|
||||
|
||||
def _get_date_parse_fn(target):
|
||||
"""get date parse function
|
||||
|
||||
This method is used to parse date arguments as target type.
|
||||
|
||||
Example:
|
||||
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
|
||||
get_date_parse_fn(20120101)('2017-01-01') => 20170101
|
||||
"""
|
||||
if isinstance(target, pd.Timestamp):
|
||||
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
|
||||
elif isinstance(target, int):
|
||||
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
|
||||
elif isinstance(target, str) and len(target) == 8:
|
||||
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
|
||||
else:
|
||||
_fn = lambda x: x # '2021-01-01'
|
||||
return _fn
|
||||
|
||||
|
||||
def _maybe_padding(x, seq_len, zeros=None):
|
||||
"""padding 2d <time * feature> data with zeros
|
||||
|
||||
Args:
|
||||
x (np.ndarray): 2d data with shape <time * feature>
|
||||
seq_len (int): target sequence length
|
||||
zeros (np.ndarray): zeros with shape <seq_len * feature>
|
||||
"""
|
||||
assert seq_len > 0, "sequence length should be larger than 0"
|
||||
if zeros is None:
|
||||
zeros = np.zeros((seq_len, x.shape[1]), dtype=np.float32)
|
||||
else:
|
||||
assert len(zeros) >= seq_len, "zeros matrix is not large enough for padding"
|
||||
if len(x) != seq_len: # padding zeros
|
||||
x = np.concatenate([zeros[: seq_len - len(x), : x.shape[1]], x], axis=0)
|
||||
return x
|
||||
|
||||
|
||||
class MTSDatasetH(DatasetH):
|
||||
"""Memory Augmented Time Series Dataset
|
||||
|
||||
Args:
|
||||
handler (DataHandler): data handler
|
||||
segments (dict): data split segments
|
||||
seq_len (int): time series sequence length
|
||||
horizon (int): label horizon
|
||||
num_states (int): how many memory states to be added
|
||||
memory_mode (str): memory mode (daily or sample)
|
||||
batch_size (int): batch size (<0 will use daily sampling)
|
||||
n_samples (int): number of samples in the same day
|
||||
shuffle (bool): whether shuffle data
|
||||
drop_last (bool): whether drop last batch < batch_size
|
||||
input_size (int): reshape flatten rows as this input_size (backward compatibility)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler,
|
||||
segments,
|
||||
seq_len=60,
|
||||
horizon=0,
|
||||
num_states=0,
|
||||
memory_mode="sample",
|
||||
batch_size=-1,
|
||||
n_samples=None,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
input_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
assert memory_mode in ["sample", "daily"], "unsupported memory mode"
|
||||
assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)"
|
||||
assert batch_size != 0, "invalid batch size"
|
||||
|
||||
if batch_size > 0 and n_samples is not None:
|
||||
warnings.warn("`n_samples` can only be used for daily sampling (`batch_size < 0`)")
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.horizon = horizon
|
||||
self.num_states = num_states
|
||||
self.memory_mode = memory_mode
|
||||
self.batch_size = batch_size
|
||||
self.n_samples = n_samples
|
||||
self.shuffle = shuffle
|
||||
self.drop_last = drop_last
|
||||
self.input_size = input_size
|
||||
self.params = (batch_size, n_samples, drop_last, shuffle) # for train/eval switch
|
||||
|
||||
super().__init__(handler, segments, **kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
|
||||
super().setup_data(**kwargs)
|
||||
|
||||
if handler_kwargs is not None:
|
||||
self.handler.setup_data(**handler_kwargs)
|
||||
|
||||
# pre-fetch data and change index to <code, date>
|
||||
# NOTE: we will use inplace sort to reduce memory use
|
||||
try:
|
||||
df = self.handler._learn.copy() # use copy otherwise recorder will fail
|
||||
# FIXME: currently we cannot support switching from `_learn` to `_infer` for inference
|
||||
except:
|
||||
warnings.warn("cannot access `_learn`, will load raw data")
|
||||
df = self.handler._data.copy()
|
||||
df.index = df.index.swaplevel()
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
# convert to numpy
|
||||
self._data = df["feature"].values.astype("float32")
|
||||
np.nan_to_num(self._data, copy=False) # NOTE: fillna in case users forget using the fillna processor
|
||||
self._label = df["label"].squeeze().values.astype("float32")
|
||||
self._index = df.index
|
||||
|
||||
if self.input_size is not None and self.input_size != self._data.shape[1]:
|
||||
warnings.warn("the data has different shape from input_size and the data will be reshaped")
|
||||
assert self._data.shape[1] % self.input_size == 0, "data mismatch, please check `input_size`"
|
||||
|
||||
# create batch slices
|
||||
self._batch_slices = _create_ts_slices(self._index, self.seq_len)
|
||||
|
||||
# create daily slices
|
||||
daily_slices = {date: [] for date in sorted(self._index.unique(level=1))} # sorted by date
|
||||
for i, (code, date) in enumerate(self._index):
|
||||
daily_slices[date].append(self._batch_slices[i])
|
||||
self._daily_slices = np.array(list(daily_slices.values()), dtype="object")
|
||||
self._daily_index = pd.Series(list(daily_slices.keys())) # index is the original date index
|
||||
|
||||
# add memory (sample wise and daily)
|
||||
if self.memory_mode == "sample":
|
||||
self._memory = np.zeros((len(self._data), self.num_states), dtype=np.float32)
|
||||
elif self.memory_mode == "daily":
|
||||
self._memory = np.zeros((len(self._daily_index), self.num_states), dtype=np.float32)
|
||||
else:
|
||||
raise ValueError(f"invalid memory_mode `{self.memory_mode}`")
|
||||
|
||||
# padding tensor
|
||||
self._zeros = np.zeros((self.seq_len, max(self.num_states, self._data.shape[1])), dtype=np.float32)
|
||||
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
fn = _get_date_parse_fn(self._index[0][1])
|
||||
start_date = fn(slc.start)
|
||||
end_date = fn(slc.stop)
|
||||
obj = copy.copy(self) # shallow copy
|
||||
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
|
||||
obj._data = self._data # reference (no copy)
|
||||
obj._label = self._label
|
||||
obj._index = self._index
|
||||
obj._memory = self._memory
|
||||
obj._zeros = self._zeros
|
||||
# update index for this batch
|
||||
date_index = self._index.get_level_values(1)
|
||||
obj._batch_slices = self._batch_slices[(date_index >= start_date) & (date_index <= end_date)]
|
||||
mask = (self._daily_index.values >= start_date) & (self._daily_index.values <= end_date)
|
||||
obj._daily_slices = self._daily_slices[mask]
|
||||
obj._daily_index = self._daily_index[mask]
|
||||
return obj
|
||||
|
||||
def restore_index(self, index):
|
||||
return self._index[index]
|
||||
|
||||
def restore_daily_index(self, daily_index):
|
||||
return pd.Index(self._daily_index.loc[daily_index])
|
||||
|
||||
def assign_data(self, index, vals):
|
||||
if self.num_states == 0:
|
||||
raise ValueError("cannot assign data as `num_states==0`")
|
||||
if isinstance(vals, torch.Tensor):
|
||||
vals = vals.detach().cpu().numpy()
|
||||
self._memory[index] = vals
|
||||
|
||||
def clear_memory(self):
|
||||
if self.num_states == 0:
|
||||
raise ValueError("cannot clear memory as `num_states==0`")
|
||||
self._memory[:] = 0
|
||||
|
||||
def train(self):
|
||||
"""enable traning mode"""
|
||||
self.batch_size, self.n_samples, self.drop_last, self.shuffle = self.params
|
||||
|
||||
def eval(self):
|
||||
"""enable evaluation mode"""
|
||||
self.batch_size = -1
|
||||
self.n_samples = None
|
||||
self.drop_last = False
|
||||
self.shuffle = False
|
||||
|
||||
def _get_slices(self):
|
||||
if self.batch_size < 0: # daily sampling
|
||||
slices = self._daily_slices.copy()
|
||||
batch_size = -1 * self.batch_size
|
||||
else: # normal sampling
|
||||
slices = self._batch_slices.copy()
|
||||
batch_size = self.batch_size
|
||||
return slices, batch_size
|
||||
|
||||
def __len__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
if self.drop_last:
|
||||
return len(slices) // batch_size
|
||||
return (len(slices) + batch_size - 1) // batch_size
|
||||
|
||||
def __iter__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
indices = np.arange(len(slices))
|
||||
if self.shuffle:
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[::batch_size]:
|
||||
if self.drop_last and i + batch_size > len(indices):
|
||||
break
|
||||
|
||||
data = [] # store features
|
||||
label = [] # store labels
|
||||
index = [] # store index
|
||||
state = [] # store memory states
|
||||
daily_index = [] # store daily index
|
||||
daily_count = [] # store number of samples for each day
|
||||
|
||||
for j in indices[i : i + batch_size]:
|
||||
|
||||
# normal sampling: self.batch_size > 0 => slices is a list => slices_subset is a slice
|
||||
# daily sampling: self.batch_size < 0 => slices is a nested list => slices_subset is a list
|
||||
slices_subset = slices[j]
|
||||
|
||||
# daily sampling
|
||||
# each slices_subset contains a list of slices for multiple stocks
|
||||
# NOTE: daily sampling is used in 1) eval mode, 2) train mode with self.batch_size < 0
|
||||
if self.batch_size < 0:
|
||||
|
||||
# store daily index
|
||||
idx = self._daily_index.index[j] # daily_index.index is the index of the original data
|
||||
daily_index.append(idx)
|
||||
|
||||
# store daily memory if specified
|
||||
# NOTE: daily memory always requires daily sampling (self.batch_size < 0)
|
||||
if self.memory_mode == "daily":
|
||||
slc = slice(max(idx - self.seq_len - self.horizon, 0), max(idx - self.horizon, 0))
|
||||
state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros))
|
||||
|
||||
# down-sample stocks and store count
|
||||
if self.n_samples and 0 < self.n_samples < len(slices_subset): # intraday subsample
|
||||
slices_subset = np.random.choice(slices_subset, self.n_samples, replace=False)
|
||||
daily_count.append(len(slices_subset))
|
||||
|
||||
# normal sampling
|
||||
# each slices_subset is a single slice
|
||||
# NOTE: normal sampling is used in train mode with self.batch_size > 0
|
||||
else:
|
||||
slices_subset = [slices_subset]
|
||||
|
||||
for slc in slices_subset:
|
||||
|
||||
# legacy support for Alpha360 data by `input_size`
|
||||
if self.input_size:
|
||||
data.append(self._data[slc.stop - 1].reshape(self.input_size, -1).T)
|
||||
else:
|
||||
data.append(_maybe_padding(self._data[slc], self.seq_len, self._zeros))
|
||||
|
||||
if self.memory_mode == "sample":
|
||||
state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros)[: -self.horizon])
|
||||
|
||||
label.append(self._label[slc.stop - 1])
|
||||
index.append(slc.stop - 1)
|
||||
|
||||
# end slices loop
|
||||
|
||||
# end indices batch loop
|
||||
|
||||
# concate
|
||||
data = _to_tensor(np.stack(data))
|
||||
state = _to_tensor(np.stack(state))
|
||||
label = _to_tensor(np.stack(label))
|
||||
index = np.array(index)
|
||||
daily_index = np.array(daily_index)
|
||||
daily_count = np.array(daily_count)
|
||||
|
||||
# yield -> generator
|
||||
yield {
|
||||
"data": data,
|
||||
"label": label,
|
||||
"state": state,
|
||||
"index": index,
|
||||
"daily_index": daily_index,
|
||||
"daily_count": daily_count,
|
||||
}
|
||||
|
||||
# end indice loop
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_cls_kwargs
|
||||
from ...utils import get_callable_kwargs
|
||||
from ...data.dataset import processor as processor_module
|
||||
from ...log import TimeInspector
|
||||
from inspect import getfullargspec
|
||||
@@ -14,7 +14,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
if not isinstance(p, Processor):
|
||||
klass, pkwargs = get_cls_kwargs(p, processor_module)
|
||||
klass, pkwargs = get_callable_kwargs(p, processor_module)
|
||||
args = getfullargspec(klass).args
|
||||
if "fit_start_time" in args and "fit_end_time" in args:
|
||||
assert (
|
||||
@@ -58,6 +58,7 @@ class Alpha360(DataHandlerLP):
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
@@ -72,6 +73,7 @@ class Alpha360(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"inst_processor": inst_processor,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -144,6 +146,7 @@ class Alpha158(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processor=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
@@ -158,6 +161,7 @@ class Alpha158(DataHandlerLP):
|
||||
},
|
||||
"filter_pipe": filter_pipe,
|
||||
"freq": freq,
|
||||
"inst_processor": inst_processor,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
|
||||
@@ -27,7 +27,6 @@ from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
class DailyBatchSampler(Sampler):
|
||||
def __init__(self, data_source):
|
||||
|
||||
self.data_source = data_source
|
||||
# calculate number of samples in each batch
|
||||
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values
|
||||
|
||||
331
qlib/contrib/model/pytorch_localformer.py
Normal file
331
qlib/contrib/model/pytorch_localformer.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
import math
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from torch.nn.modules.container import ModuleList
|
||||
|
||||
# qrun examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml ”
|
||||
|
||||
|
||||
class LocalformerModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat: int = 20,
|
||||
d_model: int = 64,
|
||||
batch_size: int = 2048,
|
||||
nhead: int = 2,
|
||||
num_layers: int = 2,
|
||||
dropout: float = 0,
|
||||
n_epochs=100,
|
||||
lr=0.0001,
|
||||
metric="",
|
||||
early_stop=5,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
reg=1e-3,
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.reg = reg
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_jobs = n_jobs
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.logger = get_module_logger("TransformerModel")
|
||||
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred.float() - label.float()) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, max_len=1000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
# [T, N, F]
|
||||
return x + self.pe[: x.size(0), :]
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
class LocalformerEncoder(nn.Module):
|
||||
__constants__ = ["norm"]
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, d_model):
|
||||
super(LocalformerEncoder, self).__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.conv = _get_clones(nn.Conv1d(d_model, d_model, 3, 1, 1), num_layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(self, src, mask):
|
||||
output = src
|
||||
out = src
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
# [T, N, F] --> [N, T, F] --> [N, F, T]
|
||||
out = output.transpose(1, 0).transpose(2, 1)
|
||||
out = self.conv[i](out).transpose(2, 1).transpose(1, 0)
|
||||
|
||||
output = mod(output + out, src_mask=mask)
|
||||
|
||||
return output + out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
|
||||
super(Transformer, self).__init__()
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_model,
|
||||
hidden_size=d_model,
|
||||
num_layers=num_layers,
|
||||
batch_first=False,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.feature_layer = nn.Linear(d_feat, d_model)
|
||||
self.pos_encoder = PositionalEncoding(d_model)
|
||||
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
|
||||
self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model)
|
||||
self.decoder_layer = nn.Linear(d_model, 1)
|
||||
self.device = device
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, src):
|
||||
# src [N, F*T] --> [N, T, F]
|
||||
src = src.reshape(len(src), self.d_feat, -1).permute(0, 2, 1)
|
||||
src = self.feature_layer(src)
|
||||
|
||||
# src [N, T, F] --> [T, N, F], [60, 512, 8]
|
||||
src = src.transpose(1, 0) # not batch first
|
||||
|
||||
mask = None
|
||||
|
||||
src = self.pos_encoder(src)
|
||||
output = self.transformer_encoder(src, mask) # [60, 512, 8]
|
||||
|
||||
output, _ = self.rnn(output)
|
||||
|
||||
# [T, N, F] --> [N, T*F]
|
||||
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
|
||||
|
||||
return output.squeeze()
|
||||
308
qlib/contrib/model/pytorch_localformer_ts.py
Normal file
308
qlib/contrib/model/pytorch_localformer_ts.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
import math
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from torch.nn.modules.container import ModuleList
|
||||
|
||||
|
||||
class LocalformerModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat: int = 20,
|
||||
d_model: int = 64,
|
||||
batch_size: int = 8192,
|
||||
nhead: int = 2,
|
||||
num_layers: int = 2,
|
||||
dropout: float = 0,
|
||||
n_epochs=100,
|
||||
lr=0.0001,
|
||||
metric="",
|
||||
early_stop=5,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
reg=1e-3,
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.reg = reg
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_jobs = n_jobs
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.logger = get_module_logger("TransformerModel")
|
||||
self.logger.info(
|
||||
"Improved Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred.float() - label.float()) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, data_loader):
|
||||
|
||||
self.model.train()
|
||||
|
||||
for data in data_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.model(feature.float()) # .float()
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_loader):
|
||||
|
||||
self.model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
for data in data_loader:
|
||||
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature.float()) # .float()
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(train_loader)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(train_loader)
|
||||
val_loss, val_score = self.test_epoch(valid_loader)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.model.eval()
|
||||
preds = []
|
||||
|
||||
for data in test_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=dl_test.get_index())
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, max_len=1000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
# [T, N, F]
|
||||
return x + self.pe[: x.size(0), :]
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
class LocalformerEncoder(nn.Module):
|
||||
__constants__ = ["norm"]
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, d_model):
|
||||
super(LocalformerEncoder, self).__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.conv = _get_clones(nn.Conv1d(d_model, d_model, 3, 1, 1), num_layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(self, src, mask):
|
||||
output = src
|
||||
out = src
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
# [T, N, F] --> [N, T, F] --> [N, F, T]
|
||||
out = output.transpose(1, 0).transpose(2, 1)
|
||||
out = self.conv[i](out).transpose(2, 1).transpose(1, 0)
|
||||
|
||||
output = mod(output + out, src_mask=mask)
|
||||
|
||||
return output + out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
|
||||
super(Transformer, self).__init__()
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_model,
|
||||
hidden_size=d_model,
|
||||
num_layers=num_layers,
|
||||
batch_first=False,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.feature_layer = nn.Linear(d_feat, d_model)
|
||||
self.pos_encoder = PositionalEncoding(d_model)
|
||||
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
|
||||
self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model)
|
||||
self.decoder_layer = nn.Linear(d_model, 1)
|
||||
self.device = device
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, src):
|
||||
# src [N, T, F], [512, 60, 6]
|
||||
src = self.feature_layer(src) # [512, 60, 8]
|
||||
|
||||
# src [N, T, F] --> [T, N, F], [60, 512, 8]
|
||||
src = src.transpose(1, 0) # not batch first
|
||||
|
||||
mask = None
|
||||
|
||||
src = self.pos_encoder(src)
|
||||
output = self.transformer_encoder(src, mask) # [60, 512, 8]
|
||||
|
||||
output, _ = self.rnn(output)
|
||||
|
||||
# [T, N, F] --> [N, T*F]
|
||||
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
|
||||
|
||||
return output.squeeze()
|
||||
@@ -564,7 +564,7 @@ class FeatureTransformer(nn.Module):
|
||||
self.shared = None
|
||||
self.independ = nn.ModuleList()
|
||||
if first:
|
||||
self.independ.append(GLU(inp, out_dim, vbs=vbs))
|
||||
self.independ.append(GLU(inp_dim, out_dim, vbs=vbs))
|
||||
for x in range(first, n_ind):
|
||||
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
|
||||
self.scale = float(np.sqrt(0.5))
|
||||
|
||||
944
qlib/contrib/model/pytorch_tra.py
Normal file
944
qlib/contrib/model/pytorch_tra.py
Normal file
@@ -0,0 +1,944 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import io
|
||||
import os
|
||||
import copy
|
||||
import math
|
||||
import json
|
||||
import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except:
|
||||
SummaryWriter = None
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.contrib.data.dataset import MTSDatasetH
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
class TRAModel(Model):
|
||||
"""
|
||||
TRA Model
|
||||
|
||||
Args:
|
||||
model_config (dict): model config (will be used by RNN or Transformer)
|
||||
tra_config (dict): TRA config (will be used by TRA)
|
||||
model_type (str): which backbone model to use (RNN/Transformer)
|
||||
lr (float): learning rate
|
||||
n_epochs (int): number of total epochs
|
||||
early_stop (int): early stop when performance not improved at this step
|
||||
update_freq (int): gradient update frequency
|
||||
max_steps_per_epoch (int): maximum number of steps in one epoch
|
||||
lamb (float): regularization parameter
|
||||
rho (float): exponential decay rate for `lamb`
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
seed (int): random seed
|
||||
logdir (str): local log directory
|
||||
eval_train (bool): whether evaluate train set between epochs
|
||||
eval_test (bool): whether evaluate test set between epochs
|
||||
pretrain (bool): whether pretrain the backbone model before training TRA.
|
||||
Note that only TRA will be optimized after pretraining
|
||||
init_state (str): model init state path
|
||||
freeze_model (bool): whether freeze backbone model parameters
|
||||
freeze_predictors (bool): whether freeze predictors parameters
|
||||
transport_method (str): transport method, can be none/router/oracle
|
||||
memory_mode (str): memory mode, the same argument for MTSDatasetH
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
tra_config,
|
||||
model_type="RNN",
|
||||
lr=1e-3,
|
||||
n_epochs=500,
|
||||
early_stop=50,
|
||||
update_freq=1,
|
||||
max_steps_per_epoch=None,
|
||||
lamb=0.0,
|
||||
rho=0.99,
|
||||
alpha=1.0,
|
||||
seed=0,
|
||||
logdir=None,
|
||||
eval_train=False,
|
||||
eval_test=False,
|
||||
pretrain=False,
|
||||
init_state=None,
|
||||
reset_router=False,
|
||||
freeze_model=False,
|
||||
freeze_predictors=False,
|
||||
transport_method="none",
|
||||
memory_mode="sample",
|
||||
):
|
||||
|
||||
self.logger = get_module_logger("TRA")
|
||||
|
||||
assert memory_mode in ["sample", "daily"], "invalid memory mode"
|
||||
assert transport_method in ["none", "router", "oracle"], f"invalid transport method {transport_method}"
|
||||
assert transport_method == "none" or tra_config["num_states"] > 1, "optimal transport requires `num_states` > 1"
|
||||
assert (
|
||||
memory_mode != "daily" or tra_config["src_info"] == "TPE"
|
||||
), "daily transport can only support TPE as `src_info`"
|
||||
|
||||
if transport_method == "router" and not eval_train:
|
||||
self.logger.warning("`eval_train` will be ignored when using TRA.router")
|
||||
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
self.model_config = model_config
|
||||
self.tra_config = tra_config
|
||||
self.model_type = model_type
|
||||
self.lr = lr
|
||||
self.n_epochs = n_epochs
|
||||
self.early_stop = early_stop
|
||||
self.update_freq = update_freq
|
||||
self.max_steps_per_epoch = max_steps_per_epoch
|
||||
self.lamb = lamb
|
||||
self.rho = rho
|
||||
self.alpha = alpha
|
||||
self.seed = seed
|
||||
self.logdir = logdir
|
||||
self.eval_train = eval_train
|
||||
self.eval_test = eval_test
|
||||
self.pretrain = pretrain
|
||||
self.init_state = init_state
|
||||
self.reset_router = reset_router
|
||||
self.freeze_model = freeze_model
|
||||
self.freeze_predictors = freeze_predictors
|
||||
self.transport_method = transport_method
|
||||
self.use_daily_transport = memory_mode == "daily"
|
||||
self.transport_fn = transport_daily if self.use_daily_transport else transport_sample
|
||||
|
||||
self._writer = None
|
||||
if self.logdir is not None:
|
||||
if os.path.exists(self.logdir):
|
||||
self.logger.warning(f"logdir {self.logdir} is not empty")
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
if SummaryWriter is not None:
|
||||
self._writer = SummaryWriter(log_dir=self.logdir)
|
||||
|
||||
self._init_model()
|
||||
|
||||
def _init_model(self):
|
||||
|
||||
self.logger.info("init TRAModel...")
|
||||
|
||||
self.model = eval(self.model_type)(**self.model_config).to(device)
|
||||
print(self.model)
|
||||
|
||||
self.tra = TRA(self.model.output_size, **self.tra_config).to(device)
|
||||
print(self.tra)
|
||||
|
||||
if self.init_state:
|
||||
self.logger.warning(f"load state dict from `init_state`")
|
||||
state_dict = torch.load(self.init_state, map_location="cpu")
|
||||
self.model.load_state_dict(state_dict["model"])
|
||||
res = load_state_dict_unsafe(self.tra, state_dict["tra"])
|
||||
self.logger.warning(str(res))
|
||||
|
||||
if self.reset_router:
|
||||
self.logger.warning(f"reset TRA.router parameters")
|
||||
self.tra.fc.reset_parameters()
|
||||
self.tra.router.reset_parameters()
|
||||
|
||||
if self.freeze_model:
|
||||
self.logger.warning(f"freeze model parameters")
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
if self.freeze_predictors:
|
||||
self.logger.warning(f"freeze TRA.predictors parameters")
|
||||
for param in self.tra.predictors.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters() if p.requires_grad]))
|
||||
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters() if p.requires_grad]))
|
||||
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
|
||||
|
||||
self.fitted = False
|
||||
self.global_step = -1
|
||||
|
||||
def train_epoch(self, epoch, data_set, is_pretrain=False):
|
||||
|
||||
self.model.train()
|
||||
self.tra.train()
|
||||
data_set.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
P_all = []
|
||||
prob_all = []
|
||||
choice_all = []
|
||||
max_steps = len(data_set)
|
||||
if self.max_steps_per_epoch is not None:
|
||||
if epoch == 0 and self.max_steps_per_epoch < max_steps:
|
||||
self.logger.info(f"max steps updated from {max_steps} to {self.max_steps_per_epoch}")
|
||||
max_steps = min(self.max_steps_per_epoch, max_steps)
|
||||
|
||||
cur_step = 0
|
||||
total_loss = 0
|
||||
total_count = 0
|
||||
for batch in tqdm(data_set, total=max_steps):
|
||||
cur_step += 1
|
||||
if cur_step > max_steps:
|
||||
break
|
||||
|
||||
if not is_pretrain:
|
||||
self.global_step += 1
|
||||
|
||||
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
|
||||
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
|
||||
|
||||
with torch.set_grad_enabled(not self.freeze_model):
|
||||
hidden = self.model(data)
|
||||
|
||||
all_preds, choice, prob = self.tra(hidden, state)
|
||||
|
||||
if is_pretrain or self.transport_method != "none":
|
||||
# NOTE: use oracle transport for pre-training
|
||||
loss, pred, L, P = self.transport_fn(
|
||||
all_preds,
|
||||
label,
|
||||
choice,
|
||||
prob,
|
||||
state.mean(dim=1),
|
||||
count,
|
||||
self.transport_method if not is_pretrain else "oracle",
|
||||
self.alpha,
|
||||
training=True,
|
||||
)
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
if self.use_daily_transport: # only save for daily transport
|
||||
P_all.append(pd.DataFrame(P.detach().cpu().numpy(), index=index))
|
||||
prob_all.append(pd.DataFrame(prob.detach().cpu().numpy(), index=index))
|
||||
choice_all.append(pd.DataFrame(choice.detach().cpu().numpy(), index=index))
|
||||
decay = self.rho ** (self.global_step // 100) # decay every 100 steps
|
||||
lamb = 0 if is_pretrain else self.lamb * decay
|
||||
reg = prob.log().mul(P).sum(dim=1).mean() # train router to predict OT assignment
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/router_loss", -reg.item(), self.global_step)
|
||||
self._writer.add_scalar("training/reg_loss", loss.item(), self.global_step)
|
||||
self._writer.add_scalar("training/lamb", lamb, self.global_step)
|
||||
if not self.use_daily_transport:
|
||||
P_mean = P.mean(axis=0).detach()
|
||||
self._writer.add_scalar("training/P", P_mean.max() / P_mean.min(), self.global_step)
|
||||
loss = loss - lamb * reg
|
||||
else:
|
||||
pred = all_preds.mean(dim=1)
|
||||
loss = loss_fn(pred, label)
|
||||
|
||||
(loss / self.update_freq).backward()
|
||||
if cur_step % self.update_freq == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/total_loss", loss.item(), self.global_step)
|
||||
|
||||
total_loss += loss.item()
|
||||
total_count += 1
|
||||
|
||||
if self.use_daily_transport and len(P_all):
|
||||
P_all = pd.concat(P_all, axis=0)
|
||||
prob_all = pd.concat(prob_all, axis=0)
|
||||
choice_all = pd.concat(choice_all, axis=0)
|
||||
P_all.index = data_set.restore_daily_index(P_all.index)
|
||||
prob_all.index = P_all.index
|
||||
choice_all.index = P_all.index
|
||||
if not is_pretrain:
|
||||
self._writer.add_image("P", plot(P_all), epoch, dataformats="HWC")
|
||||
self._writer.add_image("prob", plot(prob_all), epoch, dataformats="HWC")
|
||||
self._writer.add_image("choice", plot(choice_all), epoch, dataformats="HWC")
|
||||
|
||||
total_loss /= total_count
|
||||
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/loss", total_loss, epoch)
|
||||
|
||||
return total_loss
|
||||
|
||||
def test_epoch(self, epoch, data_set, return_pred=False, prefix="test", is_pretrain=False):
|
||||
|
||||
self.model.eval()
|
||||
self.tra.eval()
|
||||
data_set.eval()
|
||||
|
||||
preds = []
|
||||
probs = []
|
||||
P_all = []
|
||||
metrics = []
|
||||
for batch in tqdm(data_set):
|
||||
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
|
||||
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
|
||||
|
||||
with torch.no_grad():
|
||||
hidden = self.model(data)
|
||||
all_preds, choice, prob = self.tra(hidden, state)
|
||||
|
||||
if is_pretrain or self.transport_method != "none":
|
||||
loss, pred, L, P = self.transport_fn(
|
||||
all_preds,
|
||||
label,
|
||||
choice,
|
||||
prob,
|
||||
state.mean(dim=1),
|
||||
count,
|
||||
self.transport_method if not is_pretrain else "oracle",
|
||||
self.alpha,
|
||||
training=False,
|
||||
)
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
if P is not None and return_pred:
|
||||
P_all.append(pd.DataFrame(P.cpu().numpy(), index=index))
|
||||
else:
|
||||
pred = all_preds.mean(dim=1)
|
||||
|
||||
X = np.c_[pred.cpu().numpy(), label.cpu().numpy(), all_preds.cpu().numpy()]
|
||||
columns = ["score", "label"] + ["score_%d" % d for d in range(all_preds.shape[1])]
|
||||
pred = pd.DataFrame(X, index=batch["index"], columns=columns)
|
||||
|
||||
metrics.append(evaluate(pred))
|
||||
|
||||
if return_pred:
|
||||
preds.append(pred)
|
||||
if prob is not None:
|
||||
columns = ["prob_%d" % d for d in range(all_preds.shape[1])]
|
||||
probs.append(pd.DataFrame(prob.cpu().numpy(), index=index, columns=columns))
|
||||
|
||||
metrics = pd.DataFrame(metrics)
|
||||
metrics = {
|
||||
"MSE": metrics.MSE.mean(),
|
||||
"MAE": metrics.MAE.mean(),
|
||||
"IC": metrics.IC.mean(),
|
||||
"ICIR": metrics.IC.mean() / metrics.IC.std(),
|
||||
}
|
||||
|
||||
if self._writer is not None and epoch >= 0 and not is_pretrain:
|
||||
for key, value in metrics.items():
|
||||
self._writer.add_scalar(prefix + "/" + key, value, epoch)
|
||||
|
||||
if return_pred:
|
||||
preds = pd.concat(preds, axis=0)
|
||||
preds.index = data_set.restore_index(preds.index)
|
||||
preds.index = preds.index.swaplevel()
|
||||
preds.sort_index(inplace=True)
|
||||
|
||||
if probs:
|
||||
probs = pd.concat(probs, axis=0)
|
||||
if self.use_daily_transport:
|
||||
probs.index = data_set.restore_daily_index(probs.index)
|
||||
else:
|
||||
probs.index = data_set.restore_index(probs.index)
|
||||
probs.index = probs.index.swaplevel()
|
||||
probs.sort_index(inplace=True)
|
||||
|
||||
if len(P_all):
|
||||
P_all = pd.concat(P_all, axis=0)
|
||||
if self.use_daily_transport:
|
||||
P_all.index = data_set.restore_daily_index(P_all.index)
|
||||
else:
|
||||
P_all.index = data_set.restore_index(P_all.index)
|
||||
P_all.index = P_all.index.swaplevel()
|
||||
P_all.sort_index(inplace=True)
|
||||
|
||||
return metrics, preds, probs, P_all
|
||||
|
||||
def _fit(self, train_set, valid_set, test_set, evals_result, is_pretrain=True):
|
||||
|
||||
best_score = -1
|
||||
best_epoch = 0
|
||||
stop_rounds = 0
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
# train
|
||||
if not is_pretrain and self.transport_method != "none":
|
||||
self.logger.info("init memory...")
|
||||
self.test_epoch(-1, train_set)
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
self.logger.info("Epoch %d:", epoch)
|
||||
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(epoch, train_set, is_pretrain=is_pretrain)
|
||||
|
||||
self.logger.info("evaluating...")
|
||||
# NOTE: during evaluating, the whole memory will be refreshed
|
||||
if not is_pretrain and (self.transport_method == "router" or self.eval_train):
|
||||
train_set.clear_memory() # NOTE: clear the shared memory
|
||||
train_metrics = self.test_epoch(epoch, train_set, is_pretrain=is_pretrain, prefix="train")[0]
|
||||
evals_result["train"].append(train_metrics)
|
||||
self.logger.info("train metrics: %s" % train_metrics)
|
||||
|
||||
valid_metrics = self.test_epoch(epoch, valid_set, is_pretrain=is_pretrain, prefix="valid")[0]
|
||||
evals_result["valid"].append(valid_metrics)
|
||||
self.logger.info("valid metrics: %s" % valid_metrics)
|
||||
|
||||
if self.eval_test:
|
||||
test_metrics = self.test_epoch(epoch, test_set, is_pretrain=is_pretrain, prefix="test")[0]
|
||||
evals_result["test"].append(test_metrics)
|
||||
self.logger.info("test metrics: %s" % test_metrics)
|
||||
|
||||
if valid_metrics["IC"] > best_score:
|
||||
best_score = valid_metrics["IC"]
|
||||
stop_rounds = 0
|
||||
best_epoch = epoch
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
if self.logdir is not None:
|
||||
torch.save(best_params, self.logdir + "/model.bin")
|
||||
else:
|
||||
stop_rounds += 1
|
||||
if stop_rounds >= self.early_stop:
|
||||
self.logger.info("early stop @ %s" % epoch)
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_params["model"])
|
||||
self.tra.load_state_dict(best_params["tra"])
|
||||
|
||||
return best_score
|
||||
|
||||
def fit(self, dataset, evals_result=dict()):
|
||||
|
||||
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
|
||||
|
||||
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
|
||||
|
||||
self.fitted = True
|
||||
self.global_step = -1
|
||||
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
evals_result["test"] = []
|
||||
|
||||
if self.pretrain:
|
||||
self.logger.info("pretraining...")
|
||||
self.optimizer = optim.Adam(
|
||||
list(self.model.parameters()) + list(self.tra.predictors.parameters()), lr=self.lr
|
||||
)
|
||||
self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True)
|
||||
|
||||
# reset optimizer
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
|
||||
|
||||
self.logger.info("training...")
|
||||
best_score = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=False)
|
||||
|
||||
self.logger.info("inference")
|
||||
train_metrics, train_preds, train_probs, train_P = self.test_epoch(-1, train_set, return_pred=True)
|
||||
self.logger.info("train metrics: %s" % train_metrics)
|
||||
|
||||
valid_metrics, valid_preds, valid_probs, valid_P = self.test_epoch(-1, valid_set, return_pred=True)
|
||||
self.logger.info("valid metrics: %s" % valid_metrics)
|
||||
|
||||
test_metrics, test_preds, test_probs, test_P = self.test_epoch(-1, test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % test_metrics)
|
||||
|
||||
if self.logdir:
|
||||
self.logger.info("save model & pred to local directory")
|
||||
|
||||
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
|
||||
self.logdir + "/logs.csv", index=False
|
||||
)
|
||||
|
||||
torch.save({"model": self.model.state_dict(), "tra": self.tra.state_dict()}, self.logdir + "/model.bin")
|
||||
|
||||
train_preds.to_pickle(self.logdir + "/train_pred.pkl")
|
||||
valid_preds.to_pickle(self.logdir + "/valid_pred.pkl")
|
||||
test_preds.to_pickle(self.logdir + "/test_pred.pkl")
|
||||
|
||||
if len(train_probs):
|
||||
train_probs.to_pickle(self.logdir + "/train_prob.pkl")
|
||||
valid_probs.to_pickle(self.logdir + "/valid_prob.pkl")
|
||||
test_probs.to_pickle(self.logdir + "/test_prob.pkl")
|
||||
|
||||
if len(train_P):
|
||||
train_P.to_pickle(self.logdir + "/train_P.pkl")
|
||||
valid_P.to_pickle(self.logdir + "/valid_P.pkl")
|
||||
test_P.to_pickle(self.logdir + "/test_P.pkl")
|
||||
|
||||
info = {
|
||||
"config": {
|
||||
"model_config": self.model_config,
|
||||
"tra_config": self.tra_config,
|
||||
"model_type": self.model_type,
|
||||
"lr": self.lr,
|
||||
"n_epochs": self.n_epochs,
|
||||
"early_stop": self.early_stop,
|
||||
"max_steps_per_epoch": self.max_steps_per_epoch,
|
||||
"lamb": self.lamb,
|
||||
"rho": self.rho,
|
||||
"alpha": self.alpha,
|
||||
"seed": self.seed,
|
||||
"logdir": self.logdir,
|
||||
"pretrain": self.pretrain,
|
||||
"init_state": self.init_state,
|
||||
"transport_method": self.transport_method,
|
||||
"use_daily_transport": self.use_daily_transport,
|
||||
},
|
||||
"best_eval_metric": -best_score, # NOTE: -1 for minimize
|
||||
"metrics": {"train": train_metrics, "valid": valid_metrics, "test": test_metrics},
|
||||
}
|
||||
with open(self.logdir + "/info.json", "w") as f:
|
||||
json.dump(info, f)
|
||||
|
||||
def predict(self, dataset, segment="test"):
|
||||
|
||||
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
|
||||
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
test_set = dataset.prepare(segment)
|
||||
|
||||
metrics, preds, _, _ = self.test_epoch(-1, test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % metrics)
|
||||
|
||||
return preds
|
||||
|
||||
|
||||
class RNN(nn.Module):
|
||||
|
||||
"""RNN Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of hidden layers
|
||||
rnn_arch (str): rnn architecture
|
||||
use_attn (bool): whether use attention layer.
|
||||
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
|
||||
dropout (float): dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
rnn_arch="GRU",
|
||||
use_attn=True,
|
||||
dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.rnn_arch = rnn_arch
|
||||
self.use_attn = use_attn
|
||||
|
||||
if hidden_size < input_size:
|
||||
# compression
|
||||
self.input_proj = nn.Linear(input_size, hidden_size)
|
||||
else:
|
||||
self.input_proj = None
|
||||
|
||||
self.rnn = getattr(nn, rnn_arch)(
|
||||
input_size=min(input_size, hidden_size),
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
if self.use_attn:
|
||||
self.W = nn.Linear(hidden_size, hidden_size)
|
||||
self.u = nn.Linear(hidden_size, 1, bias=False)
|
||||
self.output_size = hidden_size * 2
|
||||
else:
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.input_proj is not None:
|
||||
x = self.input_proj(x)
|
||||
|
||||
rnn_out, last_out = self.rnn(x)
|
||||
if self.rnn_arch == "LSTM":
|
||||
last_out = last_out[0]
|
||||
last_out = last_out.mean(dim=0)
|
||||
|
||||
if self.use_attn:
|
||||
laten = self.W(rnn_out).tanh()
|
||||
scores = self.u(laten).softmax(dim=1)
|
||||
att_out = (rnn_out * scores).sum(dim=1)
|
||||
last_out = torch.cat([last_out, att_out], dim=1)
|
||||
|
||||
return last_out
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[: x.size(0), :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
"""Transformer Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of transformer layers
|
||||
num_heads (int): number of heads in transformer
|
||||
dropout (float): dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
num_heads=2,
|
||||
dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.input_proj = nn.Linear(input_size, hidden_size)
|
||||
|
||||
self.pe = PositionalEncoding(input_size, dropout)
|
||||
layer = nn.TransformerEncoderLayer(
|
||||
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
|
||||
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = x.permute(1, 0, 2).contiguous() # the first dim need to be time
|
||||
x = self.pe(x)
|
||||
|
||||
x = self.input_proj(x)
|
||||
out = self.encoder(x)
|
||||
|
||||
return out[-1]
|
||||
|
||||
|
||||
class TRA(nn.Module):
|
||||
|
||||
"""Temporal Routing Adaptor (TRA)
|
||||
|
||||
TRA takes historical prediction erros & latent representation as inputs,
|
||||
then routes the input sample to a specific predictor for training & inference.
|
||||
|
||||
Args:
|
||||
input_size (int): input size (RNN/Transformer's hidden size)
|
||||
num_states (int): number of latent states (i.e., trading patterns)
|
||||
If `num_states=1`, then TRA falls back to traditional methods
|
||||
hidden_size (int): hidden size of the router
|
||||
tau (float): gumbel softmax temperature
|
||||
src_info (str): information for the router
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
num_states=1,
|
||||
hidden_size=8,
|
||||
rnn_arch="GRU",
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
tau=1.0,
|
||||
src_info="LR_TPE",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert src_info in ["LR", "TPE", "LR_TPE"], "invalid `src_info`"
|
||||
|
||||
self.num_states = num_states
|
||||
self.tau = tau
|
||||
self.rnn_arch = rnn_arch
|
||||
self.src_info = src_info
|
||||
|
||||
self.predictors = nn.Linear(input_size, num_states)
|
||||
|
||||
if self.num_states > 1:
|
||||
if "TPE" in src_info:
|
||||
self.router = getattr(nn, rnn_arch)(
|
||||
input_size=num_states,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.fc = nn.Linear(hidden_size + input_size if "LR" in src_info else hidden_size, num_states)
|
||||
else:
|
||||
self.fc = nn.Linear(input_size, num_states)
|
||||
|
||||
def reset_parameters(self):
|
||||
for child in self.children():
|
||||
child.reset_parameters()
|
||||
|
||||
def forward(self, hidden, hist_loss):
|
||||
|
||||
preds = self.predictors(hidden)
|
||||
|
||||
if self.num_states == 1: # no need for router when having only one prediction
|
||||
return preds, None, None
|
||||
|
||||
if "TPE" in self.src_info:
|
||||
out = self.router(hist_loss)[1] # TPE
|
||||
if self.rnn_arch == "LSTM":
|
||||
out = out[0]
|
||||
out = out.mean(dim=0)
|
||||
if "LR" in self.src_info:
|
||||
out = torch.cat([hidden, out], dim=-1) # LR_TPE
|
||||
else:
|
||||
out = hidden # LR
|
||||
|
||||
out = self.fc(out)
|
||||
|
||||
choice = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=True)
|
||||
prob = torch.softmax(out / self.tau, dim=-1)
|
||||
|
||||
return preds, choice, prob
|
||||
|
||||
|
||||
def evaluate(pred):
|
||||
pred = pred.rank(pct=True) # transform into percentiles
|
||||
score = pred.score
|
||||
label = pred.label
|
||||
diff = score - label
|
||||
MSE = (diff ** 2).mean()
|
||||
MAE = (diff.abs()).mean()
|
||||
IC = score.corr(label, method="spearman")
|
||||
return {"MSE": MSE, "MAE": MAE, "IC": IC}
|
||||
|
||||
|
||||
def shoot_infs(inp_tensor):
|
||||
"""Replaces inf by maximum of tensor"""
|
||||
mask_inf = torch.isinf(inp_tensor)
|
||||
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
|
||||
if len(ind_inf) > 0:
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = 0
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = 0
|
||||
m = torch.max(inp_tensor)
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = m
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = m
|
||||
return inp_tensor
|
||||
|
||||
|
||||
def sinkhorn(Q, n_iters=3, epsilon=0.1):
|
||||
# epsilon should be adjusted according to logits value's scale
|
||||
with torch.no_grad():
|
||||
Q = torch.exp(Q / epsilon)
|
||||
Q = shoot_infs(Q)
|
||||
for i in range(n_iters):
|
||||
Q /= Q.sum(dim=0, keepdim=True)
|
||||
Q /= Q.sum(dim=1, keepdim=True)
|
||||
return Q
|
||||
|
||||
|
||||
def loss_fn(pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
if len(pred.shape) == 2:
|
||||
label = label[:, None]
|
||||
return (pred[mask] - label[mask]).pow(2).mean(dim=0)
|
||||
|
||||
|
||||
def minmax_norm(x):
|
||||
xmin = x.min(dim=-1, keepdim=True).values
|
||||
xmax = x.max(dim=-1, keepdim=True).values
|
||||
mask = (xmin == xmax).squeeze()
|
||||
x = (x - xmin) / (xmax - xmin + 1e-12)
|
||||
x[mask] = 1
|
||||
return x
|
||||
|
||||
|
||||
def transport_sample(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
|
||||
"""
|
||||
sample-wise transport
|
||||
|
||||
Args:
|
||||
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
|
||||
label (torch.Tensor): label, [sample]
|
||||
choice (torch.Tensor): gumbel softmax choice, [sample x states]
|
||||
prob (torch.Tensor): router predicted probility, [sample x states]
|
||||
hist_loss (torch.Tensor): history loss matrix, [sample x states]
|
||||
count (list): sample counts for each day, empty list for sample-wise transport
|
||||
transport_method (str): transportation method
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
training (bool): indicate training or inference
|
||||
"""
|
||||
assert all_preds.shape == choice.shape
|
||||
assert len(all_preds) == len(label)
|
||||
assert transport_method in ["oracle", "router"]
|
||||
|
||||
all_loss = torch.zeros_like(all_preds)
|
||||
mask = ~torch.isnan(label)
|
||||
all_loss[mask] = (all_preds[mask] - label[mask, None]).pow(2) # [sample x states]
|
||||
|
||||
L = minmax_norm(all_loss.detach())
|
||||
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
|
||||
Lh = minmax_norm(Lh)
|
||||
P = sinkhorn(-Lh)
|
||||
del Lh
|
||||
|
||||
if transport_method == "router":
|
||||
if training:
|
||||
pred = (all_preds * choice).sum(dim=1) # gumbel softmax
|
||||
else:
|
||||
pred = all_preds[range(len(all_preds)), prob.argmax(dim=-1)] # argmax
|
||||
else:
|
||||
pred = (all_preds * P).sum(dim=1)
|
||||
|
||||
if transport_method == "router":
|
||||
loss = loss_fn(pred, label)
|
||||
else:
|
||||
loss = (all_loss * P).sum(dim=1).mean()
|
||||
|
||||
return loss, pred, L, P
|
||||
|
||||
|
||||
def transport_daily(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
|
||||
"""
|
||||
daily transport
|
||||
|
||||
Args:
|
||||
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
|
||||
label (torch.Tensor): label, [sample]
|
||||
choice (torch.Tensor): gumbel softmax choice, [days x states]
|
||||
prob (torch.Tensor): router predicted probility, [days x states]
|
||||
hist_loss (torch.Tensor): history loss matrix, [days x states]
|
||||
count (list): sample counts for each day, [days]
|
||||
transport_method (str): transportation method
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
training (bool): indicate training or inference
|
||||
"""
|
||||
assert len(prob) == len(count)
|
||||
assert len(all_preds) == sum(count)
|
||||
assert transport_method in ["oracle", "router"]
|
||||
|
||||
all_loss = [] # loss of all predictions
|
||||
start = 0
|
||||
for i, cnt in enumerate(count):
|
||||
slc = slice(start, start + cnt) # samples from the i-th day
|
||||
start += cnt
|
||||
tloss = loss_fn(all_preds[slc], label[slc]) # loss of the i-th day
|
||||
all_loss.append(tloss)
|
||||
all_loss = torch.stack(all_loss, dim=0) # [days x states]
|
||||
|
||||
L = minmax_norm(all_loss.detach())
|
||||
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
|
||||
Lh = minmax_norm(Lh)
|
||||
P = sinkhorn(-Lh)
|
||||
del Lh
|
||||
|
||||
pred = []
|
||||
start = 0
|
||||
for i, cnt in enumerate(count):
|
||||
slc = slice(start, start + cnt) # samples from the i-th day
|
||||
start += cnt
|
||||
if transport_method == "router":
|
||||
if training:
|
||||
tpred = all_preds[slc] @ choice[i] # gumbel softmax
|
||||
else:
|
||||
tpred = all_preds[slc][:, prob[i].argmax(dim=-1)] # argmax
|
||||
else:
|
||||
tpred = all_preds[slc] @ P[i]
|
||||
pred.append(tpred)
|
||||
pred = torch.cat(pred, dim=0) # [samples]
|
||||
|
||||
if transport_method == "router":
|
||||
loss = loss_fn(pred, label)
|
||||
else:
|
||||
loss = (all_loss * P).sum(dim=1).mean()
|
||||
|
||||
return loss, pred, L, P
|
||||
|
||||
|
||||
def load_state_dict_unsafe(model, state_dict):
|
||||
"""
|
||||
Load state dict to provided model while ignore exceptions.
|
||||
"""
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(model)
|
||||
load = None # break load->load reference cycle
|
||||
|
||||
return {"unexpected_keys": unexpected_keys, "missing_keys": missing_keys, "error_msgs": error_msgs}
|
||||
|
||||
|
||||
def plot(P):
|
||||
assert isinstance(P, pd.DataFrame)
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
|
||||
P.plot.area(ax=axes[0], xlabel="")
|
||||
P.idxmax(axis=1).value_counts().sort_index().plot.bar(ax=axes[1], xlabel="")
|
||||
plt.tight_layout()
|
||||
|
||||
with io.BytesIO() as buf:
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
img = plt.imread(buf)
|
||||
plt.close()
|
||||
|
||||
return np.uint8(img * 255)
|
||||
294
qlib/contrib/model/pytorch_transformer.py
Normal file
294
qlib/contrib/model/pytorch_transformer.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
import math
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
# qrun examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml ”
|
||||
|
||||
|
||||
class TransformerModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat: int = 20,
|
||||
d_model: int = 64,
|
||||
batch_size: int = 2048,
|
||||
nhead: int = 2,
|
||||
num_layers: int = 2,
|
||||
dropout: float = 0,
|
||||
n_epochs=100,
|
||||
lr=0.0001,
|
||||
metric="",
|
||||
early_stop=5,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
reg=1e-3,
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.reg = reg
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_jobs = n_jobs
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.logger = get_module_logger("TransformerModel")
|
||||
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred.float() - label.float()) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, max_len=1000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
# [T, N, F]
|
||||
return x + self.pe[: x.size(0), :]
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
|
||||
super(Transformer, self).__init__()
|
||||
self.feature_layer = nn.Linear(d_feat, d_model)
|
||||
self.pos_encoder = PositionalEncoding(d_model)
|
||||
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
|
||||
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
|
||||
self.decoder_layer = nn.Linear(d_model, 1)
|
||||
self.device = device
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, src):
|
||||
# src [N, F*T] --> [N, T, F]
|
||||
src = src.reshape(len(src), self.d_feat, -1).permute(0, 2, 1)
|
||||
src = self.feature_layer(src)
|
||||
|
||||
# src [N, T, F] --> [T, N, F], [60, 512, 8]
|
||||
src = src.transpose(1, 0) # not batch first
|
||||
|
||||
mask = None
|
||||
|
||||
src = self.pos_encoder(src)
|
||||
output = self.transformer_encoder(src, mask) # [60, 512, 8]
|
||||
|
||||
# [T, N, F] --> [N, T*F]
|
||||
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
|
||||
|
||||
return output.squeeze()
|
||||
269
qlib/contrib/model/pytorch_transformer_ts.py
Normal file
269
qlib/contrib/model/pytorch_transformer_ts.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
import math
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class TransformerModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat: int = 20,
|
||||
d_model: int = 64,
|
||||
batch_size: int = 8192,
|
||||
nhead: int = 2,
|
||||
num_layers: int = 2,
|
||||
dropout: float = 0,
|
||||
n_epochs=100,
|
||||
lr=0.0001,
|
||||
metric="",
|
||||
early_stop=5,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
reg=1e-3,
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.reg = reg
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_jobs = n_jobs
|
||||
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.logger = get_module_logger("TransformerModel")
|
||||
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred.float() - label.float()) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, data_loader):
|
||||
|
||||
self.model.train()
|
||||
|
||||
for data in data_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.model(feature.float()) # .float()
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_loader):
|
||||
|
||||
self.model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
for data in data_loader:
|
||||
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature.float()) # .float()
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(train_loader)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(train_loader)
|
||||
val_loss, val_score = self.test_epoch(valid_loader)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.model.eval()
|
||||
preds = []
|
||||
|
||||
for data in test_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=dl_test.get_index())
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, max_len=1000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
# [T, N, F]
|
||||
return x + self.pe[: x.size(0), :]
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None):
|
||||
super(Transformer, self).__init__()
|
||||
self.feature_layer = nn.Linear(d_feat, d_model)
|
||||
self.pos_encoder = PositionalEncoding(d_model)
|
||||
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
|
||||
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
|
||||
self.decoder_layer = nn.Linear(d_model, 1)
|
||||
self.device = device
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, src):
|
||||
# src [N, T, F], [512, 60, 6]
|
||||
src = self.feature_layer(src) # [512, 60, 8]
|
||||
|
||||
# src [N, T, F] --> [T, N, F], [60, 512, 8]
|
||||
src = src.transpose(1, 0) # not batch first
|
||||
|
||||
mask = None
|
||||
|
||||
src = self.pos_encoder(src)
|
||||
output = self.transformer_encoder(src, mask) # [60, 512, 8]
|
||||
|
||||
# [T, N, F] --> [N, T*F]
|
||||
output = self.decoder_layer(output.transpose(1, 0)[:, -1, :]) # [512, 1]
|
||||
|
||||
return output.squeeze()
|
||||
@@ -221,9 +221,9 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
only_tradable : bool
|
||||
will the strategy only consider the tradable stock when buying and selling.
|
||||
if only_tradable:
|
||||
strategy will make buy sell decision without checking the tradable state of the stock.
|
||||
the strategy will peek at the information in the short future to avoid untradable stocks (untradable stocks include stocks that meet suspension, or hit limit up or limit down).
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
the strategy will generate orders without peeking any information in the future, so the order generated by the strategies may fail.
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__()
|
||||
ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None))
|
||||
|
||||
@@ -196,9 +196,9 @@ class Feature(Expression):
|
||||
|
||||
def __init__(self, name=None):
|
||||
if name:
|
||||
self._name = name.lower()
|
||||
self._name = name
|
||||
else:
|
||||
self._name = type(self).__name__.lower()
|
||||
self._name = type(self).__name__
|
||||
|
||||
def __str__(self):
|
||||
return "$" + self._name
|
||||
|
||||
@@ -17,6 +17,7 @@ import abc
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union, Iterable
|
||||
from collections import OrderedDict
|
||||
|
||||
from ..config import C
|
||||
@@ -216,12 +217,14 @@ class CacheUtils:
|
||||
redis_lock.reset_all(r)
|
||||
|
||||
@staticmethod
|
||||
def visit(cache_path):
|
||||
def visit(cache_path: Union[str, Path]):
|
||||
# FIXME: Because read_lock was canceled when reading the cache, multiple processes may have read and write exceptions here
|
||||
try:
|
||||
with open(cache_path + ".meta", "rb") as f:
|
||||
cache_path = Path(cache_path)
|
||||
meta_path = cache_path.with_suffix(".meta")
|
||||
with meta_path.open("rb") as f:
|
||||
d = pickle.load(f)
|
||||
with open(cache_path + ".meta", "wb") as f:
|
||||
with meta_path.open("wb") as f:
|
||||
try:
|
||||
d["meta"]["last_visit"] = str(time.time())
|
||||
d["meta"]["visits"] = d["meta"]["visits"] + 1
|
||||
@@ -249,17 +252,17 @@ class CacheUtils:
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def reader_lock(redis_t, lock_name):
|
||||
lock_name = f"{C.provider_uri}:{lock_name}"
|
||||
current_cache_rlock = redis_lock.Lock(redis_t, "%s-rlock" % lock_name)
|
||||
current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name)
|
||||
def reader_lock(redis_t, lock_name: str):
|
||||
current_cache_rlock = redis_lock.Lock(redis_t, f"{lock_name}-rlock")
|
||||
current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock")
|
||||
lock_reader = f"{lock_name}-reader"
|
||||
# make sure only one reader is entering
|
||||
current_cache_rlock.acquire(timeout=60)
|
||||
try:
|
||||
current_cache_readers = redis_t.get("%s-reader" % lock_name)
|
||||
current_cache_readers = redis_t.get(lock_reader)
|
||||
if current_cache_readers is None or int(current_cache_readers) == 0:
|
||||
CacheUtils.acquire(current_cache_wlock, lock_name)
|
||||
redis_t.incr("%s-reader" % lock_name)
|
||||
redis_t.incr(lock_reader)
|
||||
finally:
|
||||
current_cache_rlock.release()
|
||||
try:
|
||||
@@ -268,9 +271,9 @@ class CacheUtils:
|
||||
# make sure only one reader is leaving
|
||||
current_cache_rlock.acquire(timeout=60)
|
||||
try:
|
||||
redis_t.decr("%s-reader" % lock_name)
|
||||
if int(redis_t.get("%s-reader" % lock_name)) == 0:
|
||||
redis_t.delete("%s-reader" % lock_name)
|
||||
redis_t.decr(lock_reader)
|
||||
if int(redis_t.get(lock_reader)) == 0:
|
||||
redis_t.delete(lock_reader)
|
||||
current_cache_wlock.reset()
|
||||
finally:
|
||||
current_cache_rlock.release()
|
||||
@@ -278,8 +281,7 @@ class CacheUtils:
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def writer_lock(redis_t, lock_name):
|
||||
lock_name = f"{C.provider_uri}:{lock_name}"
|
||||
current_cache_wlock = redis_lock.Lock(redis_t, "%s-wlock" % lock_name, id=CacheUtils.LOCK_ID)
|
||||
current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock", id=CacheUtils.LOCK_ID)
|
||||
CacheUtils.acquire(current_cache_wlock, lock_name)
|
||||
try:
|
||||
yield
|
||||
@@ -297,6 +299,30 @@ class BaseProviderCache:
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.provider, attr)
|
||||
|
||||
@staticmethod
|
||||
def check_cache_exists(cache_path: Union[str, Path], suffix_list: Iterable = (".index", ".meta")) -> bool:
|
||||
cache_path = Path(cache_path)
|
||||
for p in [cache_path] + [cache_path.with_suffix(_s) for _s in suffix_list]:
|
||||
if not p.exists():
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def clear_cache(cache_path: Union[str, Path]):
|
||||
for p in [
|
||||
cache_path,
|
||||
cache_path.with_suffix(".meta"),
|
||||
cache_path.with_suffix(".index"),
|
||||
]:
|
||||
if p.exists():
|
||||
p.unlink()
|
||||
|
||||
@staticmethod
|
||||
def get_cache_dir(dir_name: str, freq: str = None) -> Path:
|
||||
cache_dir = Path(C.dpm.get_data_path(freq)).joinpath(dir_name)
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
|
||||
class ExpressionCache(BaseProviderCache):
|
||||
"""Expression cache mechanism base class.
|
||||
@@ -330,15 +356,16 @@ class ExpressionCache(BaseProviderCache):
|
||||
"""
|
||||
raise NotImplementedError("Implement this method if you want to use expression cache")
|
||||
|
||||
def update(self, cache_uri):
|
||||
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
|
||||
"""Update expression cache to latest calendar.
|
||||
|
||||
Overide this method to define how to update expression cache corresponding to users' own cache mechanism.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache_uri : str
|
||||
cache_uri : str or Path
|
||||
the complete uri of expression cache file (include dir path).
|
||||
freq : str
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -358,7 +385,9 @@ class DatasetCache(BaseProviderCache):
|
||||
|
||||
HDF_KEY = "df"
|
||||
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
"""Get feature dataset.
|
||||
|
||||
.. note:: Same interface as `dataset` method in dataset provider
|
||||
@@ -369,13 +398,19 @@ class DatasetCache(BaseProviderCache):
|
||||
"""
|
||||
if disk_cache == 0:
|
||||
# skip cache
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
else:
|
||||
# use and replace cache
|
||||
try:
|
||||
return self._dataset(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
return self._dataset(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
except NotImplementedError:
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, **kwargs):
|
||||
"""Get dataset cache file uri.
|
||||
@@ -384,14 +419,18 @@ class DatasetCache(BaseProviderCache):
|
||||
"""
|
||||
raise NotImplementedError("Implement this function to match your own cache mechanism")
|
||||
|
||||
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
"""Get feature dataset using cache.
|
||||
|
||||
Override this method to define how to get feature dataset corresponding to users' own cache mechanism.
|
||||
"""
|
||||
raise NotImplementedError("Implement this method if you want to use dataset feature cache")
|
||||
|
||||
def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def _dataset_uri(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
"""Get a uri of feature dataset using cache.
|
||||
specially:
|
||||
disk_cache=1 means using data set cache and return the uri of cache file.
|
||||
@@ -403,15 +442,16 @@ class DatasetCache(BaseProviderCache):
|
||||
"Implement this method if you want to use dataset feature cache as a cache file for client"
|
||||
)
|
||||
|
||||
def update(self, cache_uri):
|
||||
def update(self, cache_uri: Union[str, Path], freq: str = "day"):
|
||||
"""Update dataset cache to latest calendar.
|
||||
|
||||
Overide this method to define how to update dataset cache corresponding to users' own cache mechanism.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache_uri : str
|
||||
cache_uri : str or Path
|
||||
the complete uri of dataset cache file (include dir path).
|
||||
freq : str
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -452,25 +492,19 @@ class DiskExpressionCache(ExpressionCache):
|
||||
self.r = get_redis_connection()
|
||||
# remote==True means client is using this module, writing behaviour will not be allowed.
|
||||
self.remote = kwargs.get("remote", False)
|
||||
self.expr_cache_path = os.path.join(C.get_data_path(), C.features_cache_dir_name)
|
||||
os.makedirs(self.expr_cache_path, exist_ok=True)
|
||||
|
||||
def get_cache_dir(self, freq: str = None) -> Path:
|
||||
return super(DiskExpressionCache, self).get_cache_dir(C.features_cache_dir_name, freq)
|
||||
|
||||
def _uri(self, instrument, field, start_time, end_time, freq):
|
||||
field = remove_fields_space(field)
|
||||
instrument = str(instrument).lower()
|
||||
return hash_args(instrument, field, freq)
|
||||
|
||||
@staticmethod
|
||||
def check_cache_exists(cache_path):
|
||||
for p in [cache_path, cache_path + ".meta"]:
|
||||
if not Path(p).exists():
|
||||
return False
|
||||
return True
|
||||
|
||||
def _expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
|
||||
_cache_uri = self._uri(instrument=instrument, field=field, start_time=None, end_time=None, freq=freq)
|
||||
_instrument_dir = os.path.join(self.expr_cache_path, instrument.lower())
|
||||
cache_path = os.path.join(_instrument_dir, _cache_uri)
|
||||
_instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower())
|
||||
cache_path = _instrument_dir.joinpath(_cache_uri)
|
||||
# get calendar
|
||||
from .data import Cal
|
||||
|
||||
@@ -478,7 +512,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
|
||||
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False)
|
||||
|
||||
if self.check_cache_exists(cache_path):
|
||||
if self.check_cache_exists(cache_path, suffix_list=[".meta"]):
|
||||
"""
|
||||
In most cases, we do not need reader_lock.
|
||||
Because updating data is a small probability event compare to reading data.
|
||||
@@ -502,8 +536,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
# normalize field
|
||||
field = remove_fields_space(field)
|
||||
# cache unavailable, generate the cache
|
||||
if not os.path.exists(_instrument_dir):
|
||||
os.makedirs(_instrument_dir, exist_ok=True)
|
||||
_instrument_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not isinstance(eval(parse_field(field)), Feature):
|
||||
# When the expression is not a raw feature
|
||||
# generate expression cache if the feature is not a Feature
|
||||
@@ -511,7 +544,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq)
|
||||
if not series.empty:
|
||||
# This expresion is empty, we don't generate any cache for it.
|
||||
with CacheUtils.writer_lock(self.r, "expression-%s" % _cache_uri):
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:expression-{_cache_uri}"):
|
||||
self.gen_expression_cache(
|
||||
expression_data=series,
|
||||
cache_path=cache_path,
|
||||
@@ -527,14 +560,6 @@ class DiskExpressionCache(ExpressionCache):
|
||||
# If the expression is a raw feature(such as $close, $open)
|
||||
return self.provider.expression(instrument, field, start_time, end_time, freq)
|
||||
|
||||
@staticmethod
|
||||
def clear_cache(cache_path):
|
||||
meta_path = cache_path + ".meta"
|
||||
for p in [cache_path, meta_path]:
|
||||
p = Path(p)
|
||||
if p.exists():
|
||||
p.unlink()
|
||||
|
||||
def gen_expression_cache(self, expression_data, cache_path, instrument, field, freq, last_update):
|
||||
"""use bin file to save like feature-data."""
|
||||
# Make sure the cache runs right when the directory is deleted
|
||||
@@ -544,27 +569,28 @@ class DiskExpressionCache(ExpressionCache):
|
||||
"meta": {"last_visit": time.time(), "visits": 1},
|
||||
}
|
||||
self.logger.debug(f"generating expression cache: {meta}")
|
||||
os.makedirs(self.expr_cache_path, exist_ok=True)
|
||||
self.clear_cache(cache_path)
|
||||
meta_path = cache_path + ".meta"
|
||||
meta_path = cache_path.with_suffix(".meta")
|
||||
|
||||
with open(meta_path, "wb") as f:
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(meta, f)
|
||||
os.chmod(meta_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
meta_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
df = expression_data.to_frame()
|
||||
|
||||
r = np.hstack([df.index[0], expression_data]).astype("<f")
|
||||
r.tofile(str(cache_path))
|
||||
|
||||
def update(self, sid, cache_uri):
|
||||
cp_cache_uri = os.path.join(self.expr_cache_path, sid, cache_uri)
|
||||
if not self.check_cache_exists(cp_cache_uri):
|
||||
def update(self, sid, cache_uri, freq: str = "day"):
|
||||
|
||||
cp_cache_uri = self.get_cache_dir(freq).joinpath(sid).joinpath(cache_uri)
|
||||
meta_path = cp_cache_uri.with_suffix(".meta")
|
||||
if not self.check_cache_exists(cp_cache_uri, suffix_list=[".meta"]):
|
||||
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
|
||||
self.clear_cache(cp_cache_uri)
|
||||
return 2
|
||||
|
||||
with CacheUtils.writer_lock(self.r, "expression-%s" % cache_uri):
|
||||
with open(cp_cache_uri + ".meta", "rb") as f:
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:expression-{cache_uri}"):
|
||||
with meta_path.open("rb") as f:
|
||||
d = pickle.load(f)
|
||||
instrument = d["info"]["instrument"]
|
||||
field = d["info"]["field"]
|
||||
@@ -611,7 +637,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
f.write(data)
|
||||
# update meta file
|
||||
d["info"]["last_update"] = str(new_calendar[-1])
|
||||
with open(cp_cache_uri + ".meta", "wb") as f:
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(d, f)
|
||||
return 0
|
||||
|
||||
@@ -623,22 +649,16 @@ class DiskDatasetCache(DatasetCache):
|
||||
super(DiskDatasetCache, self).__init__(provider)
|
||||
self.r = get_redis_connection()
|
||||
self.remote = kwargs.get("remote", False)
|
||||
self.dtst_cache_path = os.path.join(C.get_data_path(), C.dataset_cache_dir_name)
|
||||
os.makedirs(self.dtst_cache_path, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache)
|
||||
def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
|
||||
return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
|
||||
|
||||
@staticmethod
|
||||
def check_cache_exists(cache_path):
|
||||
for p in [cache_path, cache_path + ".index", cache_path + ".meta"]:
|
||||
if not Path(p).exists():
|
||||
return False
|
||||
return True
|
||||
def get_cache_dir(self, freq: str = None) -> Path:
|
||||
return super(DiskDatasetCache, self).get_cache_dir(C.dataset_cache_dir_name, freq)
|
||||
|
||||
@classmethod
|
||||
def read_data_from_cache(cls, cache_path, start_time, end_time, fields):
|
||||
def read_data_from_cache(cls, cache_path: Union[str, Path], start_time, end_time, fields):
|
||||
"""read_cache_from
|
||||
|
||||
This function can read data from the disk cache dataset
|
||||
@@ -671,17 +691,32 @@ class DiskDatasetCache(DatasetCache):
|
||||
df = pd.DataFrame(columns=fields)
|
||||
return df
|
||||
|
||||
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
|
||||
):
|
||||
|
||||
if disk_cache == 0:
|
||||
# In this case, data_set cache is configured but will not be used.
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
_cache_uri = self._uri(
|
||||
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq=freq,
|
||||
disk_cache=disk_cache,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
|
||||
cache_path = os.path.join(self.dtst_cache_path, _cache_uri)
|
||||
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
|
||||
|
||||
features = pd.DataFrame()
|
||||
gen_flag = False
|
||||
@@ -689,7 +724,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
if self.check_cache_exists(cache_path):
|
||||
if disk_cache == 1:
|
||||
# use cache
|
||||
with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri):
|
||||
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
CacheUtils.visit(cache_path)
|
||||
features = self.read_data_from_cache(cache_path, start_time, end_time, fields)
|
||||
elif disk_cache == 2:
|
||||
@@ -699,15 +734,21 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
if gen_flag:
|
||||
# cache unavailable, generate the cache
|
||||
with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri):
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
features = self.gen_dataset_cache(
|
||||
cache_path=cache_path, instruments=instruments, fields=fields, freq=freq
|
||||
cache_path=cache_path,
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
freq=freq,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
if not features.empty:
|
||||
features = features.sort_index().loc(axis=0)[:, start_time:end_time]
|
||||
return features
|
||||
|
||||
def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
|
||||
def _dataset_uri(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
|
||||
):
|
||||
if disk_cache == 0:
|
||||
# In this case, server only checks the expression cache.
|
||||
# The client will load the cache data by itself.
|
||||
@@ -715,21 +756,38 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq)
|
||||
return ""
|
||||
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
_cache_uri = self._uri(
|
||||
instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq=freq,
|
||||
disk_cache=disk_cache,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
cache_path = os.path.join(self.dtst_cache_path, _cache_uri)
|
||||
cache_path = self.get_cache_dir(freq).joinpath(_cache_uri)
|
||||
|
||||
if self.check_cache_exists(cache_path):
|
||||
self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly")
|
||||
with CacheUtils.reader_lock(self.r, "dataset-%s" % _cache_uri):
|
||||
with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
CacheUtils.visit(cache_path)
|
||||
return _cache_uri
|
||||
else:
|
||||
# cache unavailable, generate the cache
|
||||
with CacheUtils.writer_lock(self.r, "dataset-%s" % _cache_uri):
|
||||
self.gen_dataset_cache(cache_path=cache_path, instruments=instruments, fields=fields, freq=freq)
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
|
||||
self.gen_dataset_cache(
|
||||
cache_path=cache_path,
|
||||
instruments=instruments,
|
||||
fields=fields,
|
||||
freq=freq,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
return _cache_uri
|
||||
|
||||
class IndexManager:
|
||||
@@ -740,8 +798,9 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
KEY = "df"
|
||||
|
||||
def __init__(self, cache_path):
|
||||
self.index_path = cache_path + ".index"
|
||||
def __init__(self, cache_path: Union[str, Path]):
|
||||
|
||||
self.index_path = cache_path.with_suffix(".index")
|
||||
self._data = None
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
@@ -757,7 +816,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
self._data.sort_index(inplace=True)
|
||||
self._data.to_hdf(self.index_path, key=self.KEY, mode="w", format="table")
|
||||
# The index should be readable for all users
|
||||
os.chmod(self.index_path, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
self.index_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
|
||||
def sync_from_disk(self):
|
||||
# The file will not be closed directly if we read_hdf from the disk directly
|
||||
@@ -795,15 +854,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
index_data += start_index
|
||||
return index_data
|
||||
|
||||
@staticmethod
|
||||
def clear_cache(cache_path):
|
||||
meta_path = cache_path + ".meta"
|
||||
for p in [cache_path, meta_path, cache_path + ".index", cache_path + ".data"]:
|
||||
p = Path(p)
|
||||
if p.exists():
|
||||
p.unlink()
|
||||
|
||||
def gen_dataset_cache(self, cache_path, instruments, fields, freq):
|
||||
def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=[]):
|
||||
"""gen_dataset_cache
|
||||
|
||||
.. note:: This function does not consider the cache read write lock. Please
|
||||
@@ -838,20 +889,23 @@ class DiskDatasetCache(DatasetCache):
|
||||
:param instruments: The instruments to store the cache.
|
||||
:param fields: The fields to store the cache.
|
||||
:param freq: The freq to store the cache.
|
||||
:param inst_processors: Instrument processors.
|
||||
|
||||
:return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
|
||||
"""
|
||||
# get calendar
|
||||
from .data import Cal
|
||||
|
||||
cache_path = Path(cache_path)
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
self.logger.debug(f"Generating dataset cache {cache_path}")
|
||||
# Make sure the cache runs right when the directory is deleted
|
||||
# while running
|
||||
os.makedirs(self.dtst_cache_path, exist_ok=True)
|
||||
self.clear_cache(cache_path)
|
||||
|
||||
features = self.provider.dataset(instruments, fields, _calendar[0], _calendar[-1], freq)
|
||||
features = self.provider.dataset(
|
||||
instruments, fields, _calendar[0], _calendar[-1], freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
if features.empty:
|
||||
return features
|
||||
@@ -860,7 +914,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
features = features.swaplevel("instrument", "datetime").sort_index()
|
||||
|
||||
# write cache data
|
||||
with pd.HDFStore(cache_path + ".data") as store:
|
||||
with pd.HDFStore(str(cache_path.with_suffix(".data"))) as store:
|
||||
cache_to_orig_map = dict(zip(remove_fields_space(features.columns), features.columns))
|
||||
orig_to_cache_map = dict(zip(features.columns, remove_fields_space(features.columns)))
|
||||
cache_features = features[list(cache_to_orig_map.values())].rename(columns=orig_to_cache_map)
|
||||
@@ -876,12 +930,13 @@ class DiskDatasetCache(DatasetCache):
|
||||
"fields": cache_columns,
|
||||
"freq": freq,
|
||||
"last_update": str(_calendar[-1]), # The last_update to store the cache
|
||||
"inst_processors": inst_processors, # The last_update to store the cache
|
||||
},
|
||||
"meta": {"last_visit": time.time(), "visits": 1},
|
||||
}
|
||||
with open(cache_path + ".meta", "wb") as f:
|
||||
with cache_path.with_suffix(".meta").open("wb") as f:
|
||||
pickle.dump(meta, f)
|
||||
os.chmod(cache_path + ".meta", stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
cache_path.with_suffix(".meta").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
# write index file
|
||||
im = DiskDatasetCache.IndexManager(cache_path)
|
||||
index_data = im.build_index_from_data(features)
|
||||
@@ -890,26 +945,27 @@ class DiskDatasetCache(DatasetCache):
|
||||
# rename the file after the cache has been generated
|
||||
# this doesn't work well on windows, but our server won't use windows
|
||||
# temporarily
|
||||
os.replace(cache_path + ".data", cache_path)
|
||||
cache_path.with_suffix(".data").rename(cache_path)
|
||||
# the fields of the cached features are converted to the original fields
|
||||
return features.swaplevel("datetime", "instrument")
|
||||
|
||||
def update(self, cache_uri):
|
||||
cp_cache_uri = os.path.join(self.dtst_cache_path, cache_uri)
|
||||
|
||||
def update(self, cache_uri, freq: str = "day"):
|
||||
cp_cache_uri = self.get_cache_dir(freq).joinpath(cache_uri)
|
||||
meta_path = cp_cache_uri.with_suffix(".meta")
|
||||
if not self.check_cache_exists(cp_cache_uri):
|
||||
self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed")
|
||||
self.clear_cache(cp_cache_uri)
|
||||
return 2
|
||||
|
||||
im = DiskDatasetCache.IndexManager(cp_cache_uri)
|
||||
with CacheUtils.writer_lock(self.r, "dataset-%s" % cache_uri):
|
||||
with open(cp_cache_uri + ".meta", "rb") as f:
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:dataset-{cache_uri}"):
|
||||
with meta_path.open("rb") as f:
|
||||
d = pickle.load(f)
|
||||
instruments = d["info"]["instruments"]
|
||||
fields = d["info"]["fields"]
|
||||
freq = d["info"]["freq"]
|
||||
last_update_time = d["info"]["last_update"]
|
||||
inst_processors = d["info"]["inst_processors"]
|
||||
index_data = im.get_index()
|
||||
|
||||
self.logger.debug("Updating dataset: {}".format(d))
|
||||
@@ -960,7 +1016,12 @@ class DiskDatasetCache(DatasetCache):
|
||||
)
|
||||
|
||||
data = self.provider.dataset(
|
||||
instruments, fields, whole_calendar[current_index - rm_n_period], new_calendar[-1], freq
|
||||
instruments,
|
||||
fields,
|
||||
whole_calendar[current_index - rm_n_period],
|
||||
new_calendar[-1],
|
||||
freq,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
|
||||
if not data.empty:
|
||||
@@ -995,7 +1056,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
# update meta file
|
||||
d["info"]["last_update"] = str(new_calendar[-1])
|
||||
with open(cp_cache_uri + ".meta", "wb") as f:
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(d, f)
|
||||
return 0
|
||||
|
||||
@@ -1006,26 +1067,36 @@ class SimpleDatasetCache(DatasetCache):
|
||||
def __init__(self, provider):
|
||||
super(SimpleDatasetCache, self).__init__(provider)
|
||||
try:
|
||||
self.local_cache_path = C["local_cache_path"]
|
||||
except KeyError as e:
|
||||
self.local_cache_path: Path = Path(C["local_cache_path"]).expanduser().resolve()
|
||||
except (KeyError, TypeError) as e:
|
||||
self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism")
|
||||
raise
|
||||
self.logger.info(
|
||||
f"DatasetCache directory: {self.local_cache_path}, "
|
||||
f"modify the cache directory via the local_cache_path in the config"
|
||||
)
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
|
||||
instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq)
|
||||
local_cache_path = str(Path(self.local_cache_path).expanduser().resolve())
|
||||
return hash_args(instruments, fields, start_time, end_time, freq, disk_cache, local_cache_path)
|
||||
return hash_args(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors
|
||||
)
|
||||
|
||||
def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
def _dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[]
|
||||
):
|
||||
if disk_cache == 0:
|
||||
# In this case, data_set cache is configured but will not be used.
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
os.makedirs(os.path.expanduser(self.local_cache_path), exist_ok=True)
|
||||
cache_file = os.path.join(
|
||||
self.local_cache_path, self._uri(instruments, fields, start_time, end_time, freq, disk_cache=disk_cache)
|
||||
self.local_cache_path.mkdir(exist_ok=True, parents=True)
|
||||
cache_file = self.local_cache_path.joinpath(
|
||||
self._uri(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache=disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
)
|
||||
gen_flag = False
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
if cache_file.exists():
|
||||
if disk_cache == 1:
|
||||
# use cache
|
||||
df = pd.read_pickle(cache_file)
|
||||
@@ -1037,7 +1108,9 @@ class SimpleDatasetCache(DatasetCache):
|
||||
gen_flag = True
|
||||
|
||||
if gen_flag:
|
||||
data = self.provider.dataset(instruments, normalize_cache_fields(fields), start_time, end_time, freq)
|
||||
data = self.provider.dataset(
|
||||
instruments, normalize_cache_fields(fields), start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
data.to_pickle(cache_file)
|
||||
return self.cache_to_origin_data(data, fields)
|
||||
|
||||
@@ -1045,26 +1118,53 @@ class SimpleDatasetCache(DatasetCache):
|
||||
class DatasetURICache(DatasetCache):
|
||||
"""Prepared cache mechanism for server."""
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache)
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs):
|
||||
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors)
|
||||
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0):
|
||||
def dataset(
|
||||
self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[]
|
||||
):
|
||||
|
||||
if "local" in C.dataset_provider.lower():
|
||||
# use LocalDatasetProvider
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
if disk_cache == 0:
|
||||
# do not use data_set cache, load data from remote expression cache directly
|
||||
return self.provider.dataset(instruments, fields, start_time, end_time, freq, disk_cache, return_uri=False)
|
||||
|
||||
return self.provider.dataset(
|
||||
instruments,
|
||||
fields,
|
||||
start_time,
|
||||
end_time,
|
||||
freq,
|
||||
disk_cache,
|
||||
return_uri=False,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
# use ClientDatasetProvider
|
||||
feature_uri = self._uri(instruments, fields, None, None, freq, disk_cache=disk_cache)
|
||||
feature_uri = self._uri(
|
||||
instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
value, expire = MemCacheExpire.get_cache(H["f"], feature_uri)
|
||||
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
|
||||
if value is None or expire or not os.path.exists(mnt_feature_uri):
|
||||
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
|
||||
if value is None or expire or not mnt_feature_uri.exists():
|
||||
df, uri = self.provider.dataset(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, return_uri=True
|
||||
instruments,
|
||||
fields,
|
||||
start_time,
|
||||
end_time,
|
||||
freq,
|
||||
disk_cache,
|
||||
return_uri=True,
|
||||
inst_processors=inst_processors,
|
||||
)
|
||||
# cache uri
|
||||
MemCacheExpire.set_cache(H["f"], uri, uri)
|
||||
@@ -1072,7 +1172,6 @@ class DatasetURICache(DatasetCache):
|
||||
# HZ['f'][uri] = df.copy()
|
||||
get_module_logger("cache").debug(f"get feature from {C.dataset_provider}")
|
||||
else:
|
||||
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
|
||||
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
|
||||
get_module_logger("cache").debug("get feature from uri cache")
|
||||
|
||||
|
||||
@@ -5,28 +5,34 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import abc
|
||||
import copy
|
||||
import time
|
||||
import queue
|
||||
import bisect
|
||||
import logging
|
||||
import importlib
|
||||
import traceback
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from multiprocessing import Pool
|
||||
from typing import Iterable, Union
|
||||
|
||||
from .cache import H
|
||||
from ..config import C
|
||||
from .ops import Operators
|
||||
from ..log import get_module_logger
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
|
||||
from .base import Feature
|
||||
from .ops import Operators
|
||||
from .inst_processor import InstProcessor
|
||||
|
||||
from ..log import get_module_logger
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
from ..utils import (
|
||||
Wrapper,
|
||||
init_instance_by_config,
|
||||
register_wrapper,
|
||||
get_module_by_module_path,
|
||||
parse_field,
|
||||
hash_args,
|
||||
normalize_cache_fields,
|
||||
code_to_fname,
|
||||
)
|
||||
|
||||
|
||||
class ProviderBackendMixin:
|
||||
@@ -48,8 +54,14 @@ class ProviderBackendMixin:
|
||||
# default provider_uri map
|
||||
if "provider_uri" not in backend_kwargs:
|
||||
# if the user has no uri configured, use: uri = uri_map[freq]
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
|
||||
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
|
||||
# 3. qlib.init: provider_uri
|
||||
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {})
|
||||
freq = kwargs.get("freq", "day")
|
||||
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
|
||||
if freq not in provider_uri_map:
|
||||
provider_uri_map[freq] = C.dpm.get_data_path(freq)
|
||||
backend_kwargs["provider_uri"] = provider_uri_map[freq]
|
||||
backend.setdefault("kwargs", {}).update(**kwargs)
|
||||
return init_instance_by_config(backend)
|
||||
@@ -199,13 +211,23 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin):
|
||||
'filter_start_time': None,
|
||||
'filter_end_time': None}]}
|
||||
"""
|
||||
from .filter import SeriesDFilter
|
||||
|
||||
if filter_pipe is None:
|
||||
filter_pipe = []
|
||||
config = {"market": market, "filter_pipe": []}
|
||||
# the order of the filters will affect the result, so we need to keep
|
||||
# the order
|
||||
for filter_t in filter_pipe:
|
||||
config["filter_pipe"].append(filter_t.to_config())
|
||||
if isinstance(filter_t, dict):
|
||||
_config = filter_t
|
||||
elif isinstance(filter_t, SeriesDFilter):
|
||||
_config = filter_t.to_config()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported filter types: {type(filter_t)}! Filter only supports dict or isinstance(filter, SeriesDFilter)"
|
||||
)
|
||||
config["filter_pipe"].append(_config)
|
||||
return config
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -341,7 +363,7 @@ class DatasetProvider(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"):
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=[]):
|
||||
"""Get dataset data.
|
||||
|
||||
Parameters
|
||||
@@ -356,6 +378,8 @@ class DatasetProvider(abc.ABC):
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency.
|
||||
inst_processors: Iterable[Union[dict, InstProcessor]]
|
||||
the operations performed on each instrument
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -372,6 +396,7 @@ class DatasetProvider(abc.ABC):
|
||||
end_time=None,
|
||||
freq="day",
|
||||
disk_cache=1,
|
||||
inst_processors=[],
|
||||
**kwargs,
|
||||
):
|
||||
"""Get task uri, used when generating rabbitmq task in qlib_server
|
||||
@@ -392,7 +417,8 @@ class DatasetProvider(abc.ABC):
|
||||
whether to skip(0)/use(1)/replace(2) disk_cache.
|
||||
|
||||
"""
|
||||
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
# TODO: qlib-server support inst_processors
|
||||
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache, inst_processors)
|
||||
|
||||
@staticmethod
|
||||
def get_instruments_d(instruments, freq):
|
||||
@@ -433,7 +459,7 @@ class DatasetProvider(abc.ABC):
|
||||
return [ExpressionD.get_expression_instance(f) for f in fields]
|
||||
|
||||
@staticmethod
|
||||
def dataset_processor(instruments_d, column_names, start_time, end_time, freq):
|
||||
def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=[]):
|
||||
"""
|
||||
Load and process the data, return the data set.
|
||||
- default using multi-kernel method.
|
||||
@@ -459,6 +485,7 @@ class DatasetProvider(abc.ABC):
|
||||
normalize_column_names,
|
||||
spans,
|
||||
C,
|
||||
inst_processors,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -473,6 +500,7 @@ class DatasetProvider(abc.ABC):
|
||||
normalize_column_names,
|
||||
None,
|
||||
C,
|
||||
inst_processors,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -494,7 +522,9 @@ class DatasetProvider(abc.ABC):
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def expression_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None):
|
||||
def expression_calculator(
|
||||
inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]
|
||||
):
|
||||
"""
|
||||
Calculate the expressions for one instrument, return a df result.
|
||||
If the expression has been calculated before, load from cache.
|
||||
@@ -518,13 +548,17 @@ class DatasetProvider(abc.ABC):
|
||||
data.index = _calendar[data.index.values.astype(int)]
|
||||
data.index.names = ["datetime"]
|
||||
|
||||
if spans is None:
|
||||
return data
|
||||
else:
|
||||
if spans is not None:
|
||||
mask = np.zeros(len(data), dtype=bool)
|
||||
for begin, end in spans:
|
||||
mask |= (data.index >= begin) & (data.index <= end)
|
||||
return data[mask]
|
||||
data = data[mask]
|
||||
|
||||
for _processor in inst_processors:
|
||||
if _processor:
|
||||
_processor_obj = init_instance_by_config(_processor, accept_types=InstProcessor)
|
||||
data = _processor_obj(data)
|
||||
return data
|
||||
|
||||
|
||||
class LocalCalendarProvider(CalendarProvider):
|
||||
@@ -537,11 +571,6 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
super(LocalCalendarProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@property
|
||||
def _uri_cal(self):
|
||||
"""Calendar file uri."""
|
||||
return os.path.join(C.get_data_path(), "calendars", "{}.txt")
|
||||
|
||||
def load_calendar(self, freq, future):
|
||||
"""Load original calendar timestamp from file.
|
||||
|
||||
@@ -601,11 +630,6 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
Provide instrument data from local data source.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _uri_inst(self):
|
||||
"""Instrument file uri."""
|
||||
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
|
||||
|
||||
def _load_instruments(self, market, freq):
|
||||
return self.backend_obj(market=market, freq=freq).data
|
||||
|
||||
@@ -654,14 +678,9 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
super(LocalFeatureProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@property
|
||||
def _uri_data(self):
|
||||
"""Static feature file uri."""
|
||||
return os.path.join(C.get_data_path(), "features", "{}", "{}.{}.bin")
|
||||
|
||||
def feature(self, instrument, field, start_index, end_index, freq):
|
||||
# validate
|
||||
field = str(field).lower()[1:]
|
||||
field = str(field)[1:]
|
||||
instrument = code_to_fname(instrument)
|
||||
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
|
||||
|
||||
@@ -703,7 +722,15 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"):
|
||||
def dataset(
|
||||
self,
|
||||
instruments,
|
||||
fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
inst_processors=[],
|
||||
):
|
||||
instruments_d = self.get_instruments_d(instruments, freq)
|
||||
column_names = self.get_column_names(fields)
|
||||
cal = Cal.calendar(start_time, end_time, freq)
|
||||
@@ -712,7 +739,9 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
|
||||
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
|
||||
data = self.dataset_processor(
|
||||
instruments_d, column_names, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -855,6 +884,7 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
freq="day",
|
||||
disk_cache=0,
|
||||
return_uri=False,
|
||||
inst_processors=[],
|
||||
):
|
||||
if Inst.get_inst_type(instruments) == Inst.DICT:
|
||||
get_module_logger("data").warning(
|
||||
@@ -894,7 +924,7 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
|
||||
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq)
|
||||
data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors)
|
||||
if return_uri:
|
||||
return data, feature_uri
|
||||
else:
|
||||
@@ -907,6 +937,13 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
- using single-process implementation.
|
||||
|
||||
"""
|
||||
# TODO: support inst_processors, need to change the code of qlib-server at the same time
|
||||
# FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date
|
||||
if inst_processors:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support inst_processor. "
|
||||
f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`"
|
||||
)
|
||||
self.conn.send_request(
|
||||
request_type="feature",
|
||||
request_content={
|
||||
@@ -926,7 +963,7 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
get_module_logger("data").debug("get result")
|
||||
try:
|
||||
# pre-mound nfs, used for demo
|
||||
mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri)
|
||||
mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri)
|
||||
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
|
||||
get_module_logger("data").debug("finish slicing data")
|
||||
if return_uri:
|
||||
@@ -964,6 +1001,7 @@ class BaseProvider:
|
||||
end_time=None,
|
||||
freq="day",
|
||||
disk_cache=None,
|
||||
inst_processors=[],
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@@ -978,9 +1016,11 @@ class BaseProvider:
|
||||
disk_cache = C.default_disk_cache if disk_cache is None else disk_cache
|
||||
fields = list(fields) # In case of tuple.
|
||||
try:
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
return DatasetD.dataset(
|
||||
instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors
|
||||
)
|
||||
except TypeError:
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq)
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, inst_processors=inst_processors)
|
||||
|
||||
|
||||
class LocalProvider(BaseProvider):
|
||||
@@ -1028,13 +1068,21 @@ class ClientProvider(BaseProvider):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def is_instance_of_provider(instance: object, cls: type):
|
||||
if isinstance(instance, Wrapper):
|
||||
p = getattr(instance, "_provider", None)
|
||||
|
||||
return False if p is None else isinstance(p, cls)
|
||||
|
||||
return isinstance(instance, cls)
|
||||
|
||||
from .client import Client
|
||||
|
||||
self.client = Client(C.flask_server, C.flask_port)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
if isinstance(Cal, ClientCalendarProvider):
|
||||
if is_instance_of_provider(Cal, ClientCalendarProvider):
|
||||
Cal.set_conn(self.client)
|
||||
if isinstance(Inst, ClientInstrumentProvider):
|
||||
if is_instance_of_provider(Inst, ClientInstrumentProvider):
|
||||
Inst.set_conn(self.client)
|
||||
if hasattr(DatasetD, "provider"):
|
||||
DatasetD.provider.set_conn(self.client)
|
||||
|
||||
@@ -18,6 +18,7 @@ from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import fetch_df_by_index
|
||||
from ...utils import lazy_sort_index
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
@@ -146,7 +147,8 @@ class DataHandler(Serializable):
|
||||
# Setup data.
|
||||
# _data may be with multiple column index level. The outer level indicates the feature set name
|
||||
with TimeInspector.logt("Loading data"):
|
||||
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
|
||||
# make sure the fetch method is based on a index-sorted pd.DataFrame
|
||||
self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time))
|
||||
# TODO: cache
|
||||
|
||||
CS_ALL = "__all" # return all columns with single-level index column
|
||||
@@ -293,11 +295,14 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
# process type
|
||||
PTYPE_I = "independent"
|
||||
# - self._infer will be processed by infer_processors
|
||||
# - self._learn will be processed by learn_processors
|
||||
# - self._infer will be processed by shared_processors + infer_processors
|
||||
# - self._learn will be processed by shared_processors + learn_processors
|
||||
|
||||
# NOTE:
|
||||
PTYPE_A = "append"
|
||||
# - self._infer will be processed by infer_processors
|
||||
# - self._learn will be processed by infer_processors + learn_processors
|
||||
|
||||
# - self._infer will be processed by shared_processors + infer_processors
|
||||
# - self._learn will be processed by shared_processors + infer_processors + learn_processors
|
||||
# - (e.g. self._infer processed by learn_processors )
|
||||
|
||||
def __init__(
|
||||
@@ -306,8 +311,9 @@ class DataHandlerLP(DataHandler):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Union[dict, str, DataLoader] = None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
infer_processors: List = [],
|
||||
learn_processors: List = [],
|
||||
shared_processors: List = [],
|
||||
process_type=PTYPE_A,
|
||||
drop_raw=False,
|
||||
**kwargs,
|
||||
@@ -358,7 +364,8 @@ class DataHandlerLP(DataHandler):
|
||||
# Setup preprocessor
|
||||
self.infer_processors = [] # for lint
|
||||
self.learn_processors = [] # for lint
|
||||
for pname in "infer_processors", "learn_processors":
|
||||
self.shared_processors = [] # for lint
|
||||
for pname in "infer_processors", "learn_processors", "shared_processors":
|
||||
for proc in locals()[pname]:
|
||||
getattr(self, pname).append(
|
||||
init_instance_by_config(
|
||||
@@ -373,9 +380,12 @@ class DataHandlerLP(DataHandler):
|
||||
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
|
||||
|
||||
def get_all_processors(self):
|
||||
return self.infer_processors + self.learn_processors
|
||||
return self.shared_processors + self.infer_processors + self.learn_processors
|
||||
|
||||
def fit(self):
|
||||
"""
|
||||
fit data without processing the data
|
||||
"""
|
||||
for proc in self.get_all_processors():
|
||||
with TimeInspector.logt(f"{proc.__class__.__name__}"):
|
||||
proc.fit(self._data)
|
||||
@@ -388,30 +398,68 @@ class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
self.process_data(with_fit=True)
|
||||
|
||||
@staticmethod
|
||||
def _run_proc_l(
|
||||
df: pd.DataFrame, proc_l: List[processor_module.Processor], with_fit: bool, check_for_infer: bool
|
||||
) -> pd.DataFrame:
|
||||
for proc in proc_l:
|
||||
if check_for_infer and not proc.is_for_infer():
|
||||
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
|
||||
with TimeInspector.logt(f"{proc.__class__.__name__}"):
|
||||
if with_fit:
|
||||
proc.fit(df)
|
||||
df = proc(df)
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def _is_proc_readonly(proc_l: List[processor_module.Processor]):
|
||||
"""
|
||||
NOTE: it will return True if `len(proc_l) == 0`
|
||||
"""
|
||||
for p in proc_l:
|
||||
if not p.readonly():
|
||||
return False
|
||||
return True
|
||||
|
||||
def process_data(self, with_fit: bool = False):
|
||||
"""
|
||||
process_data data. Fun `processor.fit` if necessary
|
||||
|
||||
Notation: (data) [processor]
|
||||
|
||||
# data processing flow of self.process_type == DataHandlerLP.PTYPE_I
|
||||
(self._data)-[shared_processors]-(_shared_df)-[learn_processors]-(_learn_df)
|
||||
\
|
||||
-[infer_processors]-(_infer_df)
|
||||
|
||||
# data processing flow of self.process_type == DataHandlerLP.PTYPE_A
|
||||
(self._data)-[shared_processors]-(_shared_df)-[infer_processors]-(_infer_df)-[learn_processors]-(_learn_df)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
with_fit : bool
|
||||
The input of the `fit` will be the output of the previous processor
|
||||
"""
|
||||
# data for inference
|
||||
_infer_df = self._data
|
||||
if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data
|
||||
_infer_df = _infer_df.copy()
|
||||
# shared data processors
|
||||
# 1) assign
|
||||
_shared_df = self._data
|
||||
if not self._is_proc_readonly(self.shared_processors): # avoid modifying the original data
|
||||
_shared_df = _shared_df.copy()
|
||||
# 2) process
|
||||
_shared_df = self._run_proc_l(_shared_df, self.shared_processors, with_fit=with_fit, check_for_infer=True)
|
||||
|
||||
# data for inference
|
||||
# 1) assign
|
||||
_infer_df = _shared_df
|
||||
if not self._is_proc_readonly(self.infer_processors): # avoid modifying the original data
|
||||
_infer_df = _infer_df.copy()
|
||||
# 2) process
|
||||
_infer_df = self._run_proc_l(_infer_df, self.infer_processors, with_fit=with_fit, check_for_infer=True)
|
||||
|
||||
for proc in self.infer_processors:
|
||||
if not proc.is_for_infer():
|
||||
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
|
||||
with TimeInspector.logt(f"{proc.__class__.__name__}"):
|
||||
if with_fit:
|
||||
proc.fit(_infer_df)
|
||||
_infer_df = proc(_infer_df)
|
||||
self._infer = _infer_df
|
||||
|
||||
# data for learning
|
||||
# 1) assign
|
||||
if self.process_type == DataHandlerLP.PTYPE_I:
|
||||
_learn_df = self._data
|
||||
elif self.process_type == DataHandlerLP.PTYPE_A:
|
||||
@@ -419,14 +467,11 @@ class DataHandlerLP(DataHandler):
|
||||
_learn_df = _infer_df
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
if len(self.learn_processors) > 0: # avoid modifying the original data
|
||||
if not self._is_proc_readonly(self.learn_processors): # avoid modifying the original data
|
||||
_learn_df = _learn_df.copy()
|
||||
for proc in self.learn_processors:
|
||||
with TimeInspector.logt(f"{proc.__class__.__name__}"):
|
||||
if with_fit:
|
||||
proc.fit(_learn_df)
|
||||
_learn_df = proc(_learn_df)
|
||||
# 2) process
|
||||
_learn_df = self._run_proc_l(_learn_df, self.learn_processors, with_fit=with_fit, check_for_infer=False)
|
||||
|
||||
self._learn = _learn_df
|
||||
|
||||
if self.drop_raw:
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import abc
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple, Union, List
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.data import filter as filter_module
|
||||
from qlib.data.filter import BaseDFilter
|
||||
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
@@ -62,11 +58,11 @@ class DLWParser(DataLoader):
|
||||
Extracting this class so that QlibDataLoader and other dataloaders(such as QdbDataLoader) can share the fields.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Tuple[list, tuple, dict]):
|
||||
def __init__(self, config: Union[list, tuple, dict]):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
config : Tuple[list, tuple, dict]
|
||||
config : Union[list, tuple, dict]
|
||||
Config will be used to describe the fields and column names
|
||||
|
||||
.. code-block::
|
||||
@@ -88,7 +84,7 @@ class DLWParser(DataLoader):
|
||||
else:
|
||||
self.fields = self._parse_fields_info(config)
|
||||
|
||||
def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]:
|
||||
def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, list]:
|
||||
if len(fields_info) == 0:
|
||||
raise ValueError("The size of fields must be greater than 0")
|
||||
|
||||
@@ -104,7 +100,15 @@ class DLWParser(DataLoader):
|
||||
return exprs, names
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
def load_group_df(
|
||||
self,
|
||||
instruments,
|
||||
exprs: list,
|
||||
names: list,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
gp_name: str = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
load the dataframe for specific group
|
||||
|
||||
@@ -128,7 +132,7 @@ class DLWParser(DataLoader):
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
{
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
|
||||
for grp, (exprs, names) in self.fields.items()
|
||||
},
|
||||
axis=1,
|
||||
@@ -142,7 +146,14 @@ class DLWParser(DataLoader):
|
||||
class QlibDataLoader(DLWParser):
|
||||
"""Same as QlibDataLoader. The fields can be define by config"""
|
||||
|
||||
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True, freq="day"):
|
||||
def __init__(
|
||||
self,
|
||||
config: Tuple[list, tuple, dict],
|
||||
filter_pipe: List = None,
|
||||
swap_level: bool = True,
|
||||
freq: Union[str, dict] = "day",
|
||||
inst_processor: dict = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -152,20 +163,41 @@ class QlibDataLoader(DLWParser):
|
||||
Filter pipe for the instruments
|
||||
swap_level :
|
||||
Whether to swap level of MultiIndex
|
||||
freq: dict or str
|
||||
If type(config) == dict and type(freq) == str, load config data using freq.
|
||||
If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
|
||||
inst_processor: dict
|
||||
If inst_processor is not None and type(config) == dict; load config[<group_name>] data using inst_processor[<group_name>]
|
||||
"""
|
||||
if filter_pipe is not None:
|
||||
assert isinstance(filter_pipe, list), "The type of `filter_pipe` must be list."
|
||||
filter_pipe = [
|
||||
init_instance_by_config(fp, None if "module_path" in fp else filter_module, accept_types=BaseDFilter)
|
||||
for fp in filter_pipe
|
||||
]
|
||||
|
||||
self.filter_pipe = filter_pipe
|
||||
self.swap_level = swap_level
|
||||
self.freq = freq
|
||||
|
||||
# sample
|
||||
self.inst_processor = inst_processor if inst_processor is not None else {}
|
||||
assert isinstance(self.inst_processor, dict), f"inst_processor(={self.inst_processor}) must be dict"
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if self.is_group:
|
||||
# check sample config
|
||||
if isinstance(freq, dict):
|
||||
for _gp in config.keys():
|
||||
if _gp not in freq:
|
||||
raise ValueError(f"freq(={freq}) missing group(={_gp})")
|
||||
assert (
|
||||
self.inst_processor
|
||||
), f"freq(={self.freq}), inst_processor(={self.inst_processor}) cannot be None/empty"
|
||||
|
||||
def load_group_df(
|
||||
self,
|
||||
instruments,
|
||||
exprs: list,
|
||||
names: list,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
gp_name: str = None,
|
||||
) -> pd.DataFrame:
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
instruments = "all"
|
||||
@@ -174,7 +206,10 @@ class QlibDataLoader(DLWParser):
|
||||
elif self.filter_pipe is not None:
|
||||
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
|
||||
|
||||
df = D.features(instruments, exprs, start_time, end_time, self.freq)
|
||||
freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
|
||||
df = D.features(
|
||||
instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, [])
|
||||
)
|
||||
df.columns = names
|
||||
if self.swap_level:
|
||||
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
|
||||
@@ -199,6 +234,10 @@ class StaticDataLoader(DataLoader):
|
||||
self.join = join
|
||||
self._data = None
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
# avoid pickling `self._data`
|
||||
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
self._maybe_load_raw_data()
|
||||
if instruments is None:
|
||||
|
||||
@@ -73,6 +73,14 @@ class Processor(Serializable):
|
||||
"""
|
||||
return True
|
||||
|
||||
def readonly(self) -> bool:
|
||||
"""
|
||||
Does the processor treat the input data readonly (i.e. does not write the input data) when processsing
|
||||
|
||||
Knowning the readonly information is helpful to the Handler to avoid uncessary copy
|
||||
"""
|
||||
return False
|
||||
|
||||
def config(self, **kwargs):
|
||||
attr_list = {"fit_start_time", "fit_end_time"}
|
||||
for k, v in kwargs.items():
|
||||
@@ -92,6 +100,9 @@ class DropnaProcessor(Processor):
|
||||
def __call__(self, df):
|
||||
return df.dropna(subset=get_group_columns(df, self.fields_group))
|
||||
|
||||
def readonly(self):
|
||||
return True
|
||||
|
||||
|
||||
class DropnaLabel(DropnaProcessor):
|
||||
def __init__(self, fields_group="label"):
|
||||
@@ -113,6 +124,9 @@ class DropCol(Processor):
|
||||
mask = df.columns.isin(self.col_list)
|
||||
return df.loc[:, ~mask]
|
||||
|
||||
def readonly(self):
|
||||
return True
|
||||
|
||||
|
||||
class FilterCol(Processor):
|
||||
def __init__(self, fields_group="feature", col_list=[]):
|
||||
@@ -128,6 +142,9 @@ class FilterCol(Processor):
|
||||
mask = df.columns.get_level_values(-1).isin(self.col_list)
|
||||
return df.loc[:, mask]
|
||||
|
||||
def readonly(self):
|
||||
return True
|
||||
|
||||
|
||||
class TanhProcess(Processor):
|
||||
"""Use tanh to process noise data"""
|
||||
|
||||
23
qlib/data/inst_processor.py
Normal file
23
qlib/data/inst_processor.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import abc
|
||||
import json
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class InstProcessor:
|
||||
@abc.abstractmethod
|
||||
def __call__(self, df: pd.DataFrame, *args, **kwargs):
|
||||
"""
|
||||
process the data
|
||||
|
||||
NOTE: **The processor could change the content of `df` inplace !!!!! **
|
||||
User should keep a copy of data outside
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The raw_df of handler or result from previous processor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.__class__.__name__}:{json.dumps(self.__dict__, sort_keys=True, default=str)}"
|
||||
@@ -15,7 +15,7 @@ from scipy.stats import percentileofscore
|
||||
|
||||
from .base import Expression, ExpressionOps
|
||||
from ..log import get_module_logger
|
||||
from ..utils import get_cls_kwargs
|
||||
from ..utils import get_callable_kwargs
|
||||
|
||||
try:
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
@@ -1513,7 +1513,7 @@ class OpsWrapper:
|
||||
"""
|
||||
for _operator in ops_list:
|
||||
if isinstance(_operator, dict):
|
||||
_ops_class, _ = get_cls_kwargs(_operator)
|
||||
_ops_class, _ = get_callable_kwargs(_operator)
|
||||
else:
|
||||
_ops_class = _operator
|
||||
|
||||
|
||||
@@ -105,6 +105,20 @@ class AverageEnsemble(Ensemble):
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
|
||||
"""using sample:
|
||||
from qlib.model.ens.ensemble import AverageEnsemble
|
||||
pred_res['new_key_name'] = AverageEnsemble()(predict_dict)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ensemble_dict : dict
|
||||
Dictionary you want to ensemble
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
The dictionary including ensenbling result
|
||||
"""
|
||||
# need to flatten the nested dict
|
||||
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
|
||||
values = list(ensemble_dict.values())
|
||||
|
||||
@@ -12,13 +12,12 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
|
||||
"""
|
||||
|
||||
import socket
|
||||
import time
|
||||
from typing import Callable, List
|
||||
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
|
||||
from qlib.utils import flatten_dict, get_callable_kwargs, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.recorder import Recorder
|
||||
@@ -72,7 +71,7 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
|
||||
cls, kwargs = get_callable_kwargs(record, default_module="qlib.workflow.record_temp")
|
||||
if cls is SignalRecord:
|
||||
rconf = {"model": model, "dataset": dataset, "recorder": rec}
|
||||
else:
|
||||
|
||||
@@ -43,8 +43,9 @@ def get_redis_connection():
|
||||
|
||||
|
||||
#################### Data ####################
|
||||
def read_bin(file_path, start_index, end_index):
|
||||
with open(file_path, "rb") as f:
|
||||
def read_bin(file_path: Union[str, Path], start_index, end_index):
|
||||
file_path = Path(file_path.expanduser().resolve())
|
||||
with file_path.open("rb") as f:
|
||||
# read start_index
|
||||
ref_start_index = int(np.frombuffer(f.read(4), dtype="<f")[0])
|
||||
si = max(ref_start_index, start_index)
|
||||
@@ -189,9 +190,9 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]):
|
||||
return module
|
||||
|
||||
|
||||
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
|
||||
def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
|
||||
"""
|
||||
extract class and kwargs from config info
|
||||
extract class/func and kwargs from config info
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -206,22 +207,22 @@ def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleTy
|
||||
Returns
|
||||
-------
|
||||
(type, dict):
|
||||
the class object and it's arguments.
|
||||
the class/func object and it's arguments.
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
module = get_module_by_module_path(config.get("module_path", default_module))
|
||||
|
||||
# raise AttributeError
|
||||
klass = getattr(module, config["class"])
|
||||
_callable = getattr(module, config["class" if "class" in config else "func"])
|
||||
kwargs = config.get("kwargs", {})
|
||||
elif isinstance(config, str):
|
||||
module = get_module_by_module_path(default_module)
|
||||
|
||||
klass = getattr(module, config)
|
||||
_callable = getattr(module, config)
|
||||
kwargs = {}
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return klass, kwargs
|
||||
return _callable, kwargs
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
@@ -272,7 +273,7 @@ def init_instance_by_config(
|
||||
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
|
||||
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
|
||||
return klass(**cls_kwargs, **kwargs)
|
||||
|
||||
|
||||
@@ -570,9 +571,11 @@ def get_pre_trading_date(trading_date, future=False):
|
||||
|
||||
|
||||
def transform_end_date(end_date=None, freq="day"):
|
||||
"""get previous trading date
|
||||
"""handle the end date with various format
|
||||
|
||||
If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned.
|
||||
Otherwise, returns the end_date
|
||||
|
||||
----------
|
||||
end_date: str
|
||||
end trading date
|
||||
@@ -738,7 +741,8 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
|
||||
sorted dataframe
|
||||
"""
|
||||
idx = df.index if axis == 0 else df.columns
|
||||
if idx.is_monotonic_increasing:
|
||||
# NOTE: MultiIndex.is_lexsorted() is a deprecated method in Pandas 1.3.0 and is suggested to be replaced by MultiIndex.is_monotonic_increasing (see discussion here: https://github.com/pandas-dev/pandas/issues/32259). However, in case older versions of Pandas is implemented, MultiIndex.is_lexsorted() is necessary to prevent certain fatal errors.
|
||||
if idx.is_monotonic_increasing and not (isinstance(idx, pd.MultiIndex) and not idx.is_lexsorted()):
|
||||
return df
|
||||
else:
|
||||
return df.sort_index(axis=axis)
|
||||
@@ -792,7 +796,7 @@ class Wrapper:
|
||||
return "{name}(provider={provider})".format(name=self.__class__.__name__, provider=self._provider)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if self._provider is None:
|
||||
if self.__dict__.get("_provider", None) is None:
|
||||
raise AttributeError("Please run qlib.init() first using qlib")
|
||||
return getattr(self._provider, key)
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import typing
|
||||
import dill
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
||||
@@ -18,6 +17,7 @@ class Serializable:
|
||||
|
||||
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
|
||||
default_dump_all = False # if dump all things
|
||||
FLAG_KEY = "_qlib_serial_flag"
|
||||
|
||||
def __init__(self):
|
||||
self._dump_all = self.default_dump_all
|
||||
@@ -45,8 +45,6 @@ class Serializable:
|
||||
"""
|
||||
return getattr(self, "_exclude", [])
|
||||
|
||||
FLAG_KEY = "_qlib_serial_flag"
|
||||
|
||||
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
|
||||
"""
|
||||
configure the serializable object
|
||||
@@ -124,3 +122,22 @@ class Serializable:
|
||||
return dill
|
||||
else:
|
||||
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
|
||||
|
||||
@staticmethod
|
||||
def general_dump(obj, path: Union[Path, str]):
|
||||
"""
|
||||
A general dumping method for object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : object
|
||||
the object to be dumped
|
||||
path : Union[Path, str]
|
||||
the target path the data will be dumped
|
||||
"""
|
||||
path = Path(path)
|
||||
if isinstance(obj, Serializable):
|
||||
obj.to_pickle(path)
|
||||
else:
|
||||
with path.open("wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
|
||||
@@ -38,13 +38,13 @@ class QlibRecorder:
|
||||
.. code-block:: Python
|
||||
|
||||
# start new experiment and recorder
|
||||
with R.start('test', 'recorder_1'):
|
||||
with R.start(experiment_name='test', recorder_name='recorder_1'):
|
||||
model.fit(dataset)
|
||||
R.log...
|
||||
... # further operations
|
||||
|
||||
# resume previous experiment and recorder
|
||||
with R.start('test', 'recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder.
|
||||
with R.start(experiment_name='test', recorder_name='recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder.
|
||||
... # further operations
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -53,7 +53,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder)
|
||||
qlib.init(**config.get("qlib_init"), exp_manager=exp_manager)
|
||||
|
||||
task_train(config.get("task"), experiment_name=experiment_name)
|
||||
recorder = task_train(config.get("task"), experiment_name=experiment_name)
|
||||
recorder.save_objects(config=config)
|
||||
|
||||
|
||||
# function to run worklflow by config
|
||||
|
||||
@@ -325,7 +325,7 @@ class MLflowExperiment(Experiment):
|
||||
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None):
|
||||
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None, filter_string: str = ""):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -334,8 +334,12 @@ class MLflowExperiment(Experiment):
|
||||
status : str
|
||||
the criteria based on status to filter results.
|
||||
`None` indicates no filtering.
|
||||
filter_string : str
|
||||
mlflow supported filter string like 'params."my_param"="a" and tags."my_tag"="b"', use this will help to reduce too much run number.
|
||||
"""
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
|
||||
runs = self._client.search_runs(
|
||||
self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results, filter_string=filter_string
|
||||
)
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from urllib.parse import urlparse
|
||||
import mlflow
|
||||
from filelock import FileLock
|
||||
from mlflow.exceptions import MlflowException
|
||||
from mlflow.entities import ViewType
|
||||
import os, logging
|
||||
@@ -191,6 +193,13 @@ class ExpManager:
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
|
||||
# NOTE: mlflow doesn't consider the lock for recording multiple runs
|
||||
# So we supported it in the interface wrapper
|
||||
pr = urlparse(self.uri)
|
||||
if pr.scheme == "file":
|
||||
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")) as f:
|
||||
return self.create_exp(experiment_name), True
|
||||
return self.create_exp(experiment_name), True
|
||||
|
||||
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import List, Tuple, Union
|
||||
from qlib.data.data import D
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.utils import transform_end_date
|
||||
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import Collector, RecorderCollector
|
||||
@@ -118,6 +119,7 @@ class RollingStrategy(OnlineStrategy):
|
||||
task_template = [task_template]
|
||||
self.task_template = task_template
|
||||
self.rg = rolling_gen
|
||||
assert issubclass(self.rg.__class__, RollingGen), "The rolling strategy relies on the feature if RollingGen"
|
||||
self.tool = OnlineToolR(self.exp_name)
|
||||
self.ta = TimeAdjuster()
|
||||
|
||||
@@ -174,28 +176,20 @@ class RollingStrategy(OnlineStrategy):
|
||||
Returns:
|
||||
List[dict]: a list of new tasks.
|
||||
"""
|
||||
# TODO: filter recorders by latest test segments is not a necessary
|
||||
latest_records, max_test = self._list_latest(self.tool.online_models())
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return []
|
||||
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
|
||||
calendar_latest = transform_end_date(cur_time)
|
||||
self.logger.info(
|
||||
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
)
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
|
||||
old_tasks = []
|
||||
tasks_tmp = []
|
||||
for rec in latest_records:
|
||||
task = rec.load_object("task")
|
||||
old_tasks.append(deepcopy(task))
|
||||
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
|
||||
# modify the test segment to generate new tasks
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
tasks_tmp.append(task)
|
||||
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
|
||||
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
|
||||
return new_tasks
|
||||
return []
|
||||
res = []
|
||||
for rec in latest_records:
|
||||
task = rec.load_object("task")
|
||||
res.extend(self.rg.gen_following_tasks(task, calendar_latest))
|
||||
return res
|
||||
|
||||
def _list_latest(self, rec_list: List[Recorder]):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Updater is a module to update artifacts such as predictions when the stock data is updating.
|
||||
"""
|
||||
@@ -10,11 +9,12 @@ from abc import ABCMeta, abstractmethod
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.data import D
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset import Dataset, DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.model import Model
|
||||
from qlib.utils import get_date_by_shift
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
|
||||
|
||||
class RMDLoader:
|
||||
@@ -72,12 +72,25 @@ class RecordUpdater(metaclass=ABCMeta):
|
||||
...
|
||||
|
||||
|
||||
class PredUpdater(RecordUpdater):
|
||||
class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
||||
"""
|
||||
Update the prediction in the Recorder
|
||||
Dataset-Based Updater
|
||||
- Provding updating feature for Updating data based on Qlib Dataset
|
||||
|
||||
Assumption
|
||||
- Based on Qlib dataset
|
||||
- The data to be updated is a multi-level index pd.DataFrame. For example label , prediction.
|
||||
|
||||
LABEL0
|
||||
datetime instrument
|
||||
2021-05-10 SH600000 0.006965
|
||||
SH600004 0.003407
|
||||
... ...
|
||||
2021-05-28 SZ300498 0.015748
|
||||
SZ300676 -0.001321
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"):
|
||||
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"):
|
||||
"""
|
||||
Init PredUpdater.
|
||||
|
||||
@@ -100,13 +113,27 @@ class PredUpdater(RecordUpdater):
|
||||
self.to_date = to_date
|
||||
self.hist_ref = hist_ref
|
||||
self.freq = freq
|
||||
self.fname = fname
|
||||
self.rmdl = RMDLoader(rec=record)
|
||||
|
||||
latest_date = D.calendar(freq=freq)[-1]
|
||||
if to_date == None:
|
||||
to_date = D.calendar(freq=freq)[-1]
|
||||
self.to_date = pd.Timestamp(to_date)
|
||||
self.old_pred = record.load_object("pred.pkl")
|
||||
self.last_end = self.old_pred.index.get_level_values("datetime").max()
|
||||
to_date = latest_date
|
||||
to_date = pd.Timestamp(to_date)
|
||||
|
||||
if to_date >= latest_date:
|
||||
self.logger.warning(
|
||||
f"The given `to_date`({to_date}) is later than `latest_date`({latest_date}). So `to_date` is clipped to `latest_date`."
|
||||
)
|
||||
to_date = latest_date
|
||||
self.to_date = to_date
|
||||
# FIXME: it will raise error when running routine with delay trainer
|
||||
# should we use another prediction updater for delay trainer?
|
||||
self.old_data: pd.DataFrame = record.load_object(fname)
|
||||
|
||||
# dropna is for being compatible to some data with future information(e.g. label)
|
||||
# The recent label data should be updated together
|
||||
self.last_end = self.old_data.dropna().index.get_level_values("datetime").max()
|
||||
|
||||
def prepare_data(self) -> DatasetH:
|
||||
"""
|
||||
@@ -125,7 +152,7 @@ class PredUpdater(RecordUpdater):
|
||||
|
||||
def update(self, dataset: DatasetH = None):
|
||||
"""
|
||||
Update the prediction in a recorder.
|
||||
Update the data in a recorder.
|
||||
|
||||
Args:
|
||||
DatasetH: the instance of DatasetH. None for reprepare.
|
||||
@@ -137,7 +164,7 @@ class PredUpdater(RecordUpdater):
|
||||
|
||||
if self.last_end >= self.to_date:
|
||||
self.logger.info(
|
||||
f"The prediction in {self.record.info['id']} are latest ({self.last_end}). No need to update to {self.to_date}."
|
||||
f"The data in {self.record.info['id']} are latest ({self.last_end}). No need to update to {self.to_date}."
|
||||
)
|
||||
return
|
||||
|
||||
@@ -146,14 +173,49 @@ class PredUpdater(RecordUpdater):
|
||||
# For reusing the dataset
|
||||
dataset = self.prepare_data()
|
||||
|
||||
self.record.save_objects(**{self.fname: self.get_update_data(dataset)})
|
||||
|
||||
@abstractmethod
|
||||
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
|
||||
"""
|
||||
return the updated data based on the given dataset
|
||||
|
||||
The difference between `get_update_data` and `update`
|
||||
- `update_date` only include some data specific feature
|
||||
- `update` include some general routine steps(e.g. prepare dataset, checking)
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class PredUpdater(DSBasedUpdater):
|
||||
"""
|
||||
Update the prediction in the Recorder
|
||||
"""
|
||||
|
||||
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
|
||||
# Load model
|
||||
model = self.rmdl.get_model()
|
||||
|
||||
new_pred: pd.Series = model.predict(dataset)
|
||||
|
||||
cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0)
|
||||
cb_pred = pd.concat([self.old_data, new_pred.to_frame("score")], axis=0)
|
||||
cb_pred = cb_pred.sort_index()
|
||||
|
||||
self.record.save_objects(**{"pred.pkl": cb_pred})
|
||||
|
||||
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
|
||||
return cb_pred
|
||||
|
||||
|
||||
class LabelUpdater(DSBasedUpdater):
|
||||
"""
|
||||
Update the label in the recorder
|
||||
|
||||
Assumption
|
||||
- The label is generated from record_temp.SignalRecord.
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, to_date=None, **kwargs):
|
||||
super().__init__(record, to_date=to_date, fname="label.pkl", **kwargs)
|
||||
|
||||
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
|
||||
new_label = SignalRecord.generate_label(dataset)
|
||||
cb_data = pd.concat([self.old_data, new_label], axis=0)
|
||||
cb_data = cb_data[~cb_data.index.duplicated(keep="last")].sort_index()
|
||||
return cb_data
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import List, Union
|
||||
from qlib.data.dataset import TSDatasetH
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import get_cls_kwargs
|
||||
from qlib.utils import get_callable_kwargs
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
@@ -172,7 +172,7 @@ class OnlineToolR(OnlineTool):
|
||||
hist_ref = 0
|
||||
task = rec.load_object("task")
|
||||
# Special treatment of historical dependencies
|
||||
cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset")
|
||||
cls, kwargs = get_callable_kwargs(task["dataset"], default_module="qlib.data.dataset")
|
||||
if issubclass(cls, TSDatasetH):
|
||||
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
|
||||
try:
|
||||
|
||||
@@ -121,6 +121,30 @@ class SignalRecord(RecordTemp):
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
|
||||
@staticmethod
|
||||
def generate_label(dataset):
|
||||
# NOTE:
|
||||
# Python doesn't provide the downcasting mechanism.
|
||||
# We use the trick here to downcast the class
|
||||
orig_cls = dataset.__class__
|
||||
dataset.__class__ = DatasetH
|
||||
|
||||
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
|
||||
try:
|
||||
# Assume the backend handler is DataHandlerLP
|
||||
raw_label = dataset.prepare(**params)
|
||||
except TypeError:
|
||||
# The argument number is not right
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = dataset.prepare(**params)
|
||||
except AttributeError:
|
||||
# The data handler is initialize with `drop_raw=True`...
|
||||
# So raw_label is not available
|
||||
raw_label = None
|
||||
dataset.__class__ = orig_cls
|
||||
return raw_label
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# generate prediciton
|
||||
pred = self.model.predict(self.dataset)
|
||||
@@ -136,28 +160,8 @@ class SignalRecord(RecordTemp):
|
||||
pprint(pred.head(5))
|
||||
|
||||
if isinstance(self.dataset, DatasetH):
|
||||
# NOTE:
|
||||
# Python doesn't provide the downcasting mechanism.
|
||||
# We use the trick here to downcast the class
|
||||
orig_cls = self.dataset.__class__
|
||||
self.dataset.__class__ = DatasetH
|
||||
|
||||
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
|
||||
try:
|
||||
# Assume the backend handler is DataHandlerLP
|
||||
raw_label = self.dataset.prepare(**params)
|
||||
except TypeError:
|
||||
# The argument number is not right
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = self.dataset.prepare(**params)
|
||||
except AttributeError:
|
||||
# The data handler is initialize with `drop_raw=True`...
|
||||
# So raw_label is not available
|
||||
raw_label = None
|
||||
|
||||
raw_label = self.generate_label(self.dataset)
|
||||
self.recorder.save_objects(**{"label.pkl": raw_label})
|
||||
self.dataset.__class__ = orig_cls
|
||||
|
||||
def list(self):
|
||||
return ["pred.pkl", "label.pkl"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.utils.serial import Serializable
|
||||
import mlflow, logging
|
||||
import shutil, os, pickle, tempfile, codecs, pickle
|
||||
from pathlib import Path
|
||||
@@ -299,12 +300,16 @@ class MLflowRecorder(Recorder):
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
if local_path is not None:
|
||||
self.client.log_artifacts(self.id, local_path, artifact_path)
|
||||
path = Path(local_path)
|
||||
if path.is_dir():
|
||||
self.client.log_artifacts(self.id, local_path, artifact_path)
|
||||
else:
|
||||
self.client.log_artifact(self.id, local_path, artifact_path)
|
||||
else:
|
||||
temp_dir = Path(tempfile.mkdtemp()).resolve()
|
||||
for name, data in kwargs.items():
|
||||
with (temp_dir / name).open("wb") as f:
|
||||
pickle.dump(data, f)
|
||||
path = temp_dir / name
|
||||
Serializable.general_dump(data, path)
|
||||
self.client.log_artifact(self.id, temp_dir / name, artifact_path)
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@@ -139,6 +139,7 @@ class RecorderCollector(Collector):
|
||||
rec_filter_func=None,
|
||||
artifacts_path={"pred": "pred.pkl"},
|
||||
artifacts_key=None,
|
||||
list_kwargs={},
|
||||
):
|
||||
"""
|
||||
Init RecorderCollector.
|
||||
@@ -150,6 +151,7 @@ class RecorderCollector(Collector):
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
list_kwargs (str): arguments for list_recorders function.
|
||||
"""
|
||||
super().__init__(process_list=process_list)
|
||||
if isinstance(experiment, str):
|
||||
@@ -163,6 +165,7 @@ class RecorderCollector(Collector):
|
||||
self.rec_key_func = rec_key_func
|
||||
self.artifacts_key = artifacts_key
|
||||
self.rec_filter_func = rec_filter_func
|
||||
self.list_kwargs = list_kwargs
|
||||
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
|
||||
"""
|
||||
@@ -187,7 +190,7 @@ class RecorderCollector(Collector):
|
||||
|
||||
collect_dict = {}
|
||||
# filter records
|
||||
recs = self.experiment.list_recorders()
|
||||
recs = self.experiment.list_recorders(**self.list_kwargs)
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
|
||||
@@ -5,6 +5,7 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp
|
||||
"""
|
||||
import abc
|
||||
import copy
|
||||
import pandas as pd
|
||||
from typing import List, Union, Callable
|
||||
|
||||
from qlib.utils import transform_end_date
|
||||
@@ -139,6 +140,53 @@ class RollingGen(TaskGen):
|
||||
self.test_key = "test"
|
||||
self.train_key = "train"
|
||||
|
||||
def _update_task_segs(self, task, segs):
|
||||
# update segments of this task
|
||||
task["dataset"]["kwargs"]["segments"] = copy.deepcopy(segs)
|
||||
if self.ds_extra_mod_func is not None:
|
||||
self.ds_extra_mod_func(task, self)
|
||||
|
||||
def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]:
|
||||
"""
|
||||
generating following rolling tasks for `task` until test_end
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task : dict
|
||||
Qlib task format
|
||||
test_end : pd.Timestamp
|
||||
the latest rolling task includes `test_end`
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[dict]:
|
||||
the following tasks of `task`(`task` itself is excluded)
|
||||
"""
|
||||
prev_seg = task["dataset"]["kwargs"]["segments"]
|
||||
while True:
|
||||
segments = {}
|
||||
try:
|
||||
for k, seg in prev_seg.items():
|
||||
# decide how to shift
|
||||
# expanding only for train data, the segments size of test data and valid data won't change
|
||||
if k == self.train_key and self.rtype == self.ROLL_EX:
|
||||
rtype = self.ta.SHIFT_EX
|
||||
else:
|
||||
rtype = self.ta.SHIFT_SD
|
||||
# shift the segments data
|
||||
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
|
||||
if segments[self.test_key][0] > test_end:
|
||||
break
|
||||
except KeyError:
|
||||
# We reach the end of tasks
|
||||
# No more rolling
|
||||
break
|
||||
|
||||
prev_seg = segments
|
||||
t = copy.deepcopy(task) # deepcopy is necessary to avoid modify task inplace
|
||||
self._update_task_segs(t, segments)
|
||||
yield t
|
||||
|
||||
def generate(self, task: dict) -> List[dict]:
|
||||
"""
|
||||
Converting the task into a rolling task.
|
||||
@@ -191,43 +239,23 @@ class RollingGen(TaskGen):
|
||||
"""
|
||||
res = []
|
||||
|
||||
prev_seg = None
|
||||
test_end = None
|
||||
while True:
|
||||
t = copy.deepcopy(task)
|
||||
t = copy.deepcopy(task)
|
||||
|
||||
# calculate segments
|
||||
if prev_seg is None:
|
||||
# First rolling
|
||||
# 1) prepare the end point
|
||||
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = transform_end_date(segments[self.test_key][1])
|
||||
# 2) and init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
else:
|
||||
segments = {}
|
||||
try:
|
||||
for k, seg in prev_seg.items():
|
||||
# decide how to shift
|
||||
# expanding only for train data, the segments size of test data and valid data won't change
|
||||
if k == self.train_key and self.rtype == self.ROLL_EX:
|
||||
rtype = self.ta.SHIFT_EX
|
||||
else:
|
||||
rtype = self.ta.SHIFT_SD
|
||||
# shift the segments data
|
||||
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
|
||||
if segments[self.test_key][0] > test_end:
|
||||
break
|
||||
except KeyError:
|
||||
# We reach the end of tasks
|
||||
# No more rolling
|
||||
break
|
||||
# calculate segments
|
||||
|
||||
# update segments of this task
|
||||
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
|
||||
prev_seg = segments
|
||||
if self.ds_extra_mod_func is not None:
|
||||
self.ds_extra_mod_func(t, self)
|
||||
res.append(t)
|
||||
# First rolling
|
||||
# 1) prepare the end point
|
||||
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = transform_end_date(segments[self.test_key][1])
|
||||
# 2) and init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
|
||||
# update segments of this task
|
||||
self._update_task_segs(t, segments)
|
||||
|
||||
res.append(t)
|
||||
|
||||
# Update the following rolling
|
||||
res.extend(self.gen_following_tasks(t, test_end))
|
||||
return res
|
||||
|
||||
@@ -47,6 +47,14 @@ class TaskManager:
|
||||
The tasks manager assumes that you will only update the tasks you fetched.
|
||||
The mongo fetch one and update will make it date updating secure.
|
||||
|
||||
This class can be used as a tool from commandline. Here are serveral examples
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python -m qlib.workflow.task.manage -t <pool_name> wait
|
||||
python -m qlib.workflow.task.manage -t <pool_name> task_stat
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
|
||||
@@ -80,7 +88,7 @@ class TaskManager:
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
"""
|
||||
self.task_pool = getattr(get_mongodb(), task_pool)
|
||||
self.task_pool: pymongo.collection.Collection = getattr(get_mongodb(), task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
@staticmethod
|
||||
@@ -101,6 +109,20 @@ class TaskManager:
|
||||
return task
|
||||
|
||||
def _decode_task(self, task):
|
||||
"""
|
||||
_decode_task is Serialization tool.
|
||||
Mongodb needs JSON, so it needs to convert Python objects into JSON objects through pickle
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task : dict
|
||||
task information
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
JSON required by mongodb
|
||||
"""
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
for k in list(task.keys()):
|
||||
if k.startswith(prefix):
|
||||
@@ -211,6 +233,7 @@ class TaskManager:
|
||||
r = self.task_pool.find_one({"filter": t})
|
||||
except InvalidDocument:
|
||||
r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
|
||||
# When r is none, it indicates that r s a new task
|
||||
if r is None:
|
||||
new_tasks.append(t)
|
||||
if not dry_run:
|
||||
@@ -461,11 +484,11 @@ def run_task(
|
||||
|
||||
After running this method, here are 4 situations (before_status -> after_status):
|
||||
|
||||
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param
|
||||
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param, it means that the task has not been started
|
||||
|
||||
STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param
|
||||
|
||||
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param
|
||||
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param, it means that the task has been started but not completed
|
||||
|
||||
STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys, traceback, signal, atexit, logging
|
||||
import atexit
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from ..log import get_module_logger
|
||||
from . import R
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
- [Download Qlib Data](#Download-Qlib-Data)
|
||||
- [Download CN Data](#Download-CN-Data)
|
||||
- [Downlaod US Data](#Downlaod-US-Data)
|
||||
- [Download US Data](#Download-US-Data)
|
||||
- [Download CN Simple Data](#Download-CN-Simple-Data)
|
||||
- [Help](#Help)
|
||||
- [Using in Qlib](#Using-in-Qlib)
|
||||
|
||||
@@ -78,6 +78,7 @@ def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"):
|
||||
data_list.append(_row_data[0])
|
||||
data_list = sorted(data_list)
|
||||
date_list = generate_qlib_calendar(data_list, freq=freq)
|
||||
date_list = sorted(set(daily_calendar.loc[:, 0].values.tolist() + date_list))
|
||||
write_calendar_to_qlib(qlib_dir, date_list, freq=freq)
|
||||
bs.logout()
|
||||
logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31")
|
||||
|
||||
@@ -32,6 +32,7 @@ CALENDAR_BENCH_URL_MAP = {
|
||||
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
|
||||
# NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks
|
||||
"US_ALL": "^GSPC",
|
||||
"IN_ALL": "^NSEI",
|
||||
}
|
||||
|
||||
|
||||
@@ -39,6 +40,7 @@ _BENCH_CALENDAR_LIST = None
|
||||
_ALL_CALENDAR_LIST = None
|
||||
_HS_SYMBOLS = None
|
||||
_US_SYMBOLS = None
|
||||
_IN_SYMBOLS = None
|
||||
_EN_FUND_SYMBOLS = None
|
||||
_CALENDAR_MAP = {}
|
||||
|
||||
@@ -67,7 +69,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
|
||||
calendar = _CALENDAR_MAP.get(bench_code, None)
|
||||
if calendar is None:
|
||||
if bench_code.startswith("US_"):
|
||||
if bench_code.startswith("US_") or bench_code.startswith("IN_"):
|
||||
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")
|
||||
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
|
||||
else:
|
||||
@@ -298,6 +300,47 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
return _US_SYMBOLS
|
||||
|
||||
|
||||
def get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get IN stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _IN_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_nifty():
|
||||
url = f"https://www1.nseindia.com/content/equities/EQUITY_L.csv"
|
||||
df = pd.read_csv(url)
|
||||
df = df.rename(columns={"SYMBOL": "Symbol"})
|
||||
df["Symbol"] = df["Symbol"] + ".NS"
|
||||
_symbols = df["Symbol"].dropna()
|
||||
_symbols = _symbols.unique().tolist()
|
||||
return _symbols
|
||||
|
||||
if _IN_SYMBOLS is None:
|
||||
_all_symbols = _get_nifty()
|
||||
if qlib_data_path is not None:
|
||||
for _index in ["nifty"]:
|
||||
ins_df = pd.read_csv(
|
||||
Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"),
|
||||
sep="\t",
|
||||
names=["symbol", "start_date", "end_date"],
|
||||
)
|
||||
_all_symbols += ins_df["symbol"].unique().tolist()
|
||||
|
||||
def _format(s_):
|
||||
s_ = s_.replace(".", "-")
|
||||
s_ = s_.strip("$")
|
||||
s_ = s_.strip("*")
|
||||
return s_
|
||||
|
||||
_IN_SYMBOLS = sorted(set(_all_symbols))
|
||||
|
||||
return _IN_SYMBOLS
|
||||
|
||||
|
||||
def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get en fund symbols
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ pip install -r requirements.txt
|
||||
- user can append data to `v2`: [automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
- **the [benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks) for qlib use `v1`**, *due to the unstable access to historical data by YahooFinance, there are some differences between `v2` and `v1`*
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
- `region`: `cn` or `us`, by default `cn`
|
||||
- `region`: `cn` or `us` or `in`, by default `cn`
|
||||
- `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`
|
||||
- `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`
|
||||
- examples:
|
||||
@@ -50,6 +50,10 @@ pip install -r requirements.txt
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d --region us --interval 1d
|
||||
# us 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1min --region us --interval 1min
|
||||
# in 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_in_1d --region in --interval 1d
|
||||
# in 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_in_1min --region in --interval 1min
|
||||
```
|
||||
|
||||
### Collector *YahooFinance* data to qlib
|
||||
@@ -60,7 +64,7 @@ pip install -r requirements.txt
|
||||
- `source_dir`: save the directory
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
> **due to the limitation of the *YahooFinance API*, only the last month's data is available in `1min`**
|
||||
- `region`: `CN` or `US`, by default `CN`
|
||||
- `region`: `CN` or `US` or `IN`, by default `CN`
|
||||
- `delay`: `time.sleep(delay)`, by default *0.5*
|
||||
- `start`: start datetime, by default *"2000-01-01"*; *closed interval(including start)*
|
||||
- `end`: end datetime, by default `pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`; *open interval(excluding end)*
|
||||
@@ -71,13 +75,17 @@ pip install -r requirements.txt
|
||||
- examples:
|
||||
```bash
|
||||
# cn 1d data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region CN
|
||||
# cn 1min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --delay 1 --interval 1min --region CN
|
||||
# us 1d data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US
|
||||
# us 1min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1min --delay 1 --interval 1min --region US
|
||||
# in 1d data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/in_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region IN
|
||||
# in 1min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/in_1min --delay 1 --interval 1min --region IN
|
||||
```
|
||||
2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data`
|
||||
|
||||
@@ -87,7 +95,7 @@ pip install -r requirements.txt
|
||||
- `max_workers`: number of concurrent, by default *1*
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
> if **`interval == 1min`**, `qlib_data_1d_dir` cannot be `None`
|
||||
- `region`: `CN` or `US`, by default `CN`
|
||||
- `region`: `CN` or `US` or `IN`, by default `CN`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`
|
||||
|
||||
@@ -34,6 +34,7 @@ from data_collector.utils import (
|
||||
get_calendar_list,
|
||||
get_hs_stock_symbols,
|
||||
get_us_stock_symbols,
|
||||
get_in_stock_symbols,
|
||||
generate_minutes_calendar_from_daily,
|
||||
)
|
||||
|
||||
@@ -279,10 +280,46 @@ class YahooCollectorUS1min(YahooCollectorUS):
|
||||
pass
|
||||
|
||||
|
||||
class YahooCollectorIN(YahooCollector, ABC):
|
||||
def get_instrument_list(self):
|
||||
logger.info("get INDIA stock symbols......")
|
||||
symbols = get_in_stock_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def download_index_data(self):
|
||||
pass
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
return code_to_fname(symbol).upper()
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
return "Asia/Kolkata"
|
||||
|
||||
|
||||
class YahooCollectorIN1d(YahooCollectorIN):
|
||||
pass
|
||||
|
||||
|
||||
class YahooCollectorIN1min(YahooCollectorIN):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalize(BaseNormalize):
|
||||
COLUMNS = ["open", "close", "high", "low", "volume"]
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
@staticmethod
|
||||
def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series:
|
||||
df = df.copy()
|
||||
_tmp_series = df["close"].fillna(method="ffill")
|
||||
_tmp_shift_series = _tmp_series.shift(1)
|
||||
if last_close is not None:
|
||||
_tmp_shift_series.iloc[0] = float(last_close)
|
||||
change_series = _tmp_series / _tmp_shift_series - 1
|
||||
return change_series
|
||||
|
||||
@staticmethod
|
||||
def normalize_yahoo(
|
||||
df: pd.DataFrame,
|
||||
@@ -310,11 +347,29 @@ class YahooNormalize(BaseNormalize):
|
||||
)
|
||||
df.sort_index(inplace=True)
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan
|
||||
_tmp_series = df["close"].fillna(method="ffill")
|
||||
_tmp_shift_series = _tmp_series.shift(1)
|
||||
if last_close is not None:
|
||||
_tmp_shift_series.iloc[0] = float(last_close)
|
||||
df["change"] = _tmp_series / _tmp_shift_series - 1
|
||||
|
||||
change_series = YahooNormalize.calc_change(df, last_close)
|
||||
# NOTE: The data obtained by Yahoo finance sometimes has exceptions
|
||||
# WARNING: If it is normal for a `symbol(exchange)` to differ by a factor of *89* to *111* for consecutive trading days,
|
||||
# WARNING: the logic in the following line needs to be modified
|
||||
_count = 0
|
||||
while True:
|
||||
# NOTE: may appear unusual for many days in a row
|
||||
change_series = YahooNormalize.calc_change(df, last_close)
|
||||
_mask = (change_series >= 89) & (change_series <= 111)
|
||||
if not _mask.any():
|
||||
break
|
||||
_tmp_cols = ["high", "close", "low", "open", "adjclose"]
|
||||
df.loc[_mask, _tmp_cols] = df.loc[_mask, _tmp_cols] / 100
|
||||
_count += 1
|
||||
if _count >= 10:
|
||||
_symbol = df.loc[df[symbol_field_name].first_valid_index()]["symbol"]
|
||||
logger.warning(
|
||||
f"{_symbol} `change` is abnormal for {_count} consecutive days, please check the specific data file carefully"
|
||||
)
|
||||
|
||||
df["change"] = YahooNormalize.calc_change(df, last_close)
|
||||
|
||||
columns += ["change"]
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
||||
|
||||
@@ -710,6 +765,29 @@ class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline):
|
||||
return fname_to_code(symbol)
|
||||
|
||||
|
||||
class YahooNormalizeIN:
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return get_calendar_list("IN_ALL")
|
||||
|
||||
|
||||
class YahooNormalizeIN1d(YahooNormalizeIN, YahooNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1minOffline):
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: support 1min
|
||||
raise ValueError("Does not support 1min")
|
||||
|
||||
def _get_1d_calendar_list(self):
|
||||
return get_calendar_list("IN_ALL")
|
||||
|
||||
def symbol_to_yahoo(self, symbol):
|
||||
return fname_to_code(symbol)
|
||||
|
||||
|
||||
class YahooNormalizeCN:
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: from MSN
|
||||
@@ -852,7 +930,7 @@ class Run(BaseRun):
|
||||
if self.interval.lower() == "1min":
|
||||
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
|
||||
raise ValueError(
|
||||
"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/zhupr/qlib/tree/support_extend_data/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
|
||||
"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
|
||||
)
|
||||
super(Run, self).normalize_data(
|
||||
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
|
||||
|
||||
@@ -244,6 +244,10 @@ class DumpDataBase:
|
||||
if df is None or df.empty:
|
||||
logger.warning(f"{code} data is None or empty")
|
||||
return
|
||||
|
||||
# try to remove dup rows or it will cause exception when reindex.
|
||||
df = df.drop_duplicates(self.date_field_name)
|
||||
|
||||
# features save dir
|
||||
features_dir = self._features_dir.joinpath(code_to_fname(code).lower())
|
||||
features_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
15
setup.py
15
setup.py
@@ -11,7 +11,14 @@ NAME = "pyqlib"
|
||||
DESCRIPTION = "A Quantitative-research Platform"
|
||||
REQUIRES_PYTHON = ">=3.5.0"
|
||||
|
||||
VERSION = "0.7.0"
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
||||
CURRENT_DIR = Path(__file__).absolute().parent
|
||||
_version_src = CURRENT_DIR / "VERSION.txt"
|
||||
_version_dst = CURRENT_DIR / "qlib" / "VERSION.txt"
|
||||
copyfile(_version_src, _version_dst)
|
||||
VERSION = _version_dst.read_text(encoding="utf-8").strip()
|
||||
|
||||
# Detect Cython
|
||||
try:
|
||||
@@ -39,13 +46,13 @@ REQUIRED = [
|
||||
"redis>=3.0.1",
|
||||
"python-redis-lock>=3.3.1",
|
||||
"schedule>=0.6.0",
|
||||
"cvxpy==1.0.21",
|
||||
"cvxpy>=1.0.21",
|
||||
"hyperopt==0.1.1",
|
||||
"fire>=0.3.1",
|
||||
"statsmodels",
|
||||
"xlrd>=1.0.0",
|
||||
"plotly==4.12.0",
|
||||
"matplotlib==3.1.3",
|
||||
"matplotlib>=3.3",
|
||||
"tables>=3.6.1",
|
||||
"pyyaml>=5.3.1",
|
||||
"mlflow>=1.12.1",
|
||||
@@ -58,6 +65,7 @@ REQUIRED = [
|
||||
"pymongo==3.7.2", # For task management
|
||||
"scikit-learn>=0.22",
|
||||
"dill",
|
||||
"filelock",
|
||||
]
|
||||
|
||||
# Numpy include
|
||||
@@ -121,5 +129,6 @@ setup(
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
],
|
||||
)
|
||||
|
||||
117
tests/rolling_tests/test_update_pred.py
Normal file
117
tests/rolling_tests/test_update_pred.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data import D
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
from qlib.workflow.online.utils import OnlineToolR
|
||||
from qlib.workflow.online.update import LabelUpdater
|
||||
|
||||
|
||||
class TestRolling(TestAutoData):
|
||||
_setup_kwargs = dict(expression_cache=None, dataset_cache=None)
|
||||
|
||||
def test_update_pred(self):
|
||||
"""
|
||||
This test is for testing if it will raise error if the `to_date` is out of the boundary.
|
||||
"""
|
||||
task = copy.deepcopy(CSI300_GBDT_TASK)
|
||||
|
||||
task["record"] = {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
}
|
||||
|
||||
exp_name = "online_srv_test"
|
||||
|
||||
cal = D.calendar()
|
||||
latest_date = cal[-1]
|
||||
|
||||
train_start = latest_date - pd.Timedelta(days=61)
|
||||
train_end = latest_date - pd.Timedelta(days=41)
|
||||
task["dataset"]["kwargs"]["segments"] = {
|
||||
"train": (train_start, train_end),
|
||||
"valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)),
|
||||
"test": (latest_date - pd.Timedelta(days=20), latest_date),
|
||||
}
|
||||
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"] = {
|
||||
"start_time": train_start,
|
||||
"end_time": latest_date,
|
||||
"fit_start_time": train_start,
|
||||
"fit_end_time": train_end,
|
||||
"instruments": "csi300",
|
||||
}
|
||||
|
||||
rec = task_train(task, exp_name)
|
||||
|
||||
pred = rec.load_object("pred.pkl")
|
||||
|
||||
online_tool = OnlineToolR(exp_name)
|
||||
online_tool.reset_online_tag(rec) # set to online model
|
||||
|
||||
online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))
|
||||
|
||||
def test_update_label(self):
|
||||
|
||||
task = copy.deepcopy(CSI300_GBDT_TASK)
|
||||
|
||||
task["record"] = {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
}
|
||||
|
||||
exp_name = "online_srv_test"
|
||||
|
||||
cal = D.calendar()
|
||||
shift = 10
|
||||
latest_date = cal[-1 - shift]
|
||||
|
||||
train_start = latest_date - pd.Timedelta(days=61)
|
||||
train_end = latest_date - pd.Timedelta(days=41)
|
||||
task["dataset"]["kwargs"]["segments"] = {
|
||||
"train": (train_start, train_end),
|
||||
"valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)),
|
||||
"test": (latest_date - pd.Timedelta(days=20), latest_date),
|
||||
}
|
||||
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"] = {
|
||||
"start_time": train_start,
|
||||
"end_time": latest_date,
|
||||
"fit_start_time": train_start,
|
||||
"fit_end_time": train_end,
|
||||
"instruments": "csi300",
|
||||
}
|
||||
|
||||
rec = task_train(task, exp_name)
|
||||
|
||||
pred = rec.load_object("pred.pkl")
|
||||
|
||||
online_tool = OnlineToolR(exp_name)
|
||||
online_tool.reset_online_tag(rec) # set to online model
|
||||
online_tool.update_online_pred()
|
||||
|
||||
new_pred = rec.load_object("pred.pkl")
|
||||
label = rec.load_object("label.pkl")
|
||||
label_date = label.dropna().index.get_level_values("datetime").max()
|
||||
pred_date = new_pred.dropna().index.get_level_values("datetime").max()
|
||||
|
||||
# The prediction is updated, but the label is not updated.
|
||||
self.assertTrue(label_date < pred_date)
|
||||
|
||||
# Update label now
|
||||
lu = LabelUpdater(rec)
|
||||
lu.update()
|
||||
new_label = rec.load_object("label.pkl")
|
||||
new_label_date = new_label.index.get_level_values("datetime").max()
|
||||
self.assertTrue(new_label_date == pred_date) # make sure the label is updated now
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -5,7 +5,6 @@
|
||||
from pathlib import Path
|
||||
from collections.abc import Iterable
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
@@ -33,13 +32,13 @@ class TestStorage(TestAutoData):
|
||||
print(f"calendar[-1]: {calendar[-1]}")
|
||||
|
||||
calendar = CalendarStorage(freq="1min", future=False, provider_uri="not_found")
|
||||
with pytest.raises(ValueError):
|
||||
with self.assertRaises(ValueError):
|
||||
print(calendar.data)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with self.assertRaises(ValueError):
|
||||
print(calendar[:])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with self.assertRaises(ValueError):
|
||||
print(calendar[0])
|
||||
|
||||
def test_instrument_storage(self):
|
||||
@@ -90,10 +89,10 @@ class TestStorage(TestAutoData):
|
||||
print(f"instrument['SH600000']: {instrument['SH600000']}")
|
||||
|
||||
instrument = InstrumentStorage(market="csi300", provider_uri="not_found")
|
||||
with pytest.raises(ValueError):
|
||||
with self.assertRaises(ValueError):
|
||||
print(instrument.data)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with self.assertRaises(ValueError):
|
||||
print(instrument["sSH600000"])
|
||||
|
||||
def test_feature_storage(self):
|
||||
@@ -150,15 +149,15 @@ class TestStorage(TestAutoData):
|
||||
|
||||
"""
|
||||
|
||||
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri=self.provider_uri)
|
||||
feature = FeatureStorage(instrument="SZ300677", field="close", freq="day", provider_uri=self.provider_uri)
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
with self.assertRaises(IndexError):
|
||||
print(feature[0])
|
||||
assert isinstance(
|
||||
feature[815][1], (float, np.float32)
|
||||
feature[3049][1], (float, np.float32)
|
||||
), f"{feature.__class__.__name__}.__getitem__(i: int) error"
|
||||
assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error"
|
||||
print(f"feature[815: 818]: \n{feature[815: 818]}")
|
||||
assert len(feature[3049:3052]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error"
|
||||
print(f"feature[3049: 3052]: \n{feature[3049: 3052]}")
|
||||
|
||||
print(f"feature[:].tail(): \n{feature[:].tail()}")
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ port_analysis_config = {
|
||||
}
|
||||
|
||||
|
||||
def train():
|
||||
def train(uri_path: str = None):
|
||||
"""train model
|
||||
|
||||
Returns
|
||||
@@ -55,7 +55,7 @@ def train():
|
||||
print(R)
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
with R.start(experiment_name="workflow", uri=uri_path):
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
|
||||
@@ -79,7 +79,7 @@ def train():
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
|
||||
def train_with_sigana():
|
||||
def train_with_sigana(uri_path: str = None):
|
||||
"""train model followed by SigAnaRecord
|
||||
|
||||
Returns
|
||||
@@ -91,9 +91,8 @@ def train_with_sigana():
|
||||
"""
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow_with_sigana"):
|
||||
with R.start(experiment_name="workflow_with_sigana", uri=uri_path):
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
|
||||
@@ -130,7 +129,7 @@ def fake_experiment():
|
||||
return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri
|
||||
|
||||
|
||||
def backtest_analysis(pred, rid):
|
||||
def backtest_analysis(pred, rid, uri_path: str = None):
|
||||
"""backtest and analysis
|
||||
|
||||
Parameters
|
||||
@@ -139,6 +138,8 @@ def backtest_analysis(pred, rid):
|
||||
predict scores
|
||||
rid : str
|
||||
the id of the recorder to be used in this function
|
||||
uri_path: str
|
||||
mlflow uri path
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -146,7 +147,8 @@ def backtest_analysis(pred, rid):
|
||||
the analysis result
|
||||
|
||||
"""
|
||||
recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
|
||||
with R.start(experiment_name="workflow", recorder_id=rid, uri=uri_path):
|
||||
recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
|
||||
# backtest
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par.generate()
|
||||
@@ -160,24 +162,24 @@ class TestAllFlow(TestAutoData):
|
||||
REPORT_NORMAL = None
|
||||
POSITIONS = None
|
||||
RID = None
|
||||
URI_PATH = "file:" + str(Path(__file__).parent.joinpath("test_all_flow_mlruns").resolve())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))
|
||||
shutil.rmtree(cls.URI_PATH.lstrip("file:"))
|
||||
|
||||
def test_0_train_with_sigana(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana()
|
||||
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana(self.URI_PATH)
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
def test_1_train(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
|
||||
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train(self.URI_PATH)
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
|
||||
def test_2_backtest(self):
|
||||
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
|
||||
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID, self.URI_PATH)
|
||||
self.assertGreaterEqual(
|
||||
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
|
||||
0.10,
|
||||
|
||||
@@ -12,10 +12,10 @@ from qlib.tests import TestAutoData
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
|
||||
def train_multiseg():
|
||||
def train_multiseg(uri_path: str = None):
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
with R.start(experiment_name="workflow", uri=uri_path):
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
@@ -25,10 +25,10 @@ def train_multiseg():
|
||||
return uri
|
||||
|
||||
|
||||
def train_mse():
|
||||
def train_mse(uri_path: str = None):
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
with R.start(experiment_name="workflow", uri=uri_path):
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
@@ -39,13 +39,17 @@ def train_mse():
|
||||
|
||||
|
||||
class TestAllFlow(TestAutoData):
|
||||
URI_PATH = "file:" + str(Path(__file__).parent.joinpath("test_contrib_mlruns").resolve())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
shutil.rmtree(cls.URI_PATH.lstrip("file:"))
|
||||
|
||||
def test_0_multiseg(self):
|
||||
uri_path = train_multiseg()
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
uri_path = train_multiseg(self.URI_PATH)
|
||||
|
||||
def test_1_mse(self):
|
||||
uri_path = train_mse()
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
uri_path = train_mse(self.URI_PATH)
|
||||
|
||||
|
||||
def suite():
|
||||
|
||||
Reference in New Issue
Block a user