mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 09:01:18 +08:00
Compare commits
35 Commits
v0.7.1
...
high-freq-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56edc16089 | ||
|
|
2b8462d137 | ||
|
|
1979cac50a | ||
|
|
424a48d0fb | ||
|
|
202bbea272 | ||
|
|
6a22136366 | ||
|
|
603c282415 | ||
|
|
22abe852f7 | ||
|
|
e3f463010b | ||
|
|
80aa08215f | ||
|
|
b3893067f7 | ||
|
|
e6dfccce2f | ||
|
|
f9c30f9834 | ||
|
|
f164bf8411 | ||
|
|
1f28044d84 | ||
|
|
3cf0d27a07 | ||
|
|
bcae4bb22e | ||
|
|
f680a564a0 | ||
|
|
9cd41e5a81 | ||
|
|
e23022e9d8 | ||
|
|
ebbbec2a6c | ||
|
|
13d39e6bbc | ||
|
|
b96aab6bef | ||
|
|
700eef4164 | ||
|
|
31c7d72485 | ||
|
|
30ad1967a2 | ||
|
|
0c6cad1d7b | ||
|
|
a0f22571de | ||
|
|
6835b2f67e | ||
|
|
7c4971e566 | ||
|
|
70a9d42c7d | ||
|
|
bcadf47f32 | ||
|
|
4dc14a2489 | ||
|
|
a03b08bb4c | ||
|
|
98086e4fdc |
@@ -1,12 +0,0 @@
|
||||
version = 1
|
||||
|
||||
test_patterns = ["tests/test_*.py"]
|
||||
|
||||
exclude_patterns = ["examples/**"]
|
||||
|
||||
[[analyzers]]
|
||||
name = "python"
|
||||
enabled = true
|
||||
|
||||
[analyzers.meta]
|
||||
runtime_version = "3.x.x"
|
||||
62
.github/stale.yml
vendored
Normal file
62
.github/stale.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
# Configuration for probot-stale - https://github.com/probot/stale
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request becomes stale
|
||||
daysUntilStale: 60
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
|
||||
# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
|
||||
daysUntilClose: 7
|
||||
|
||||
# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
|
||||
onlyLabels: []
|
||||
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
exemptLabels:
|
||||
- bug
|
||||
- pinned
|
||||
- security
|
||||
- "[Status] Maybe Later"
|
||||
|
||||
# Set to true to ignore issues in a project (defaults to false)
|
||||
exemptProjects: false
|
||||
|
||||
# Set to true to ignore issues in a milestone (defaults to false)
|
||||
exemptMilestones: false
|
||||
|
||||
# Set to true to ignore issues with an assignee (defaults to false)
|
||||
exemptAssignees: false
|
||||
|
||||
# Label to use when marking as stale
|
||||
staleLabel: wontfix
|
||||
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Thank you
|
||||
for your contributions.
|
||||
|
||||
# Comment to post when removing the stale label.
|
||||
# unmarkComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Comment to post when closing a stale Issue or Pull Request.
|
||||
# closeComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Limit the number of actions per hour, from 1-30. Default is 30
|
||||
limitPerRun: 30
|
||||
|
||||
# Limit to only `issues` or `pulls`
|
||||
# only: issues
|
||||
|
||||
# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
|
||||
# pulls:
|
||||
# daysUntilStale: 30
|
||||
# markComment: >
|
||||
# This pull request has been automatically marked as stale because it has not had
|
||||
# recent activity. It will be closed if no further activity occurs. Thank you
|
||||
# for your contributions.
|
||||
|
||||
# issues:
|
||||
# exemptLabels:
|
||||
# - confirmed
|
||||
2
.github/workflows/python-publish.yml
vendored
2
.github/workflows/python-publish.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
24
.github/workflows/stale.yml
vendored
24
.github/workflows/stale.yml
vendored
@@ -1,24 +0,0 @@
|
||||
name: Mark stale issues and pull requests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0/3 * * *"
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v3
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: 'This issue is stale because it has been open for three months with no activity. Remove the stale label or comment on the issue otherwise this will be closed in 5 days'
|
||||
stale-pr-message: 'This PR is stale because it has been open for a year with no activity. Remove the stale label or comment on the PR otherwise this will be closed in 5 days'
|
||||
stale-issue-label: 'stale'
|
||||
stale-pr-label: 'stale'
|
||||
days-before-stale: 90
|
||||
days-before-close: 5
|
||||
operations-per-run: 100
|
||||
exempt-issue-labels: 'bug,enhancement'
|
||||
remove-stale-when-updated: true
|
||||
93
.github/workflows/test.yml
vendored
93
.github/workflows/test.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Test
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -12,8 +12,8 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -25,41 +25,94 @@ jobs:
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install black wheel
|
||||
black qlib -l 120 --check --diff
|
||||
cd ..
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install black
|
||||
$CONDA\\python.exe -m black qlib -l 120 --check --diff
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install black
|
||||
$CONDA/bin/python -m black qlib -l 120 --check --diff
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
# Test Qlib installed with pip
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
pip install numpy==1.19.5 ruamel.yaml
|
||||
pip install pyqlib --ignore-installed
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml --user
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
else
|
||||
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
|
||||
# Test Qlib installed from source
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
$CONDA\\python.exe -m pip uninstall -y pyqlib
|
||||
else
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
sudo $CONDA/bin/python -m pip uninstall -y pyqlib
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install -e .
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade cython
|
||||
$CONDA\\python.exe -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
$CONDA\\python.exe -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
$CONDA\\python.exe setup.py install
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade cython
|
||||
sudo $CONDA/bin/python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
sudo $CONDA/bin/python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
sudo $CONDA/bin/python setup.py install
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install black pytest
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install --upgrade pip
|
||||
$CONDA\\python.exe -m pip install black pytest
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install --upgrade pip
|
||||
sudo $CONDA/bin/python -m pip install black pytest
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
python -m pytest . --durations=10
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pytest . --durations=0
|
||||
else
|
||||
$CONDA/bin/python -m pytest . --durations=0
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe qlib\\workflow\\cli.py examples\\benchmarks\\LightGBM\\workflow_config_lightgbm_Alpha158.yaml
|
||||
else
|
||||
$CONDA/bin/python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
fi
|
||||
shell: bash
|
||||
67
.github/workflows/test_macos.yml
vendored
67
.github/workflows/test_macos.yml
vendored
@@ -1,67 +0,0 @@
|
||||
# There are some issues (in the downloading data phase) on MacOS when running with other tests. So we split it into an individual config.
|
||||
name: Test MacOS
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: macos-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
python -m pip install pip --upgrade
|
||||
python -m pip install wheel --upgrade
|
||||
python -m pip install black
|
||||
python -m black qlib -l 120 --check --diff
|
||||
# Test Qlib installed with pip
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
python -m pip install numpy==1.19.5
|
||||
python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
- name: Install Lightgbm for MacOS
|
||||
run: |
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
- name: Test workflow by config (install from pip)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
python -m pip uninstall -y pyqlib
|
||||
# Test Qlib installed from source
|
||||
- name: Install Qlib from source
|
||||
run: |
|
||||
python -m pip install --upgrade cython
|
||||
python -m pip install numpy jupyter jupyter_contrib_nbextensions
|
||||
python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
python setup.py install
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -U pyopenssl idna
|
||||
python -m pip install black pytest
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
python -m pytest . --durations=0
|
||||
- name: Test workflow by config (install from source)
|
||||
run: |
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -34,7 +34,3 @@ tags
|
||||
|
||||
.pytest_cache/
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
|
||||
136
README.md
136
README.md
@@ -7,24 +7,6 @@
|
||||
[](LICENSE)
|
||||
[](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
## :newspaper: **What's NEW!** :sparkling_heart:
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
|
||||
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
|
||||
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
|
||||
| TCTS Model | [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 |
|
||||
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
|
||||
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
|
||||
| High-frequency data processing example | [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
|
||||
| High-frequency trading example | [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 |
|
||||
| High-frequency data(1min) | [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 |
|
||||
| Tabnet Model | [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
|
||||
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
|
||||
@@ -35,52 +17,41 @@ Qlib is an AI-oriented quantitative investment platform, which aims to realize t
|
||||
|
||||
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
|
||||
|
||||
With Qlib, users can easily try ideas to create better Quant investment strategies.
|
||||
With Qlib, user can easily try ideas to create better Quant investment strategies.
|
||||
|
||||
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
|
||||
|
||||
- [**Plans**](#plans)
|
||||
- [Framework of Qlib](#framework-of-qlib)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Installation](#installation)
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
|
||||
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
|
||||
- [**Quant Model(Paper) Zoo**](#quant-model-paper-zoo)
|
||||
- [**Quant Model Zoo**](#quant-model-zoo)
|
||||
- [Run a single model](#run-a-single-model)
|
||||
- [Run multiple models](#run-multiple-models)
|
||||
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
|
||||
- [High-frequency execution](#high-frequency-execution)
|
||||
- [More About Qlib](#more-about-qlib)
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
|
||||
- [Related Reports](#related-reports)
|
||||
- [Contact Us](#contact-us)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
|
||||
# Plans
|
||||
New features under development(order by estimated release time).
|
||||
Your feedbacks about the features are very important.
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Planning-based portfolio optimization | Under review: https://github.com/microsoft/qlib/pull/280 |
|
||||
| Fund data supporting and analysis | Under review: https://github.com/microsoft/qlib/pull/292 |
|
||||
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
|
||||
| High-frequency trading | Under review: https://github.com/microsoft/qlib/pull/408 |
|
||||
| Meta-Learning-based data selection | Initial opensource version under development |
|
||||
|
||||
# Framework of Qlib
|
||||
|
||||
<div style="align: center">
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.2" />
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.1" />
|
||||
</div>
|
||||
|
||||
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules, and each component could be used stand-alone.
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
|
||||
|
||||
| Name | Description |
|
||||
| ------ | ----- |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides a high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides a flexible interface to control the training process of models, which enable algorithms to control the training process. |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides flexible interface to control the training process of models which enable algorithms controlling the training process. |
|
||||
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
|
||||
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
|
||||
|
||||
@@ -108,9 +79,8 @@ This table demonstrates the supported Python version of `Qlib`:
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
|
||||
**Note**:
|
||||
1. **Conda** is suggested for managing your Python environment.
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
|
||||
### Install with pip
|
||||
Users can easily install ``Qlib`` by pip according to the following command.
|
||||
@@ -149,42 +119,14 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
## Data Preparation
|
||||
Load and prepare data by running the following code:
|
||||
```bash
|
||||
# get 1d data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# get 1min data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
|
||||
```
|
||||
|
||||
This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in
|
||||
the same repository.
|
||||
Users could create the same dataset with it.
|
||||
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
|
||||
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
### Automatic update of daily frequency data (from yahoo finance)
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
|
||||
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
|
||||
* Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
* use *crontab*: `crontab -e`
|
||||
* set up timed tasks:
|
||||
|
||||
```
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
```
|
||||
* **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
* Manual update of data
|
||||
```
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
```
|
||||
* *trading_date*: start of trading day
|
||||
* *end_date*: end of trading day(not included)
|
||||
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
<!--
|
||||
- Run the initialization code and get stock data:
|
||||
@@ -272,30 +214,25 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
- Rank Label
|
||||

|
||||
-->
|
||||
- [Explanation](https://qlib.readthedocs.io/en/latest/component/report.html) of above results
|
||||
|
||||
## Building Customized Quant Research Workflow by Code
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
|
||||
# [Quant Model (Paper) Zoo](examples/benchmarks)
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](qlib/contrib/model/xgboost.py)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](qlib/contrib/model/catboost_model.py)
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. 2016)](qlib/contrib/model/xgboost.py)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al. 2017)](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. 2017)](qlib/contrib/model/catboost_model.py)
|
||||
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural omputation 1997)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al. 1997)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](qlib/contrib/model/pytorch_gru.py)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al. 2017)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](qlib/contrib/model/pytorch_tcts.py)
|
||||
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py)
|
||||
- [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -305,13 +242,13 @@ The performance of each model on the `Alpha158` and `Alpha360` dataset can be fo
|
||||
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
|
||||
|
||||
`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
|
||||
- Users can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
|
||||
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.
|
||||
|
||||
@@ -334,6 +271,14 @@ Dataset plays a very important role in Quant. Here is a list of the datasets bui
|
||||
[Here](https://qlib.readthedocs.io/en/latest/advanced/alpha.html) is a tutorial to build dataset with `Qlib`.
|
||||
Your PR to build new Quant dataset is highly welcomed.
|
||||
|
||||
# High-Frequency Execution
|
||||
High-frequency order execution is a fundamental problem in quantitative finance.
|
||||
It aims at fulfilling a specific trading order, either liquidation or acquirement, for a given instrument.
|
||||
AI has the potential to mine patterns from a huge mass of high-frequency market data and helps traders make better decisions during order execution.
|
||||
Here is a list of solutions built on `Qlib`.
|
||||
- [Universal Trading for Order Execution with Oracle Policy Distillation](examples/trade/)
|
||||
|
||||
|
||||
# More About Qlib
|
||||
The detailed documents are organized in [docs](docs/).
|
||||
[Sphinx](http://www.sphinx-doc.org) and the readthedocs theme is required to build the documentation in html formats.
|
||||
@@ -371,34 +316,21 @@ which creates a dataset (14 features/factors) from the basic OHLCV daily data of
|
||||
* `+(-)E` indicates with (out) `ExpressionCache`
|
||||
* `+(-)D` indicates with (out) `DatasetCache`
|
||||
|
||||
Most general-purpose databases take too much time to load data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
|
||||
Most general-purpose databases take too much time on loading data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
|
||||
Such overheads greatly slow down the data loading process.
|
||||
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
|
||||
|
||||
|
||||
# Related Reports
|
||||
- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/)
|
||||
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
|
||||
- [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
|
||||
- [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
|
||||
|
||||
# Contact Us
|
||||
- If you have any issues, please create issue [here](https://github.com/microsoft/qlib/issues/new/choose) or send messages in [gitter](https://gitter.im/Microsoft/qlib).
|
||||
- If you want to make contributions to `Qlib`, please [create pull requests](https://github.com/microsoft/qlib/compare).
|
||||
- For other reasons, you are welcome to contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)).
|
||||
- We are recruiting new members(both FTEs and interns), your resumes are welcome!
|
||||
|
||||
Join IM discussion groups:
|
||||
|[Gitter](https://gitter.im/Microsoft/qlib)|
|
||||
|----|
|
||||
||
|
||||
|
||||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions.
|
||||
**Here are some
|
||||
[code standards](docs/developer/code_standard.rst) when you submit a pull request.**
|
||||
|
||||
|
||||
Most contributions require you to agree to a
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the right to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
|
||||
@@ -70,84 +70,3 @@ If the issue is not resolved, use ``keys *`` to find if multiple keys exist. If
|
||||
|
||||
|
||||
Also, feel free to post a new issue in our GitHub repository. We always check each issue carefully and try our best to solve them.
|
||||
|
||||
3. ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
File "qlib/qlib/__init__.py", line 19, in init
|
||||
from .data.cache import H
|
||||
File "qlib/qlib/data/__init__.py", line 8, in <module>
|
||||
from .data import (
|
||||
File "qlib/qlib/data/data.py", line 20, in <module>
|
||||
from .cache import H
|
||||
File "qlib/qlib/data/cache.py", line 36, in <module>
|
||||
from .ops import Operators
|
||||
File "qlib/qlib/data/ops.py", line 19, in <module>
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with ``PyCharm`` IDE, users can execute the following command in the project root folder to compile Cython files and generate executable files:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
|
||||
|
||||
|
||||
4. BadNamespaceError: / is not a connected namespace
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
File "qlib_online.py", line 35, in <module>
|
||||
cal = D.calendar()
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
|
||||
return Cal.calendar(start_time, end_time, freq, future=future)
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
|
||||
self.conn.send_request(
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\python_socketio-5.3.0-py3.8.egg\socketio\client.py", line 369, in emit
|
||||
raise exceptions.BadNamespaceError(
|
||||
BadNamespaceError: / is not a connected namespace.
|
||||
|
||||
- The version of ``python-socketio`` in qlib needs to be the same as the version of ``python-socketio`` in qlib-server:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U python-socketio==<qlib-server python-socketio version>
|
||||
|
||||
|
||||
5. TypeError: send() got an unexpected keyword argument 'binary'
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
File "qlib_online.py", line 35, in <module>
|
||||
cal = D.calendar()
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 973, in calendar
|
||||
return Cal.calendar(start_time, end_time, freq, future=future)
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\data.py", line 798, in calendar
|
||||
self.conn.send_request(
|
||||
File "e:\code\python\microsoft\qlib_latest\qlib\qlib\data\client.py", line 101, in send_request
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 263, in emit
|
||||
self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,
|
||||
File "G:\apps\miniconda\envs\qlib\lib\site-packages\socketio\client.py", line 339, in _send_packet
|
||||
self.eio.send(ep, binary=binary)
|
||||
TypeError: send() got an unexpected keyword argument 'binary'
|
||||
|
||||
|
||||
- The ``python-engineio`` version needs to be compatible with the ``python-socketio`` version, reference: https://github.com/miguelgrinberg/python-socketio#version-compatibility
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U python-engineio==<compatible python-socketio version>
|
||||
# or
|
||||
pip install -U python-socketio==3.1.2 python-engineio==3.13.2
|
||||
|
||||
BIN
docs/_static/img/framework.png
vendored
BIN
docs/_static/img/framework.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 208 KiB After Width: | Height: | Size: 271 KiB |
BIN
docs/_static/img/online_serving.png
vendored
BIN
docs/_static/img/online_serving.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 440 KiB |
BIN
docs/_static/img/qrcode/gitter_qr.png
vendored
BIN
docs/_static/img/qrcode/gitter_qr.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 7.2 KiB |
@@ -1,45 +0,0 @@
|
||||
.. _serial:
|
||||
|
||||
=================================
|
||||
Serialization
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
``Qlib`` supports dumping the state of ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc. into a disk and reloading them.
|
||||
|
||||
Serializable Class
|
||||
========================
|
||||
|
||||
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
|
||||
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
|
||||
However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
|
||||
|
||||
Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here <https://pypi.org/project/dill/>`_).
|
||||
|
||||
Example
|
||||
==========================
|
||||
``Qlib``'s serializable class includes ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc., which are subclass of ``qlib.utils.serial.Serializable``.
|
||||
Specifically, ``qlib.data.dataset.DatasetH`` is one of them. Users can serialize ``DatasetH`` as follows.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
##=============dump dataset=============
|
||||
dataset.to_pickle(path="dataset.pkl") # dataset is an instance of qlib.data.dataset.DatasetH
|
||||
|
||||
##=============reload dataset=============
|
||||
with open("dataset.pkl", "rb") as file_dataset:
|
||||
dataset = pickle.load(file_dataset)
|
||||
|
||||
.. note::
|
||||
Only state of ``DatasetH`` should be saved on the disk, such as some `mean` and `variance` used for data normalization, etc.
|
||||
|
||||
After reloading the ``DatasetH``, users need to reinitialize it. It means that users can reset some states of ``DatasetH`` or ``QlibDataHandler`` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states (data is not state and should not be saved on the disk).
|
||||
|
||||
A more detailed example is in this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
|
||||
|
||||
|
||||
API
|
||||
===================
|
||||
Please refer to `Serializable API <../reference/api.html#module-qlib.utils.serial.Serializable>`_.
|
||||
@@ -1,89 +0,0 @@
|
||||
.. _task_management:
|
||||
|
||||
=================================
|
||||
Task Management
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
|
||||
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
|
||||
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
|
||||
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
|
||||
|
||||
This whole process can be used in `Online Serving <../component/online.html>`_.
|
||||
|
||||
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
|
||||
Task Generating
|
||||
===============
|
||||
A ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users.
|
||||
The specific task template can be viewed in
|
||||
`Task Section <../component/workflow.html#task-section>`_.
|
||||
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
|
||||
|
||||
Here is the base class of ``TaskGen``:
|
||||
|
||||
.. autoclass:: qlib.workflow.task.gen.TaskGen
|
||||
:members:
|
||||
|
||||
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
|
||||
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
|
||||
|
||||
Task Storing
|
||||
===============
|
||||
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
|
||||
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
|
||||
Users **MUST** finish the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
|
||||
|
||||
Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.config import C
|
||||
C["mongo"] = {
|
||||
"task_url" : "mongodb://localhost:27017/", # your MongoDB url
|
||||
"task_db_name" : "rolling_db" # database name
|
||||
}
|
||||
|
||||
.. autoclass:: qlib.workflow.task.manage.TaskManager
|
||||
:members:
|
||||
|
||||
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
|
||||
|
||||
Task Training
|
||||
===============
|
||||
After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
|
||||
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
|
||||
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
|
||||
It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
|
||||
|
||||
.. autofunction:: qlib.workflow.task.manage.run_task
|
||||
|
||||
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
|
||||
|
||||
.. autoclass:: qlib.model.trainer.Trainer
|
||||
:members:
|
||||
|
||||
``Trainer`` will train a list of tasks and return a list of model recorders.
|
||||
``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
|
||||
`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.
|
||||
|
||||
Task Collecting
|
||||
===============
|
||||
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
|
||||
|
||||
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
|
||||
|
||||
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
|
||||
|
||||
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
|
||||
For example: {C1: object, C2: object} ---``Ensemble``---> object
|
||||
|
||||
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
|
||||
|
||||
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
@@ -31,7 +31,7 @@ Qlib Format Data
|
||||
We've specially designed a data structure to manage financial data, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information.
|
||||
Such data will be stored with filename suffix `.bin` (We'll call them `.bin` file, `.bin` format, or qlib format). `.bin` file is designed for scientific computing on finance data.
|
||||
|
||||
``Qlib`` provides two different off-the-shelf datasets, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
|
||||
``Qlib`` provides two different off-the-shelf dataset, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
|
||||
|
||||
======================== ================= ================
|
||||
Dataset US Market China Market
|
||||
@@ -41,7 +41,6 @@ Alpha360 √ √
|
||||
Alpha158 √ √
|
||||
======================== ================= ================
|
||||
|
||||
Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency dataset example through this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
|
||||
|
||||
Qlib Format Dataset
|
||||
--------------------
|
||||
@@ -49,70 +48,31 @@ Qlib Format Dataset
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# download 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# download 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
|
||||
|
||||
In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
|
||||
|
||||
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/qlib_data/cn_data`` directory and ``~/.qlib/qlib_data/us_data`` directory respectively.
|
||||
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
|
||||
|
||||
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
|
||||
|
||||
When ``Qlib`` is initialized with this dataset, users could build and evaluate their own models with it. Please refer to `Initialization <../start/initialization.html>`_ for more details.
|
||||
|
||||
Automatic update of daily frequency data
|
||||
----------------------------------------
|
||||
|
||||
**It is recommended that users update the data manually once (\-\-trading_date 2021-05-25) and then set it to update automatically.**
|
||||
|
||||
For more information refer to: `yahoo collector <https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#Automatic-update-of-daily-frequency-data>`_
|
||||
|
||||
- Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
- use *crontab*: `crontab -e`
|
||||
- set up timed tasks:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
|
||||
- **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
- Manual update of data
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
|
||||
- *trading_date*: start of trading day
|
||||
- *end_date*: end of trading day(not included)
|
||||
|
||||
|
||||
|
||||
Converting CSV Format into Qlib Format
|
||||
-------------------------------------------
|
||||
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
|
||||
|
||||
Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format.
|
||||
Here are some example:
|
||||
Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
|
||||
|
||||
for daily data:
|
||||
.. code-block:: bash
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
|
||||
for 1min data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10
|
||||
|
||||
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
|
||||
|
||||
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
|
||||
@@ -179,17 +139,6 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
The Restoration factor. Normally, ``factor = adjusted_price / original_price``, `adjusted price` reference: `split adjusted <https://www.investopedia.com/terms/s/splitadjusted.asp>`_
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
If you want to use your own alpha-factor which can't be calculate by OCHLV, like PE, EPS and so on, you could add it to the CSV files with OHCLV together and then dump it to the Qlib format data.
|
||||
|
||||
Stock Pool (Market)
|
||||
--------------------------------
|
||||
|
||||
``Qlib`` defines `stock pool <https://github.com/microsoft/qlib/blob/main/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml#L4>`_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python collector.py --index_name CSI300 --qlib_dir <user qlib data dir> --method parse_instruments
|
||||
|
||||
|
||||
Multiple Stock Modes
|
||||
--------------------------------
|
||||
@@ -209,7 +158,7 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
- If users use ``Qlib`` in china-stock mode, china-stock data is required. Users can use ``Qlib`` in china-stock mode according to the following steps:
|
||||
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in china-stock mode
|
||||
Supposed that users download their Qlib format data in the directory ``~/.qlib/qlib_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
Supposed that users download their Qlib format data in the directory ``~/.qlib/csv_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -218,9 +167,9 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
|
||||
|
||||
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
|
||||
- Download us-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in US-stock mode
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/qlib_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -228,11 +177,6 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here <https://github.com/microsoft/qlib/tree/main/scripts>`_. And then we will use the code to create data cache on our server which other users could use directly.
|
||||
|
||||
|
||||
Data API
|
||||
========================
|
||||
|
||||
@@ -269,25 +213,6 @@ Filter
|
||||
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
|
||||
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
|
||||
|
||||
Here is a simple example showing how to use filter in a basic ``Qlib`` workflow configuration file:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
filter: &filter
|
||||
filter_type: ExpressionDFilter
|
||||
rule_expression: "Ref($close, -2) / Ref($close, -1) > 1"
|
||||
filter_start_time: 2010-01-01
|
||||
filter_end_time: 2010-01-07
|
||||
keep: False
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2010-01-01
|
||||
end_time: 2021-01-22
|
||||
fit_start_time: 2010-01-01
|
||||
fit_end_time: 2015-12-31
|
||||
instruments: *market
|
||||
filter_pipe: [*filter]
|
||||
|
||||
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
|
||||
|
||||
Reference
|
||||
@@ -349,10 +274,9 @@ Here are some important interfaces that ``DataHandlerLP`` provides:
|
||||
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
|
||||
:members: __init__, fetch, get_cols
|
||||
|
||||
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
|
||||
|
||||
If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``.
|
||||
|
||||
Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess methods for features defined by config into the new handler.
|
||||
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
|
||||
|
||||
|
||||
Processor
|
||||
@@ -389,6 +313,7 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
|
||||
.. note:: Users need to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <../start/initialization.html>`_.
|
||||
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
import qlib
|
||||
@@ -415,9 +340,6 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
# fetch all the features
|
||||
print(h.fetch(col_set="feature"))
|
||||
|
||||
|
||||
.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day.
|
||||
|
||||
API
|
||||
---------
|
||||
|
||||
@@ -442,7 +364,8 @@ The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most im
|
||||
API
|
||||
---------
|
||||
|
||||
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#dataset>`_.
|
||||
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
|
||||
|
||||
|
||||
|
||||
Cache
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
.. _online:
|
||||
|
||||
=================================
|
||||
Online Serving
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
|
||||
.. image:: ../_static/img/online_serving.png
|
||||
:align: center
|
||||
|
||||
|
||||
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
|
||||
``Online Serving`` is a set of modules for online models using the latest data,
|
||||
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
|
||||
|
||||
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
|
||||
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
|
||||
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
|
||||
|
||||
**NOTE**: User should keep his data source updated to support online serving. For example, Qlib provides `a batch of scripts <https://github.com/microsoft/qlib/blob/main/scripts/data_collector/yahoo/README.md#automatic-update-of-daily-frequency-datafrom-yahoo-finance>`_ to help users update Yahoo daily data.
|
||||
|
||||
Online Manager
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
Updater
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
@@ -34,7 +34,6 @@ Here is a general view of the structure of the system:
|
||||
- Recorder 2
|
||||
- ...
|
||||
- ...
|
||||
|
||||
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
|
||||
If users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, pleaes refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.
|
||||
@@ -95,52 +94,6 @@ The ``RecordTemp`` class is a class that enables generate experiment results suc
|
||||
|
||||
- ``SignalRecord``: This class generates the `prediction` results of the model.
|
||||
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
|
||||
|
||||
Here is a simple example of what is done in ``SigAnaRecord``, which users can refer to if they want to calculate IC, Rank IC, Long-Short Return with their own prediction and label.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
from qlib.contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
|
||||
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
|
||||
|
||||
Here is a simple exampke of what is done in ``PortAnaRecord``, which users can refer to if they want to do backtest based on their own prediction and label.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
|
||||
# backtest
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
|
||||
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
|
||||
|
||||
@@ -101,7 +101,7 @@ Graphical Result
|
||||
- Axis Y:
|
||||
- `ic`
|
||||
The `Pearson correlation coefficient` series between `label` and `prediction score`.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Featrue <data.html#feature>`_ for more details.
|
||||
|
||||
- `rank_ic`
|
||||
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.
|
||||
|
||||
@@ -111,6 +111,8 @@ Usage & Example
|
||||
pred_score, strategy=strategy, **BACKTEST_CONFIG
|
||||
)
|
||||
|
||||
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
|
||||
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.
|
||||
|
||||
@@ -90,12 +90,12 @@ Below is a typical config file of ``qrun``.
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
|
||||
|
||||
@@ -142,7 +142,7 @@ The meaning of each field is as follows:
|
||||
|
||||
- `region`
|
||||
- If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
|
||||
- If `region` == "cn", ``Qlib`` will be initialized in China-stock mode.
|
||||
- If `region` == "cn", ``Qlib`` will be initialized in china-stock mode.
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
.. _code_standard:
|
||||
|
||||
=================================
|
||||
Code Standard
|
||||
=================================
|
||||
|
||||
Docstring
|
||||
=================================
|
||||
Please use the `Numpydoc Style <https://stackoverflow.com/a/24385103>`_.
|
||||
|
||||
Continuous Integration
|
||||
=================================
|
||||
Continuous Integration (CI) tools help you stick to the quality standards by running tests every time you push a new commit and reporting the results to a pull request.
|
||||
|
||||
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pip install black
|
||||
python -m black . -l 120
|
||||
@@ -42,7 +42,6 @@ Document Structure
|
||||
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
|
||||
Qlib Recorder: Experiment Management <component/recorder.rst>
|
||||
Analysis: Evaluation & Results Analysis <component/report.rst>
|
||||
Online Serving: Online Management & Strategy & Tool <component/online.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
@@ -50,8 +49,6 @@ Document Structure
|
||||
|
||||
Building Formulaic Alphas <advanced/alpha.rst>
|
||||
Online & Offline mode <advanced/server.rst>
|
||||
Serialization <advanced/serial.rst>
|
||||
Task Management <advanced/task_management.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -53,34 +53,6 @@ Cache
|
||||
.. autoclass:: qlib.data.cache.DiskDatasetCache
|
||||
:members:
|
||||
|
||||
|
||||
Storage
|
||||
-------------
|
||||
.. autoclass:: qlib.data.storage.storage.BaseStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.CalendarStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.InstrumentStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.FeatureStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage
|
||||
:members:
|
||||
|
||||
|
||||
Dataset
|
||||
---------------
|
||||
|
||||
@@ -180,81 +152,4 @@ Recorder
|
||||
Record Template
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.record_temp
|
||||
:members:
|
||||
|
||||
Task Management
|
||||
====================
|
||||
|
||||
|
||||
TaskGen
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.gen
|
||||
:members:
|
||||
|
||||
TaskManager
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.manage
|
||||
:members:
|
||||
|
||||
Trainer
|
||||
--------------------
|
||||
.. automodule:: qlib.model.trainer
|
||||
:members:
|
||||
|
||||
Collector
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.collect
|
||||
:members:
|
||||
|
||||
Group
|
||||
--------------------
|
||||
.. automodule:: qlib.model.ens.group
|
||||
:members:
|
||||
|
||||
Ensemble
|
||||
--------------------
|
||||
.. automodule:: qlib.model.ens.ensemble
|
||||
:members:
|
||||
|
||||
Utils
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.utils
|
||||
:members:
|
||||
|
||||
|
||||
Online Serving
|
||||
====================
|
||||
|
||||
|
||||
Online Manager
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
RecordUpdater
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
|
||||
|
||||
Utils
|
||||
====================
|
||||
|
||||
Serializable
|
||||
--------------------
|
||||
|
||||
.. automodule:: qlib.utils.serial.Serializable
|
||||
:members:
|
||||
|
||||
|
||||
|
||||
:members:
|
||||
@@ -75,14 +75,3 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
"default_exp_name": "Experiment",
|
||||
}
|
||||
})
|
||||
- `mongo`
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# For example, you can initialize qlib below
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
|
||||
"task_url": "mongodb://localhost:27017/", # your mongo url
|
||||
"task_db_name": "rolling_db", # the database name of Task Management
|
||||
})
|
||||
|
||||
@@ -82,7 +82,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
- Override the `finetune` method (Optional)
|
||||
- This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
|
||||
- This method is optional to the users, and when users one to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
|
||||
- The parameters must include the parameter `dataset`.
|
||||
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
|
||||
.. code-block:: Python
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
# DoubleEnsemble
|
||||
* DoubleEnsemble is an ensemble framework leveraging learning trajectory based sample reweighting and shuffling based feature selection, to solve both the low signal-to-noise ratio and increasing number of features problems. They identify the key samples based on the training dynamics on each sample and elicit key features based on the ablation impact of each feature via shuffling. The model is applicable to a wide range of base models, capable of extracting complex patterns, while mitigating the overfitting and instability issues for financial market prediction.
|
||||
* This code used in Qlib is implemented by ourselves.
|
||||
* Paper: DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis [https://arxiv.org/pdf/2010.01265.pdf](https://arxiv.org/pdf/2010.01265.pdf).
|
||||
@@ -1,3 +0,0 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
lightgbm==3.1.0
|
||||
@@ -1,90 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 28
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,97 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 136
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -61,6 +61,7 @@ task:
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
with_pretrain: True
|
||||
model_path: "benchmarks/LSTM/csi300_lstm_ts.pkl"
|
||||
GPU: 0
|
||||
dataset:
|
||||
|
||||
@@ -54,6 +54,7 @@ task:
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
with_pretrain: True
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
GPU: 0
|
||||
dataset:
|
||||
@@ -80,4 +81,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
@@ -29,7 +29,7 @@ data_handler_config: &data_handler_config
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
import datetime
|
||||
import pandas as pd
|
||||
|
||||
from qlib.data.inst_processor import InstProcessor
|
||||
|
||||
|
||||
class Resample1minProcessor(InstProcessor):
|
||||
def __init__(self, hour: int, minute: int, **kwargs):
|
||||
self.hour = hour
|
||||
self.minute = minute
|
||||
|
||||
def __call__(self, df: pd.DataFrame, *args, **kwargs):
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df.loc[df.index.time == datetime.time(self.hour, self.minute)]
|
||||
df.index = df.index.normalize()
|
||||
return df
|
||||
@@ -1,83 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri:
|
||||
day: "~/.qlib/qlib_data/cn_data"
|
||||
1min: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
dataset_cache: null
|
||||
maxtasksperchild: 1
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
# 1min closing time is 15:00:00
|
||||
end_time: "2020-08-01 15:00:00"
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
freq:
|
||||
label: day
|
||||
feature: 1min
|
||||
# with label as reference
|
||||
inst_processor:
|
||||
feature:
|
||||
- class: Resample1minProcessor
|
||||
module_path: features_sample.py
|
||||
kwargs:
|
||||
hour: 14
|
||||
minute: 56
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,81 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
instruments: *market
|
||||
data_loader:
|
||||
class: QlibDataLoader
|
||||
kwargs:
|
||||
config:
|
||||
feature:
|
||||
- ["Resi($close, 15)/$close", "Std(Abs($close/Ref($close, 1)-1)*$volume, 5)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, 5)+1e-12)", "Rsquare($close, 5)", "($high-$low)/$open", "Rsquare($close, 10)", "Corr($close, Log($volume+1), 5)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 5)", "Corr($close, Log($volume+1), 10)", "Rsquare($close, 20)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 60)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 10)", "Corr($close, Log($volume+1), 20)", "(Less($open, $close)-$low)/$open"]
|
||||
- ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
|
||||
label:
|
||||
- ["Ref($close, -2)/Ref($close, -1) - 1"]
|
||||
- ["LABEL0"]
|
||||
freq: day
|
||||
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSZScoreNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandlerLP
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,3 +0,0 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
torch==1.2.0
|
||||
@@ -1,82 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LocalformerModel
|
||||
module_path: qlib.contrib.model.pytorch_localformer_ts
|
||||
kwargs:
|
||||
seed: 0
|
||||
n_jobs: 20
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,73 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LocalformerModel
|
||||
module_path: qlib.contrib.model.pytorch_localformer
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
seed: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,13 +1,9 @@
|
||||
# Benchmarks Performance
|
||||
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs.
|
||||
|
||||
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
|
||||
|
||||
> If you need to reproduce the results below, please use the **v1** dataset: `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn --version v1`
|
||||
>
|
||||
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ
|
||||
|
||||
## Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
@@ -20,12 +16,6 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0491±0.01 | 0.3868±0.06 | 0.0589±0.00 | 0.4802±0.04 | 0.0898±0.02 | 1.2490±0.32 | -0.0778±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -35,17 +25,11 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TabNet with pretrain (Sercan O. Arikm et al) | Alpha158 | 0.0344±0.00|0.205±0.11|0.0398±0.00 |0.3479±0.01|0.0827±0.02|1.1141±0.32 |-0.0925±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features)| 0.0409±0.00 | 0.3253±0.04 | 0.0488±0.00 | 0.4045±0.02 | 0.0673±0.02 | 1.0389±0.39 | -0.0830±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0442±0.00 | 0.3426±0.03 | 0.0555±0.00 | 0.4395±0.03 | 0.0833±0.03 | 1.2064±0.36 | -0.0849±0.02 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
# Temporally Correlated Task Scheduling for Sequence Learning
|
||||
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
|
||||
|
||||
### Background
|
||||
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
|
||||
|
||||
<p align="center">
|
||||
<img src="task_description.png" width="600" height="200"/>
|
||||
</p>
|
||||
|
||||
|
||||
### Method
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
|
||||
<p align="center">
|
||||
<img src="workflow.png"/>
|
||||
</p>
|
||||
|
||||
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
|
||||
|
||||
### DataSet
|
||||
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
|
||||
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
|
||||
### Experiments
|
||||
#### Task Description
|
||||
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
|
||||
<div align=center>
|
||||
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
|
||||
</div>
|
||||
|
||||
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
|
||||
#### Baselines
|
||||
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
|
||||
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
|
||||
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
|
||||
#### Result
|
||||
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
|
||||
| :----: | :----: | :----: | :----: |
|
||||
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
|
||||
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
|
||||
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
|
||||
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 25 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 29 KiB |
@@ -1,93 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -1) / $close - 1",
|
||||
"Ref($close, -2) / Ref($close, -1) - 1",
|
||||
"Ref($close, -3) / Ref($close, -2) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TCTS
|
||||
module_path: qlib.contrib.model.pytorch_tcts
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
GPU: 0
|
||||
fore_optimizer: adam
|
||||
weight_optimizer: adam
|
||||
output_dim: 3
|
||||
fore_lr: 5e-4
|
||||
weight_lr: 5e-4
|
||||
steps: 3
|
||||
target_label: 1
|
||||
lowest_valid_performance: 0.993
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
label_col: 1
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC):
|
||||
return -1, -1
|
||||
|
||||
def get_column_definition(self):
|
||||
"""Returns formatted column definition in order expected by the TFT."""
|
||||
""""Returns formatted column definition in order expected by the TFT."""
|
||||
|
||||
column_definition = self._column_definition
|
||||
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
# Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport
|
||||
|
||||
Temporal Routing Adaptor (TRA) is designed to capture multiple trading patterns in the stock market data. Please refer to [our paper](http://arxiv.org/abs/2106.12950) for more details.
|
||||
|
||||
If you find our work useful in your research, please cite:
|
||||
```
|
||||
@inproceedings{HengxuKDD2021,
|
||||
author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian},
|
||||
title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport},
|
||||
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
|
||||
series = {KDD '21},
|
||||
year = {2021},
|
||||
publisher = {ACM},
|
||||
}
|
||||
|
||||
@article{yang2020qlib,
|
||||
title={Qlib: An AI-oriented Quantitative Investment Platform},
|
||||
author={Yang, Xiao and Liu, Weiqing and Zhou, Dong and Bian, Jiang and Liu, Tie-Yan},
|
||||
journal={arXiv preprint arXiv:2009.11189},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage (Recommended)
|
||||
|
||||
**Update**: `TRA` has been moved to `qlib.contrib.model.pytorch_tra` to support other `Qlib` components like `qlib.workflow` and `Alpha158/Alpha360` dataset.
|
||||
|
||||
Please follow the official [doc](https://qlib.readthedocs.io/en/latest/component/workflow.html) to use `TRA` with `workflow`. Here we also provide several example config files:
|
||||
|
||||
- `workflow_config_tra_Alpha360.yaml`: running `TRA` with `Alpha360` dataset
|
||||
- `workflow_config_tra_Alpha158.yaml`: running `TRA` with `Alpha158` dataset (with feature subsampling)
|
||||
- `workflow_config_tra_Alpha158_full.yaml`: running `TRA` with `Alpha158` dataset (without feature subsampling)
|
||||
|
||||
The performances of `TRA` are reported in [Benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).
|
||||
|
||||
## Usage (Not Maintained)
|
||||
|
||||
This section is used to reproduce the results in the paper.
|
||||
|
||||
### Running
|
||||
|
||||
We attach our running scripts for the paper in `run.sh`.
|
||||
|
||||
And here are two ways to run the model:
|
||||
|
||||
* Running from scripts with default parameters
|
||||
|
||||
You can directly run from Qlib command `qrun`:
|
||||
```
|
||||
qrun configs/config_alstm.yaml
|
||||
```
|
||||
|
||||
* Running from code with self-defined parameters
|
||||
|
||||
Setting different parameters is also allowed. See codes in `example.py`:
|
||||
```
|
||||
python example.py --config_file configs/config_alstm.yaml
|
||||
```
|
||||
|
||||
Here we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scipts.
|
||||
|
||||
### Results
|
||||
|
||||
After running the scripts, you can find result files in path `./output`:
|
||||
|
||||
* `info.json` - config settings and result metrics.
|
||||
* `log.csv` - running logs.
|
||||
* `model.bin` - the model parameter dictionary.
|
||||
* `pred.pkl` - the prediction scores and output for inference.
|
||||
|
||||
Evaluation metrics reported in the paper:
|
||||
|
||||
| Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD |
|
||||
|-------|-------|------|-----|-----|-----|-----|-----|-----|
|
||||
|Linear|0.163|0.327|0.020|0.132|-3.2%|16.8%|-0.191|32.1%|
|
||||
|LightGBM|0.160(0.000)|0.323(0.000)|0.041|0.292|7.8%|15.5%|0.503|25.7%|
|
||||
|MLP|0.160(0.002)|0.323(0.003)|0.037|0.273|3.7%|15.3%|0.264|26.2%|
|
||||
|SFM|0.159(0.001) |0.321(0.001) |0.047 |0.381 |7.1% |14.3% |0.497 |22.9%|
|
||||
|ALSTM|0.158(0.001) |0.320(0.001) |0.053 |0.419 |12.3% |13.7% |0.897 |20.2%|
|
||||
|Trans.|0.158(0.001) |0.322(0.001) |0.051 |0.400 |14.5% |14.2% |1.028 |22.5%|
|
||||
|ALSTM+TS|0.160(0.002) |0.321(0.002) |0.039 |0.291 |6.7% |14.6% |0.480|22.3%|
|
||||
|Trans.+TS|0.160(0.004) |0.324(0.005) |0.037 |0.278 |10.4% |14.7% |0.722 |23.7%|
|
||||
|ALSTM+TRA(Ours)|0.157(0.000) |0.318(0.000) |0.059 |0.460 |12.4% |14.0% |0.885 |20.4%|
|
||||
|Trans.+TRA(Ours)|0.157(0.000) |0.320(0.000) |0.056 |0.442 |16.1% |14.2% |1.133 |23.1%|
|
||||
|
||||
A more detailed demo for our experiment results in the paper can be found in `Report.ipynb`.
|
||||
|
||||
## Common Issues
|
||||
|
||||
For help or issues using TRA, please submit a GitHub issue.
|
||||
|
||||
Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important.
|
||||
File diff suppressed because one or more lines are too long
@@ -1,63 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
num_heads: 2
|
||||
use_attn: True
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 1
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/alstm
|
||||
model_type: LSTM
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
@@ -1,63 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
num_heads: 2
|
||||
use_attn: True
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 10
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0001
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/alstm_tra
|
||||
model_type: LSTM
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 2.0
|
||||
rho: 0.99
|
||||
freeze_model: True
|
||||
model_init_state: output/test/alstm_tra_init/model.bin
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
@@ -1,63 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
num_heads: 2
|
||||
use_attn: True
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/alstm_tra_init
|
||||
model_type: LSTM
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 512
|
||||
@@ -1,63 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
num_heads: 4
|
||||
use_attn: False
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 1
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/transformer
|
||||
model_type: Transformer
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
@@ -1,63 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
num_heads: 4
|
||||
use_attn: False
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0005
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/transformer_tra
|
||||
model_type: Transformer
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: True
|
||||
model_init_state: output/test/transformer_tra_init/model.bin
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 512
|
||||
@@ -1,63 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
data_loader_config: &data_loader_config
|
||||
class: StaticDataLoader
|
||||
module_path: qlib.data.dataset.loader
|
||||
kwargs:
|
||||
config:
|
||||
feature: data/feature.pkl
|
||||
label: data/label.pkl
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 16
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
num_heads: 4
|
||||
use_attn: False
|
||||
dropout: 0.1
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: src/model.py
|
||||
kwargs:
|
||||
lr: 0.0002
|
||||
n_epochs: 500
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 20
|
||||
seed: 1000
|
||||
logdir: output/test/transformer_tra_init
|
||||
model_type: Transformer
|
||||
model_config: *model_config
|
||||
tra_config: *tra_config
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
freeze_model: False
|
||||
model_init_state:
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: src/dataset.py
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandler
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs:
|
||||
data_loader: *data_loader_config
|
||||
segments:
|
||||
train: [2007-10-30, 2016-05-27]
|
||||
valid: [2016-09-26, 2018-05-29]
|
||||
test: [2018-09-21, 2020-06-30]
|
||||
seq_len: 60
|
||||
horizon: 21
|
||||
num_states: *num_states
|
||||
batch_size: 512
|
||||
@@ -1 +0,0 @@
|
||||
Data Link: https://drive.google.com/drive/folders/1fMqZYSeLyrHiWmVzygeI4sw3vp5Gt8cY?usp=sharing
|
||||
@@ -1,39 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import qlib
|
||||
import ruamel.yaml as yaml
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
|
||||
def main(seed, config_file="configs/config_alstm.yaml"):
|
||||
|
||||
# set random seed
|
||||
with open(config_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
|
||||
seed_suffix = ""
|
||||
config["task"]["model"]["kwargs"].update(
|
||||
{"seed": seed, "logdir": config["task"]["model"]["kwargs"]["logdir"] + seed_suffix}
|
||||
)
|
||||
|
||||
# initialize workflow
|
||||
qlib.init(
|
||||
provider_uri=config["qlib_init"]["provider_uri"],
|
||||
region=config["qlib_init"]["region"],
|
||||
)
|
||||
dataset = init_instance_by_config(config["task"]["dataset"])
|
||||
model = init_instance_by_config(config["task"]["model"])
|
||||
|
||||
# train model
|
||||
model.fit(dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# set params from cmd
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--seed", type=int, default=1000, help="random seed")
|
||||
parser.add_argument("--config_file", type=str, default="configs/config_alstm.yaml", help="config file")
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
@@ -1,29 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# we used random seed(1 1000 2000 3000 4000 5000) in our experiments
|
||||
|
||||
# Directly run from Qlib command `qrun`
|
||||
qrun configs/config_alstm.yaml
|
||||
|
||||
qrun configs/config_transformer.yaml
|
||||
|
||||
qrun configs/config_transformer_tra_init.yaml
|
||||
qrun configs/config_transformer_tra.yaml
|
||||
|
||||
qrun configs/config_alstm_tra_init.yaml
|
||||
qrun configs/config_alstm_tra.yaml
|
||||
|
||||
|
||||
# Or setting different parameters with example.py
|
||||
python example.py --config_file configs/config_alstm.yaml
|
||||
|
||||
python example.py --config_file configs/config_transformer.yaml
|
||||
|
||||
python example.py --config_file configs/config_transformer_tra_init.yaml
|
||||
python example.py --config_file configs/config_transformer_tra.yaml
|
||||
|
||||
python example.py --config_file configs/config_alstm_tra_init.yaml
|
||||
python example.py --config_file configs/config_alstm_tra.yaml
|
||||
|
||||
|
||||
|
||||
@@ -1,253 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return torch.tensor(x, dtype=torch.float, device=device)
|
||||
return x
|
||||
|
||||
|
||||
def _create_ts_slices(index, seq_len):
|
||||
"""
|
||||
create time series slices from pandas index
|
||||
|
||||
Args:
|
||||
index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
|
||||
seq_len (int): sequence length
|
||||
"""
|
||||
assert index.is_lexsorted(), "index should be sorted"
|
||||
|
||||
# number of dates for each code
|
||||
sample_count_by_codes = pd.Series(0, index=index).groupby(level=0).size().values
|
||||
|
||||
# start_index for each code
|
||||
start_index_of_codes = np.roll(np.cumsum(sample_count_by_codes), 1)
|
||||
start_index_of_codes[0] = 0
|
||||
|
||||
# all the [start, stop) indices of features
|
||||
# features btw [start, stop) are used to predict the `stop - 1` label
|
||||
slices = []
|
||||
for cur_loc, cur_cnt in zip(start_index_of_codes, sample_count_by_codes):
|
||||
for stop in range(1, cur_cnt + 1):
|
||||
end = cur_loc + stop
|
||||
start = max(end - seq_len, 0)
|
||||
slices.append(slice(start, end))
|
||||
slices = np.array(slices)
|
||||
|
||||
return slices
|
||||
|
||||
|
||||
def _get_date_parse_fn(target):
|
||||
"""get date parse function
|
||||
|
||||
This method is used to parse date arguments as target type.
|
||||
|
||||
Example:
|
||||
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
|
||||
get_date_parse_fn(20120101)('2017-01-01') => 20170101
|
||||
"""
|
||||
if isinstance(target, pd.Timestamp):
|
||||
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
|
||||
elif isinstance(target, str) and len(target) == 8:
|
||||
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
|
||||
elif isinstance(target, int):
|
||||
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
|
||||
else:
|
||||
_fn = lambda x: x
|
||||
return _fn
|
||||
|
||||
|
||||
class MTSDatasetH(DatasetH):
|
||||
"""Memory Augmented Time Series Dataset
|
||||
|
||||
Args:
|
||||
handler (DataHandler): data handler
|
||||
segments (dict): data split segments
|
||||
seq_len (int): time series sequence length
|
||||
horizon (int): label horizon (to mask historical loss for TRA)
|
||||
num_states (int): how many memory states to be added (for TRA)
|
||||
batch_size (int): batch size (<0 means daily batch)
|
||||
shuffle (bool): whether shuffle data
|
||||
pin_memory (bool): whether pin data to gpu memory
|
||||
drop_last (bool): whether drop last batch < batch_size
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler,
|
||||
segments,
|
||||
seq_len=60,
|
||||
horizon=0,
|
||||
num_states=1,
|
||||
batch_size=-1,
|
||||
shuffle=True,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
assert horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.horizon = horizon
|
||||
self.num_states = num_states
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.drop_last = drop_last
|
||||
self.pin_memory = pin_memory
|
||||
self.params = (batch_size, drop_last, shuffle) # for train/eval switch
|
||||
|
||||
super().__init__(handler, segments, **kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
|
||||
super().setup_data()
|
||||
|
||||
# change index to <code, date>
|
||||
# NOTE: we will use inplace sort to reduce memory use
|
||||
df = self.handler._data
|
||||
df.index = df.index.swaplevel()
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
self._data = df["feature"].values.astype("float32")
|
||||
self._label = df["label"].squeeze().astype("float32")
|
||||
self._index = df.index
|
||||
|
||||
# add memory to feature
|
||||
self._data = np.c_[self._data, np.zeros((len(self._data), self.num_states), dtype=np.float32)]
|
||||
|
||||
# padding tensor
|
||||
self.zeros = np.zeros((self.seq_len, self._data.shape[1]), dtype=np.float32)
|
||||
|
||||
# pin memory
|
||||
if self.pin_memory:
|
||||
self._data = _to_tensor(self._data)
|
||||
self._label = _to_tensor(self._label)
|
||||
self.zeros = _to_tensor(self.zeros)
|
||||
|
||||
# create batch slices
|
||||
self.batch_slices = _create_ts_slices(self._index, self.seq_len)
|
||||
|
||||
# create daily slices
|
||||
index = [slc.stop - 1 for slc in self.batch_slices]
|
||||
act_index = self.restore_index(index)
|
||||
daily_slices = {date: [] for date in sorted(act_index.unique(level=1))}
|
||||
for i, (code, date) in enumerate(act_index):
|
||||
daily_slices[date].append(self.batch_slices[i])
|
||||
self.daily_slices = list(daily_slices.values())
|
||||
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
fn = _get_date_parse_fn(self._index[0][1])
|
||||
start_date = fn(slc.start)
|
||||
end_date = fn(slc.stop)
|
||||
obj = copy.copy(self) # shallow copy
|
||||
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
|
||||
obj._data = self._data
|
||||
obj._label = self._label
|
||||
obj._index = self._index
|
||||
new_batch_slices = []
|
||||
for batch_slc in self.batch_slices:
|
||||
date = self._index[batch_slc.stop - 1][1]
|
||||
if start_date <= date <= end_date:
|
||||
new_batch_slices.append(batch_slc)
|
||||
obj.batch_slices = np.array(new_batch_slices)
|
||||
new_daily_slices = []
|
||||
for daily_slc in self.daily_slices:
|
||||
date = self._index[daily_slc[0].stop - 1][1]
|
||||
if start_date <= date <= end_date:
|
||||
new_daily_slices.append(daily_slc)
|
||||
obj.daily_slices = new_daily_slices
|
||||
return obj
|
||||
|
||||
def restore_index(self, index):
|
||||
if isinstance(index, torch.Tensor):
|
||||
index = index.cpu().numpy()
|
||||
return self._index[index]
|
||||
|
||||
def assign_data(self, index, vals):
|
||||
if isinstance(self._data, torch.Tensor):
|
||||
vals = _to_tensor(vals)
|
||||
elif isinstance(vals, torch.Tensor):
|
||||
vals = vals.detach().cpu().numpy()
|
||||
index = index.detach().cpu().numpy()
|
||||
self._data[index, -self.num_states :] = vals
|
||||
|
||||
def clear_memory(self):
|
||||
self._data[:, -self.num_states :] = 0
|
||||
|
||||
# TODO: better train/eval mode design
|
||||
def train(self):
|
||||
"""enable traning mode"""
|
||||
self.batch_size, self.drop_last, self.shuffle = self.params
|
||||
|
||||
def eval(self):
|
||||
"""enable evaluation mode"""
|
||||
self.batch_size = -1
|
||||
self.drop_last = False
|
||||
self.shuffle = False
|
||||
|
||||
def _get_slices(self):
|
||||
if self.batch_size < 0:
|
||||
slices = self.daily_slices.copy()
|
||||
batch_size = -1 * self.batch_size
|
||||
else:
|
||||
slices = self.batch_slices.copy()
|
||||
batch_size = self.batch_size
|
||||
return slices, batch_size
|
||||
|
||||
def __len__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
if self.drop_last:
|
||||
return len(slices) // batch_size
|
||||
return (len(slices) + batch_size - 1) // batch_size
|
||||
|
||||
def __iter__(self):
|
||||
slices, batch_size = self._get_slices()
|
||||
if self.shuffle:
|
||||
np.random.shuffle(slices)
|
||||
|
||||
for i in range(len(slices))[::batch_size]:
|
||||
if self.drop_last and i + batch_size > len(slices):
|
||||
break
|
||||
# get slices for this batch
|
||||
slices_subset = slices[i : i + batch_size]
|
||||
if self.batch_size < 0:
|
||||
slices_subset = np.concatenate(slices_subset)
|
||||
# collect data
|
||||
data = []
|
||||
label = []
|
||||
index = []
|
||||
for slc in slices_subset:
|
||||
_data = self._data[slc].clone() if self.pin_memory else self._data[slc].copy()
|
||||
if len(_data) != self.seq_len:
|
||||
if self.pin_memory:
|
||||
_data = torch.cat([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
|
||||
else:
|
||||
_data = np.concatenate([self.zeros[: self.seq_len - len(_data)], _data], axis=0)
|
||||
if self.num_states > 0:
|
||||
_data[-self.horizon :, -self.num_states :] = 0
|
||||
data.append(_data)
|
||||
label.append(self._label[slc.stop - 1])
|
||||
index.append(slc.stop - 1)
|
||||
# concate
|
||||
index = torch.tensor(index, device=device)
|
||||
if isinstance(data[0], torch.Tensor):
|
||||
data = torch.stack(data)
|
||||
label = torch.stack(label)
|
||||
else:
|
||||
data = _to_tensor(np.stack(data))
|
||||
label = _to_tensor(np.stack(label))
|
||||
# yield -> generator
|
||||
yield {"data": data, "label": label, "index": index}
|
||||
@@ -1,603 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
import math
|
||||
import json
|
||||
import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
class TRAModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
tra_config,
|
||||
model_type="LSTM",
|
||||
lr=1e-3,
|
||||
n_epochs=500,
|
||||
early_stop=50,
|
||||
smooth_steps=5,
|
||||
max_steps_per_epoch=None,
|
||||
freeze_model=False,
|
||||
model_init_state=None,
|
||||
lamb=0.0,
|
||||
rho=0.99,
|
||||
seed=0,
|
||||
logdir=None,
|
||||
eval_train=True,
|
||||
eval_test=False,
|
||||
avg_params=True,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
self.logger = get_module_logger("TRA")
|
||||
self.logger.info("TRA Model...")
|
||||
|
||||
self.model = eval(model_type)(**model_config).to(device)
|
||||
if model_init_state:
|
||||
self.model.load_state_dict(torch.load(model_init_state, map_location="cpu")["model"])
|
||||
if freeze_model:
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters()]))
|
||||
|
||||
self.tra = TRA(self.model.output_size, **tra_config).to(device)
|
||||
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters()]))
|
||||
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=lr)
|
||||
|
||||
self.model_config = model_config
|
||||
self.tra_config = tra_config
|
||||
self.lr = lr
|
||||
self.n_epochs = n_epochs
|
||||
self.early_stop = early_stop
|
||||
self.smooth_steps = smooth_steps
|
||||
self.max_steps_per_epoch = max_steps_per_epoch
|
||||
self.lamb = lamb
|
||||
self.rho = rho
|
||||
self.seed = seed
|
||||
self.logdir = logdir
|
||||
self.eval_train = eval_train
|
||||
self.eval_test = eval_test
|
||||
self.avg_params = avg_params
|
||||
|
||||
if self.tra.num_states > 1 and not self.eval_train:
|
||||
self.logger.warn("`eval_train` will be ignored when using TRA")
|
||||
|
||||
if self.logdir is not None:
|
||||
if os.path.exists(self.logdir):
|
||||
self.logger.warn(f"logdir {self.logdir} is not empty")
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
|
||||
self.fitted = False
|
||||
self.global_step = -1
|
||||
|
||||
def train_epoch(self, data_set):
|
||||
|
||||
self.model.train()
|
||||
self.tra.train()
|
||||
|
||||
data_set.train()
|
||||
|
||||
max_steps = self.n_epochs
|
||||
if self.max_steps_per_epoch is not None:
|
||||
max_steps = min(self.max_steps_per_epoch, self.n_epochs)
|
||||
|
||||
count = 0
|
||||
total_loss = 0
|
||||
total_count = 0
|
||||
for batch in tqdm(data_set, total=max_steps):
|
||||
count += 1
|
||||
if count > max_steps:
|
||||
break
|
||||
|
||||
self.global_step += 1
|
||||
|
||||
data, label, index = batch["data"], batch["label"], batch["index"]
|
||||
|
||||
feature = data[:, :, : -self.tra.num_states]
|
||||
hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]
|
||||
|
||||
hidden = self.model(feature)
|
||||
pred, all_preds, prob = self.tra(hidden, hist_loss)
|
||||
|
||||
loss = (pred - label).pow(2).mean()
|
||||
|
||||
L = (all_preds.detach() - label[:, None]).pow(2)
|
||||
L -= L.min(dim=-1, keepdim=True).values # normalize & ensure postive input
|
||||
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
|
||||
if prob is not None:
|
||||
P = sinkhorn(-L, epsilon=0.01) # sample assignment matrix
|
||||
lamb = self.lamb * (self.rho ** self.global_step)
|
||||
reg = prob.log().mul(P).sum(dim=-1).mean()
|
||||
loss = loss - lamb * reg
|
||||
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_count += len(pred)
|
||||
|
||||
total_loss /= total_count
|
||||
|
||||
return total_loss
|
||||
|
||||
def test_epoch(self, data_set, return_pred=False):
|
||||
|
||||
self.model.eval()
|
||||
self.tra.eval()
|
||||
data_set.eval()
|
||||
|
||||
preds = []
|
||||
metrics = []
|
||||
for batch in tqdm(data_set):
|
||||
data, label, index = batch["data"], batch["label"], batch["index"]
|
||||
|
||||
feature = data[:, :, : -self.tra.num_states]
|
||||
hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :]
|
||||
|
||||
with torch.no_grad():
|
||||
hidden = self.model(feature)
|
||||
pred, all_preds, prob = self.tra(hidden, hist_loss)
|
||||
|
||||
L = (all_preds - label[:, None]).pow(2)
|
||||
|
||||
L -= L.min(dim=-1, keepdim=True).values # normalize & ensure postive input
|
||||
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
|
||||
X = np.c_[
|
||||
pred.cpu().numpy(),
|
||||
label.cpu().numpy(),
|
||||
]
|
||||
columns = ["score", "label"]
|
||||
if prob is not None:
|
||||
X = np.c_[X, all_preds.cpu().numpy(), prob.cpu().numpy()]
|
||||
columns += ["score_%d" % d for d in range(all_preds.shape[1])] + [
|
||||
"prob_%d" % d for d in range(all_preds.shape[1])
|
||||
]
|
||||
|
||||
pred = pd.DataFrame(X, index=index.cpu().numpy(), columns=columns)
|
||||
|
||||
metrics.append(evaluate(pred))
|
||||
|
||||
if return_pred:
|
||||
preds.append(pred)
|
||||
|
||||
metrics = pd.DataFrame(metrics)
|
||||
metrics = {
|
||||
"MSE": metrics.MSE.mean(),
|
||||
"MAE": metrics.MAE.mean(),
|
||||
"IC": metrics.IC.mean(),
|
||||
"ICIR": metrics.IC.mean() / metrics.IC.std(),
|
||||
}
|
||||
|
||||
if return_pred:
|
||||
preds = pd.concat(preds, axis=0)
|
||||
preds.index = data_set.restore_index(preds.index)
|
||||
preds.index = preds.index.swaplevel()
|
||||
preds.sort_index(inplace=True)
|
||||
|
||||
return metrics, preds
|
||||
|
||||
def fit(self, dataset, evals_result=dict()):
|
||||
|
||||
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
|
||||
|
||||
best_score = -1
|
||||
best_epoch = 0
|
||||
stop_rounds = 0
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
params_list = {
|
||||
"model": collections.deque(maxlen=self.smooth_steps),
|
||||
"tra": collections.deque(maxlen=self.smooth_steps),
|
||||
}
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
evals_result["test"] = []
|
||||
|
||||
# train
|
||||
self.fitted = True
|
||||
self.global_step = -1
|
||||
|
||||
if self.tra.num_states > 1:
|
||||
self.logger.info("init memory...")
|
||||
self.test_epoch(train_set)
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
self.logger.info("Epoch %d:", epoch)
|
||||
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(train_set)
|
||||
|
||||
self.logger.info("evaluating...")
|
||||
# average params for inference
|
||||
params_list["model"].append(copy.deepcopy(self.model.state_dict()))
|
||||
params_list["tra"].append(copy.deepcopy(self.tra.state_dict()))
|
||||
self.model.load_state_dict(average_params(params_list["model"]))
|
||||
self.tra.load_state_dict(average_params(params_list["tra"]))
|
||||
|
||||
# NOTE: during evaluating, the whole memory will be refreshed
|
||||
if self.tra.num_states > 1 or self.eval_train:
|
||||
train_set.clear_memory() # NOTE: clear the shared memory
|
||||
train_metrics = self.test_epoch(train_set)[0]
|
||||
evals_result["train"].append(train_metrics)
|
||||
self.logger.info("\ttrain metrics: %s" % train_metrics)
|
||||
|
||||
valid_metrics = self.test_epoch(valid_set)[0]
|
||||
evals_result["valid"].append(valid_metrics)
|
||||
self.logger.info("\tvalid metrics: %s" % valid_metrics)
|
||||
|
||||
if self.eval_test:
|
||||
test_metrics = self.test_epoch(test_set)[0]
|
||||
evals_result["test"].append(test_metrics)
|
||||
self.logger.info("\ttest metrics: %s" % test_metrics)
|
||||
|
||||
if valid_metrics["IC"] > best_score:
|
||||
best_score = valid_metrics["IC"]
|
||||
stop_rounds = 0
|
||||
best_epoch = epoch
|
||||
best_params = {
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
else:
|
||||
stop_rounds += 1
|
||||
if stop_rounds >= self.early_stop:
|
||||
self.logger.info("early stop @ %s" % epoch)
|
||||
break
|
||||
|
||||
# restore parameters
|
||||
self.model.load_state_dict(params_list["model"][-1])
|
||||
self.tra.load_state_dict(params_list["tra"][-1])
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_params["model"])
|
||||
self.tra.load_state_dict(best_params["tra"])
|
||||
|
||||
metrics, preds = self.test_epoch(test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % metrics)
|
||||
|
||||
if self.logdir:
|
||||
self.logger.info("save model & pred to local directory")
|
||||
|
||||
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
|
||||
self.logdir + "/logs.csv", index=False
|
||||
)
|
||||
|
||||
torch.save(best_params, self.logdir + "/model.bin")
|
||||
|
||||
preds.to_pickle(self.logdir + "/pred.pkl")
|
||||
|
||||
info = {
|
||||
"config": {
|
||||
"model_config": self.model_config,
|
||||
"tra_config": self.tra_config,
|
||||
"lr": self.lr,
|
||||
"n_epochs": self.n_epochs,
|
||||
"early_stop": self.early_stop,
|
||||
"smooth_steps": self.smooth_steps,
|
||||
"max_steps_per_epoch": self.max_steps_per_epoch,
|
||||
"lamb": self.lamb,
|
||||
"rho": self.rho,
|
||||
"seed": self.seed,
|
||||
"logdir": self.logdir,
|
||||
},
|
||||
"best_eval_metric": -best_score, # NOTE: minux -1 for minimize
|
||||
"metric": metrics,
|
||||
}
|
||||
with open(self.logdir + "/info.json", "w") as f:
|
||||
json.dump(info, f)
|
||||
|
||||
def predict(self, dataset, segment="test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
test_set = dataset.prepare(segment)
|
||||
|
||||
metrics, preds = self.test_epoch(test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % metrics)
|
||||
|
||||
return preds
|
||||
|
||||
|
||||
class LSTM(nn.Module):
|
||||
|
||||
"""LSTM Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of hidden layers
|
||||
use_attn (bool): whether use attention layer.
|
||||
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
|
||||
dropout (float): dropout rate
|
||||
input_drop (float): input dropout for data augmentation
|
||||
noise_level (float): add gaussian noise to input for data augmentation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
use_attn=True,
|
||||
dropout=0.0,
|
||||
input_drop=0.0,
|
||||
noise_level=0.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.use_attn = use_attn
|
||||
self.noise_level = noise_level
|
||||
|
||||
self.input_drop = nn.Dropout(input_drop)
|
||||
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
if self.use_attn:
|
||||
self.W = nn.Linear(hidden_size, hidden_size)
|
||||
self.u = nn.Linear(hidden_size, 1, bias=False)
|
||||
self.output_size = hidden_size * 2
|
||||
else:
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.input_drop(x)
|
||||
|
||||
if self.training and self.noise_level > 0:
|
||||
noise = torch.randn_like(x).to(x)
|
||||
x = x + noise * self.noise_level
|
||||
|
||||
rnn_out, _ = self.rnn(x)
|
||||
last_out = rnn_out[:, -1]
|
||||
|
||||
if self.use_attn:
|
||||
laten = self.W(rnn_out).tanh()
|
||||
scores = self.u(laten).softmax(dim=1)
|
||||
att_out = (rnn_out * scores).sum(dim=1).squeeze()
|
||||
last_out = torch.cat([last_out, att_out], dim=1)
|
||||
|
||||
return last_out
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[: x.size(0), :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
"""Transformer Model
|
||||
|
||||
Args:
|
||||
input_size (int): input size (# features)
|
||||
hidden_size (int): hidden size
|
||||
num_layers (int): number of transformer layers
|
||||
num_heads (int): number of heads in transformer
|
||||
dropout (float): dropout rate
|
||||
input_drop (float): input dropout for data augmentation
|
||||
noise_level (float): add gaussian noise to input for data augmentation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=16,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
num_heads=2,
|
||||
dropout=0.0,
|
||||
input_drop=0.0,
|
||||
noise_level=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.noise_level = noise_level
|
||||
|
||||
self.input_drop = nn.Dropout(input_drop)
|
||||
|
||||
self.input_proj = nn.Linear(input_size, hidden_size)
|
||||
|
||||
self.pe = PositionalEncoding(input_size, dropout)
|
||||
layer = nn.TransformerEncoderLayer(
|
||||
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
|
||||
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.input_drop(x)
|
||||
|
||||
if self.training and self.noise_level > 0:
|
||||
noise = torch.randn_like(x).to(x)
|
||||
x = x + noise * self.noise_level
|
||||
|
||||
x = x.permute(1, 0, 2).contiguous() # the first dim need to be sequence
|
||||
x = self.pe(x)
|
||||
|
||||
x = self.input_proj(x)
|
||||
out = self.encoder(x)
|
||||
|
||||
return out[-1]
|
||||
|
||||
|
||||
class TRA(nn.Module):
|
||||
|
||||
"""Temporal Routing Adaptor (TRA)
|
||||
|
||||
TRA takes historical prediction erros & latent representation as inputs,
|
||||
then routes the input sample to a specific predictor for training & inference.
|
||||
|
||||
Args:
|
||||
input_size (int): input size (RNN/Transformer's hidden size)
|
||||
num_states (int): number of latent states (i.e., trading patterns)
|
||||
If `num_states=1`, then TRA falls back to traditional methods
|
||||
hidden_size (int): hidden size of the router
|
||||
tau (float): gumbel softmax temperature
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"):
|
||||
super().__init__()
|
||||
|
||||
self.num_states = num_states
|
||||
self.tau = tau
|
||||
self.src_info = src_info
|
||||
|
||||
if num_states > 1:
|
||||
self.router = nn.LSTM(
|
||||
input_size=num_states,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
)
|
||||
self.fc = nn.Linear(hidden_size + input_size, num_states)
|
||||
|
||||
self.predictors = nn.Linear(input_size, num_states)
|
||||
|
||||
def forward(self, hidden, hist_loss):
|
||||
|
||||
preds = self.predictors(hidden)
|
||||
|
||||
if self.num_states == 1:
|
||||
return preds.squeeze(-1), preds, None
|
||||
|
||||
# information type
|
||||
router_out, _ = self.router(hist_loss)
|
||||
if "LR" in self.src_info:
|
||||
latent_representation = hidden
|
||||
else:
|
||||
latent_representation = torch.randn(hidden.shape).to(hidden)
|
||||
if "TPE" in self.src_info:
|
||||
temporal_pred_error = router_out[:, -1]
|
||||
else:
|
||||
temporal_pred_error = torch.randn(router_out[:, -1].shape).to(hidden)
|
||||
|
||||
out = self.fc(torch.cat([temporal_pred_error, latent_representation], dim=-1))
|
||||
prob = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=False)
|
||||
|
||||
if self.training:
|
||||
final_pred = (preds * prob).sum(dim=-1)
|
||||
else:
|
||||
final_pred = preds[range(len(preds)), prob.argmax(dim=-1)]
|
||||
|
||||
return final_pred, preds, prob
|
||||
|
||||
|
||||
def evaluate(pred):
|
||||
pred = pred.rank(pct=True) # transform into percentiles
|
||||
score = pred.score
|
||||
label = pred.label
|
||||
diff = score - label
|
||||
MSE = (diff ** 2).mean()
|
||||
MAE = (diff.abs()).mean()
|
||||
IC = score.corr(label)
|
||||
return {"MSE": MSE, "MAE": MAE, "IC": IC}
|
||||
|
||||
|
||||
def average_params(params_list):
|
||||
assert isinstance(params_list, (tuple, list, collections.deque))
|
||||
n = len(params_list)
|
||||
if n == 1:
|
||||
return params_list[0]
|
||||
new_params = collections.OrderedDict()
|
||||
keys = None
|
||||
for i, params in enumerate(params_list):
|
||||
if keys is None:
|
||||
keys = params.keys()
|
||||
for k, v in params.items():
|
||||
if k not in keys:
|
||||
raise ValueError("the %d-th model has different params" % i)
|
||||
if k not in new_params:
|
||||
new_params[k] = v / n
|
||||
else:
|
||||
new_params[k] += v / n
|
||||
return new_params
|
||||
|
||||
|
||||
def shoot_infs(inp_tensor):
|
||||
"""Replaces inf by maximum of tensor"""
|
||||
mask_inf = torch.isinf(inp_tensor)
|
||||
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
|
||||
if len(ind_inf) > 0:
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = 0
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = 0
|
||||
m = torch.max(inp_tensor)
|
||||
for ind in ind_inf:
|
||||
if len(ind) == 2:
|
||||
inp_tensor[ind[0], ind[1]] = m
|
||||
elif len(ind) == 1:
|
||||
inp_tensor[ind[0]] = m
|
||||
return inp_tensor
|
||||
|
||||
|
||||
def sinkhorn(Q, n_iters=3, epsilon=0.01):
|
||||
# epsilon should be adjusted according to logits value's scale
|
||||
with torch.no_grad():
|
||||
Q = shoot_infs(Q)
|
||||
Q = torch.exp(Q / epsilon)
|
||||
for i in range(n_iters):
|
||||
Q /= Q.sum(dim=0, keepdim=True)
|
||||
Q /= Q.sum(dim=1, keepdim=True)
|
||||
return Q
|
||||
@@ -1,129 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 20
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.0
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha158
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,123 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 158
|
||||
hidden_size: 256
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.2
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha158_full
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size:
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,123 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
model_config: &model_config
|
||||
input_size: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
rnn_arch: LSTM
|
||||
use_attn: True
|
||||
dropout: 0.0
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
class: TRAModel
|
||||
module_path: qlib.contrib.model.pytorch_tra
|
||||
kwargs:
|
||||
tra_config: *tra_config
|
||||
model_config: *model_config
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha360
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
eval_test: True
|
||||
pretrain: True
|
||||
init_state:
|
||||
freeze_model: False
|
||||
freeze_predictors: False
|
||||
dataset:
|
||||
class: MTSDatasetH
|
||||
module_path: qlib.contrib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
seq_len: 60
|
||||
horizon: 2
|
||||
input_size: 6
|
||||
num_states: *num_states
|
||||
batch_size: 1024
|
||||
n_samples:
|
||||
memory_mode: *memory_mode
|
||||
drop_last: True
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
BIN
examples/benchmarks/TabNet/pretrain/best.model
Normal file
BIN
examples/benchmarks/TabNet/pretrain/best.model
Normal file
Binary file not shown.
@@ -44,7 +44,6 @@ task:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
d_feat: 158
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
@@ -56,7 +55,7 @@ task:
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2016-12-31]
|
||||
pretrain_validation: [2015-01-01, 2020-08-01]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
d_feat: 360
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2016-12-31]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,3 +0,0 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
torch==1.2.0
|
||||
@@ -1,82 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TransformerModel
|
||||
module_path: qlib.contrib.model.pytorch_transformer_ts
|
||||
kwargs:
|
||||
seed: 0
|
||||
n_jobs: 20
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,73 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TransformerModel
|
||||
module_path: qlib.contrib.model.pytorch_transformer
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
seed: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -25,11 +25,4 @@ The example is given in `workflow.py`, users can run the code as follows.
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py dump_and_load_dataset
|
||||
```
|
||||
|
||||
## Benchmarks Performance
|
||||
### Signal Test
|
||||
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |
|
||||
```
|
||||
@@ -1,5 +1,7 @@
|
||||
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
|
||||
from qlib.contrib.data.handler import check_transform_proc
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.utils import get_cls_kwargs
|
||||
from qlib.log import TimeInspector
|
||||
|
||||
|
||||
class HighFreqHandler(DataHandlerLP):
|
||||
@@ -14,9 +16,20 @@ class HighFreqHandler(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
@@ -99,6 +112,8 @@ class HighFreqHandler(DataHandlerLP):
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
|
||||
names += ["date"]
|
||||
return fields, names
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def get_calendar_day(freq="day", future=False):
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(list(map(lambda x: pd.Timestamp(x.date()), Cal.load_calendar(freq, future))))
|
||||
_calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
|
||||
H["c"][flag] = _calendar
|
||||
return _calendar
|
||||
|
||||
|
||||
@@ -33,9 +33,6 @@ class HighFreqNorm(Processor):
|
||||
self.feature_vmin[name] = np.nanmin(part_values)
|
||||
|
||||
def __call__(self, df_features):
|
||||
df_features["date"] = pd.to_datetime(
|
||||
df_features.index.get_level_values(level="datetime").to_series().dt.date.values
|
||||
)
|
||||
df_features.set_index("date", append=True, drop=True, inplace=True)
|
||||
df_values = df_features.values
|
||||
names = {
|
||||
|
||||
@@ -1,13 +1,24 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import fire
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils import init_instance_by_config, exists_qlib_data
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
@@ -16,16 +27,17 @@ from qlib.tests.data import GetData
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
|
||||
|
||||
|
||||
class HighfreqWorkflow:
|
||||
class HighfreqWorkflow(object):
|
||||
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
|
||||
|
||||
MARKET = "all"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
start_time = "2020-09-15 00:00:00"
|
||||
end_time = "2021-01-18 16:00:00"
|
||||
train_end_time = "2020-11-30 16:00:00"
|
||||
test_start_time = "2020-12-01 00:00:00"
|
||||
start_time = pd.Timestamp("2020-09-15 00:00:00")
|
||||
end_time = pd.Timestamp("2021-01-18 16:00:00")
|
||||
train_end_time = pd.Timestamp("2020-11-30 16:00:00")
|
||||
test_start_time = pd.Timestamp("2020-12-01 00:00:00")
|
||||
|
||||
DATA_HANDLER_CONFIG0 = {
|
||||
"start_time": start_time,
|
||||
@@ -33,7 +45,7 @@ class HighfreqWorkflow:
|
||||
"fit_start_time": start_time,
|
||||
"fit_end_time": train_end_time,
|
||||
"instruments": MARKET,
|
||||
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor"}],
|
||||
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}],
|
||||
}
|
||||
DATA_HANDLER_CONFIG1 = {
|
||||
"start_time": start_time,
|
||||
@@ -85,7 +97,9 @@ class HighfreqWorkflow:
|
||||
# use yahoo_cn_1min data
|
||||
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
|
||||
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True)
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
|
||||
qlib.init(**QLIB_INIT_CONFIG)
|
||||
|
||||
def _prepare_calender_cache(self):
|
||||
@@ -132,44 +146,72 @@ class HighfreqWorkflow:
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reinit dataset=============
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segments={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
dataset.init(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_LS,
|
||||
},
|
||||
)
|
||||
dataset_backtest.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segments={
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.init(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.setup_data(handler_kwargs={})
|
||||
|
||||
##=============get data=============
|
||||
xtest = dataset.prepare("test")
|
||||
backtest_test = dataset_backtest.prepare("test")
|
||||
xtest = dataset.prepare(["test"])
|
||||
backtest_test = dataset_backtest.prepare(["test"])
|
||||
|
||||
print(xtest, backtest_test)
|
||||
return
|
||||
|
||||
|
||||
def get_high_freq_data(self, data_path):
|
||||
self._init_qlib()
|
||||
self._prepare_calender_cache()
|
||||
|
||||
import os
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
xtrain, xtest = dataset.prepare(["train", "test"])
|
||||
normed_feature = pd.concat([xtrain, xtest]).sort_index()
|
||||
dic = dict(tuple(normed_feature.groupby("instrument")))
|
||||
feature_path = os.path.join(data_path, "normed_feature/")
|
||||
if not os.path.exists(feature_path):
|
||||
os.makedirs(feature_path)
|
||||
for k, v in dic.items():
|
||||
v.to_pickle(feature_path + f"{k}.pkl")
|
||||
|
||||
|
||||
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
backtest = pd.concat([backtest_train, backtest_test]).sort_index()
|
||||
backtest['date'] = backtest.index.map(lambda x: x[1].date())
|
||||
backtest.set_index('date', append=True, drop=True, inplace=True)
|
||||
dic = dict(tuple(backtest.groupby("instrument")))
|
||||
backtest_path = os.path.join(data_path, "backtest/")
|
||||
if not os.path.exists(backtest_path):
|
||||
os.makedirs(backtest_path)
|
||||
for k, v in dic.items():
|
||||
v.to_pickle(backtest_path + f"{k}.pkl.backtest")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(HighfreqWorkflow)
|
||||
#fire.Fire(HighfreqWorkflow)
|
||||
data_path = '../data/'
|
||||
workflow = HighfreqWorkflow()
|
||||
workflow.get_high_freq_data(data_path)
|
||||
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
market: &market 'csi300'
|
||||
start_time: &start_time "2020-09-15 00:00:00"
|
||||
end_time: &end_time "2021-01-18 16:00:00"
|
||||
train_end_time: &train_end_time "2020-11-15 16:00:00"
|
||||
valid_start_time: &valid_start_time "2020-11-16 00:00:00"
|
||||
valid_end_time: &valid_end_time "2020-11-30 16:00:00"
|
||||
test_start_time: &test_start_time "2020-12-01 00:00:00"
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: *start_time
|
||||
end_time: *end_time
|
||||
fit_start_time: *start_time
|
||||
fit_end_time: *train_end_time
|
||||
instruments: *market
|
||||
freq: '1min'
|
||||
infer_processors:
|
||||
- class: 'RobustZScoreNorm'
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
clip_outlier: false
|
||||
- class: "Fillna"
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
learn_processors:
|
||||
- class: 'DropnaLabel'
|
||||
- class: 'CSRankNorm'
|
||||
kwargs:
|
||||
fields_group: 'label'
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
task:
|
||||
model:
|
||||
class: "HFLGBModel"
|
||||
module_path: "qlib.contrib.model.highfreq_gdbt_model"
|
||||
kwargs:
|
||||
objective: 'binary'
|
||||
metric: ['binary_logloss','auc']
|
||||
verbosity: -1
|
||||
learning_rate: 0.01
|
||||
max_depth: 8
|
||||
num_leaves: 150
|
||||
lambda_l1: 1.5
|
||||
lambda_l2: 1
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: "DatasetH"
|
||||
module_path: "qlib.data.dataset"
|
||||
kwargs:
|
||||
handler:
|
||||
class: "Alpha158"
|
||||
module_path: "qlib.contrib.data.handler"
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [*start_time, *train_end_time]
|
||||
valid: [*train_end_time, *valid_end_time]
|
||||
test: [*test_start_time, *end_time]
|
||||
record:
|
||||
- class: "SignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
- class: "HFSignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
@@ -1,23 +0,0 @@
|
||||
# LightGBM hyperparameter
|
||||
|
||||
## Alpha158
|
||||
First terminal
|
||||
```
|
||||
optuna create-study --study LGBM_158 --storage sqlite:///db.sqlite3
|
||||
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
|
||||
```
|
||||
Second terminal
|
||||
```
|
||||
python hyperparameter_158.py
|
||||
```
|
||||
|
||||
## Alpha360
|
||||
First terminal
|
||||
```
|
||||
optuna create-study --study LGBM_360 --storage sqlite:///db.sqlite3
|
||||
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
|
||||
```
|
||||
Second terminal
|
||||
```
|
||||
python hyperparameter_360.py
|
||||
```
|
||||
@@ -1,46 +0,0 @@
|
||||
import qlib
|
||||
import optuna
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.config import CSI300_DATASET_CONFIG
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
def objective(trial):
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
|
||||
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
|
||||
"subsample": trial.suggest_uniform("subsample", 0, 1),
|
||||
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
|
||||
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
|
||||
"max_depth": 10,
|
||||
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
|
||||
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
|
||||
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
|
||||
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
|
||||
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
|
||||
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
|
||||
},
|
||||
},
|
||||
}
|
||||
evals_result = dict()
|
||||
model = init_instance_by_config(task["model"])
|
||||
model.fit(dataset, evals_result=evals_result)
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region="cn")
|
||||
|
||||
dataset = init_instance_by_config(CSI300_DATASET_CONFIG)
|
||||
|
||||
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
@@ -1,49 +0,0 @@
|
||||
import qlib
|
||||
import optuna
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS
|
||||
|
||||
DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)
|
||||
|
||||
|
||||
def objective(trial):
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
|
||||
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
|
||||
"subsample": trial.suggest_uniform("subsample", 0, 1),
|
||||
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
|
||||
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
|
||||
"max_depth": 10,
|
||||
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
|
||||
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
|
||||
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
|
||||
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
|
||||
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
|
||||
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
evals_result = dict()
|
||||
model = init_instance_by_config(task["model"])
|
||||
model.fit(dataset, evals_result=evals_result)
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
dataset = init_instance_by_config(DATASET_CONFIG)
|
||||
|
||||
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
@@ -1,5 +0,0 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
lightgbm==3.1.0
|
||||
optuna==2.7.0
|
||||
optuna-dashboard==0.4.1
|
||||
@@ -1,32 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
# model initialization
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
# get model feature importance
|
||||
feature_importance = model.get_feature_importance()
|
||||
print("feature importance:")
|
||||
print(feature_importance)
|
||||
@@ -1 +0,0 @@
|
||||
xgboost
|
||||
@@ -1,111 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be shown in task_collecting.
|
||||
Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM, task_train
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
class RollingTaskExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region=REG_CN,
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
task_config=None,
|
||||
rolling_step=550,
|
||||
rolling_type=RollingGen.ROLL_SD,
|
||||
):
|
||||
# TaskManager config
|
||||
if task_config is None:
|
||||
task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.task_config = task_config
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.experiment_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
def task_generating(self):
|
||||
print("========== task_generating ==========")
|
||||
tasks = task_generator(
|
||||
tasks=self.task_config,
|
||||
generators=self.rolling_gen, # generate different date segments
|
||||
)
|
||||
pprint(tasks)
|
||||
return tasks
|
||||
|
||||
def task_training(self, tasks):
|
||||
print("========== task_training ==========")
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
|
||||
print("========== worker ==========")
|
||||
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
|
||||
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
collector = RecorderCollector(
|
||||
experiment=self.experiment_name,
|
||||
process_list=RollingGroup(),
|
||||
rec_key_func=rec_key,
|
||||
rec_filter_func=my_filter,
|
||||
)
|
||||
print(collector())
|
||||
|
||||
def main(self):
|
||||
self.reset()
|
||||
tasks = self.task_generating()
|
||||
self.task_training(tasks)
|
||||
self.task_collecting()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python task_manager_rolling.py main --experiment_name="your_exp_name"
|
||||
fire.Fire(RollingTaskExample)
|
||||
@@ -1,102 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=None,
|
||||
):
|
||||
"""
|
||||
Init OnlineManagerExample.
|
||||
|
||||
Args:
|
||||
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
|
||||
region (str, optional): the stock region. Defaults to "cn".
|
||||
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
|
||||
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
|
||||
task_db_name (str, optional): database name. Defaults to "rolling_db".
|
||||
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
|
||||
rolling_step (int, optional): the step for rolling. Defaults to 80.
|
||||
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
|
||||
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
begin_time=self.start_time,
|
||||
)
|
||||
self.tasks = tasks
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
# Run this to run all workflow automatically
|
||||
def main(self):
|
||||
print("========== reset ==========")
|
||||
self.reset()
|
||||
print("========== simulate ==========")
|
||||
self.rolling_online_manager.simulate(end_time=self.end_time)
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
# FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
self.trainer.worker()
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
|
||||
fire.Fire(OnlineSimulationExample)
|
||||
@@ -1,144 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how OnlineManager works with rolling tasks.
|
||||
There are four parts including first train, routine 1, add strategy and routine 2.
|
||||
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
|
||||
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
|
||||
Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
|
||||
Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
"""
|
||||
|
||||
import os
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
rolling_step=550,
|
||||
tasks=None,
|
||||
add_tasks=None,
|
||||
):
|
||||
if add_tasks is None:
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.tasks = tasks
|
||||
self.add_tasks = add_tasks
|
||||
self.rolling_step = rolling_step
|
||||
strategies = []
|
||||
for task in tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
self.trainer = trainer
|
||||
self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
self.trainer.worker(experiment_name=name_id)
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
TaskManager(task_pool=name_id).remove()
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def first_run(self):
|
||||
print("========== reset ==========")
|
||||
self.reset()
|
||||
print("========== first_run ==========")
|
||||
self.rolling_online_manager.first_train()
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def routine(self):
|
||||
print("========== load ==========")
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== routine ==========")
|
||||
self.rolling_online_manager.routine()
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def add_strategy(self):
|
||||
print("========== load ==========")
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== add strategy ==========")
|
||||
strategies = []
|
||||
for task in self.add_tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
self.rolling_online_manager.add_strategy(strategies=strategies)
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def main(self):
|
||||
self.first_run()
|
||||
self.routine()
|
||||
self.add_strategy()
|
||||
self.routine()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### to train the first version's models, use the command below
|
||||
# python rolling_online_management.py first_run
|
||||
|
||||
####### to update the models and predictions after the trading time, use the command below
|
||||
# python rolling_online_management.py routine
|
||||
|
||||
####### to define your own parameters, use `--`
|
||||
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
|
||||
fire.Fire(RollingOnlineExample)
|
||||
@@ -1,54 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how OnlineTool works when we need update prediction.
|
||||
There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained models to the `online` models.
|
||||
Next, we will finish updating online predictions.
|
||||
"""
|
||||
import copy
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.online.utils import OnlineToolR
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
task = copy.deepcopy(CSI300_GBDT_TASK)
|
||||
|
||||
task["record"] = {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
}
|
||||
|
||||
|
||||
class UpdatePredExample:
|
||||
def __init__(
|
||||
self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
|
||||
):
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
self.experiment_name = experiment_name
|
||||
self.online_tool = OnlineToolR(self.experiment_name)
|
||||
self.task_config = task_config
|
||||
|
||||
def first_train(self):
|
||||
rec = task_train(self.task_config, experiment_name=self.experiment_name)
|
||||
self.online_tool.reset_online_tag(rec) # set to online model
|
||||
|
||||
def update_online_pred(self):
|
||||
self.online_tool.update_online_pred()
|
||||
|
||||
def main(self):
|
||||
self.first_train()
|
||||
self.update_online_pred()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to train a model and set it to online model, use the command below
|
||||
# python update_online_pred.py first_train
|
||||
## to update online predictions once a day, use the command below
|
||||
# python update_online_pred.py update_online_pred
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire(UpdatePredExample)
|
||||
@@ -1,17 +0,0 @@
|
||||
# Rolling Process Data
|
||||
|
||||
This workflow is an example for `Rolling Process Data`.
|
||||
|
||||
## Background
|
||||
|
||||
When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change.
|
||||
|
||||
In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window.
|
||||
|
||||
|
||||
## Run the Code
|
||||
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py rolling_process
|
||||
```
|
||||
@@ -1,32 +0,0 @@
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.loader import DataLoaderDH
|
||||
from qlib.contrib.data.handler import check_transform_proc
|
||||
|
||||
|
||||
class RollingDataHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
data_loader_kwargs={},
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "DataLoaderDH",
|
||||
"kwargs": {**data_loader_kwargs},
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
instruments=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
)
|
||||
@@ -1,137 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import pickle
|
||||
|
||||
from datetime import datetime
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
class RollingDataWorkflow:
|
||||
|
||||
MARKET = "csi300"
|
||||
start_time = "2010-01-01"
|
||||
end_time = "2019-12-31"
|
||||
rolling_cnt = 5
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def _dump_pre_handler(self, path):
|
||||
handler_config = {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"instruments": self.MARKET,
|
||||
"infer_processors": [],
|
||||
"learn_processors": [],
|
||||
},
|
||||
}
|
||||
pre_handler = init_instance_by_config(handler_config)
|
||||
pre_handler.config(dump_all=True)
|
||||
pre_handler.to_pickle(path)
|
||||
|
||||
def _load_pre_handler(self, path):
|
||||
with open(path, "rb") as file_dataset:
|
||||
pre_handler = pickle.load(file_dataset)
|
||||
return pre_handler
|
||||
|
||||
def rolling_process(self):
|
||||
self._init_qlib()
|
||||
self._dump_pre_handler("pre_handler.pkl")
|
||||
pre_handler = self._load_pre_handler("pre_handler.pkl")
|
||||
|
||||
train_start_time = (2010, 1, 1)
|
||||
train_end_time = (2012, 12, 31)
|
||||
valid_start_time = (2013, 1, 1)
|
||||
valid_end_time = (2013, 12, 31)
|
||||
test_start_time = (2014, 1, 1)
|
||||
test_end_time = (2014, 12, 31)
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "RollingDataHandler",
|
||||
"module_path": "rolling_handler",
|
||||
"kwargs": {
|
||||
"start_time": datetime(*train_start_time),
|
||||
"end_time": datetime(*test_end_time),
|
||||
"fit_start_time": datetime(*train_start_time),
|
||||
"fit_end_time": datetime(*train_end_time),
|
||||
"infer_processors": [
|
||||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}},
|
||||
],
|
||||
"learn_processors": [
|
||||
{"class": "DropnaLabel"},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
],
|
||||
"data_loader_kwargs": {
|
||||
"handler_config": pre_handler,
|
||||
},
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": (datetime(*train_start_time), datetime(*train_end_time)),
|
||||
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
|
||||
"test": (datetime(*test_start_time), datetime(*test_end_time)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
for rolling_offset in range(self.rolling_cnt):
|
||||
|
||||
print(f"===========rolling{rolling_offset} start===========")
|
||||
if rolling_offset:
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
"processor_kwargs": {
|
||||
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
},
|
||||
},
|
||||
segments={
|
||||
"train": (
|
||||
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
),
|
||||
"valid": (
|
||||
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),
|
||||
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),
|
||||
),
|
||||
"test": (
|
||||
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),
|
||||
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_FIT_SEQ,
|
||||
}
|
||||
)
|
||||
|
||||
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
|
||||
print(dtrain, dvalid, dtest)
|
||||
## print or dump data
|
||||
print(f"===========rolling{rolling_offset} end===========")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(RollingDataWorkflow)
|
||||
@@ -5,11 +5,13 @@ import os
|
||||
import sys
|
||||
import fire
|
||||
import time
|
||||
import venv
|
||||
import glob
|
||||
import shutil
|
||||
import signal
|
||||
import inspect
|
||||
import tempfile
|
||||
import traceback
|
||||
import functools
|
||||
import statistics
|
||||
import subprocess
|
||||
@@ -21,7 +23,9 @@ from pprint import pprint
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.workflow.cli import workflow
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
# init qlib
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
@@ -35,11 +39,14 @@ exp_manager = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
|
||||
|
||||
# decorator to check the arguments
|
||||
def only_allow_defined_args(function_to_decorate):
|
||||
@functools.wraps(function_to_decorate)
|
||||
@@ -92,8 +99,7 @@ def create_env():
|
||||
|
||||
|
||||
# function to execute the cmd
|
||||
def execute(cmd, wait_when_err=False):
|
||||
print("Running CMD:", cmd)
|
||||
def execute(cmd):
|
||||
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
|
||||
for line in p.stdout:
|
||||
sys.stdout.write(line.split("\b")[0])
|
||||
@@ -103,8 +109,6 @@ def execute(cmd, wait_when_err=False):
|
||||
sys.stdout.write("\b" * 10 + "\b".join(line.split("\b")[1:-1]))
|
||||
|
||||
if p.returncode != 0:
|
||||
if wait_when_err:
|
||||
input("Press Enter to Continue")
|
||||
return p.stderr
|
||||
else:
|
||||
return None
|
||||
@@ -187,15 +191,7 @@ def gen_and_save_md_table(metrics, dataset):
|
||||
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(
|
||||
times=1,
|
||||
models=None,
|
||||
dataset="Alpha360",
|
||||
exclude=False,
|
||||
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
|
||||
wait_before_rm_env: bool = False,
|
||||
wait_when_err: bool = False,
|
||||
):
|
||||
def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parrallel running the same model
|
||||
@@ -211,13 +207,6 @@ def run(
|
||||
determines whether the model being used is excluded or included.
|
||||
dataset : str
|
||||
determines the dataset to be used for each model.
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path
|
||||
wait_before_rm_env : bool
|
||||
wait before remove environment.
|
||||
wait_when_err : bool
|
||||
wait when errors raised when executing commands
|
||||
|
||||
Usage:
|
||||
-------
|
||||
@@ -258,36 +247,32 @@ def run(
|
||||
sys.stderr.write("\n")
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
|
||||
execute(f"{python_path} -m pip install -r {req_path}")
|
||||
sys.stderr.write("\n")
|
||||
# setup gpu for tft
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
|
||||
wait_when_err=wait_when_err,
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn"
|
||||
)
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade pip") # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e git+https://github.com/microsoft/qlib#egg=pyqlib"
|
||||
) # TODO: FIX ME!
|
||||
else:
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e git+https://github.com/microsoft/qlib#egg=pyqlib"
|
||||
) # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config for multiple times
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
errs = execute(
|
||||
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
|
||||
wait_when_err=wait_when_err,
|
||||
f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
|
||||
)
|
||||
if errs is not None:
|
||||
_errs = errors.get(fn, {})
|
||||
@@ -296,8 +281,6 @@ def run(
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
if wait_before_rm_env:
|
||||
input("Press Enter to Continue")
|
||||
shutil.rmtree(env_path)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
|
||||
104
examples/trade/README.md
Normal file
104
examples/trade/README.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# Universal Trading for Order Execution with Oracle Policy Distillation
|
||||
This is the experiment code for our AAAI 2021 paper "[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)", including the implementations of all the compared methods in the paper and a general reinforcement learning framework for order execution in quantitative finance.
|
||||
|
||||
## Abstract
|
||||
As a fundamental problem in algorithmic trading, order execution aims at fulfilling a specific trading order, either liquidation or acquirement, for a given instrument. Towards effective execution strategy, recent years have witnessed the shift from the analytical view with model-based market assumptions to model-free perspective, i.e., reinforcement learning, due to its nature of sequential decision optimization. However, the noisy and yet imperfect market information that can be leveraged by the policy has made it quite challenging to build up sample efficient reinforcement learning methods to achieve effective order execution. In this paper, we propose a novel universal trading policy optimization framework to bridge the gap between the noisy yet imperfect market states and the optimal action sequences for order execution. Particularly, this framework leverages a policy distillation method that can better guide the learning of the common policy towards practically optimal execution by an oracle teacher with perfect information to approximate the optimal trading strategy. The extensive experiments have shown significant improvements of our method over various strong baselines, with reasonable trading actions.
|
||||
|
||||
## Environment Dependencies
|
||||
|
||||
### Dependencies
|
||||
|
||||
```
|
||||
gym==0.17.3
|
||||
torch==1.6.0
|
||||
numba==0.51.2
|
||||
numpy==1.19.1
|
||||
pandas==1.1.3
|
||||
tqdm==4.50.2
|
||||
tianshou==0.3.0.post1
|
||||
env==0.1.0
|
||||
PyYAML==5.4.1
|
||||
redis==3.5.3
|
||||
```
|
||||
|
||||
### Environment Variable
|
||||
|
||||
`EXP_PATH` Absolute path to your config folder, we give folder `exp` as an example.
|
||||
|
||||
`OUTPUT_DIR` Absolute path to your log folder.
|
||||
|
||||
## Data Processing
|
||||
|
||||
For Feature processing, we take Yahoo dataset as an example, which can be precessed in `qlib/examples/highfreq/workflow.py` file. If you have a need to change your data storage path, you can change the `data_path` in `workflow.py`, and then do the following.
|
||||
|
||||
```
|
||||
python workflow.py
|
||||
```
|
||||
|
||||
For order generation, if you have changed change the the `data_path` in `workflow.py`, change `data_path` in `order_gen.py` again, then do the following.
|
||||
|
||||
```
|
||||
python order_gen.py
|
||||
```
|
||||
|
||||
## Training and backtest
|
||||
|
||||
### Config file
|
||||
|
||||
Config file is need to start our project, we take `PPO`, `OPDS` and `OPD` as an example in folder `exp/example`. If you want to use our given config, make sure the `data_path` you set before matches the config file.
|
||||
|
||||
### Baseline method
|
||||
|
||||
To run a method, you can do the following.
|
||||
|
||||
```
|
||||
python main.py --config={config_path}
|
||||
```
|
||||
|
||||
Where `{config_path}` means the relative path from your config.yml to `EXP_PATH`.
|
||||
|
||||
If you need to run our given method such as PPO method, you can do the following.
|
||||
|
||||
```
|
||||
python main.py --config=example/PPO/config.yml
|
||||
```
|
||||
|
||||
### OPD method
|
||||
|
||||
OPD method is a multi step method, at first you should run OPDT as the teacher in OPD method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPDT/config.yml
|
||||
```
|
||||
|
||||
After training, find the `policy_best` file in your OPDT log file and copy it to `trade` file for backtest. Also you can change `policy_path` in the `example/OPDT_b/config.yml` to your `policy_best` file. Then run the backtest method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPDT_b/config.yml
|
||||
```
|
||||
|
||||
then processed feature from teacher. Remember to change `log_path` if you have changed `log_dir` in `OPDT_b/config.yml`.
|
||||
|
||||
```
|
||||
python teacher_feature.py
|
||||
```
|
||||
|
||||
and finally start our OPD method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPD/config.yml
|
||||
```
|
||||
|
||||
## Citation
|
||||
You are more than welcome to citetmu our paper:
|
||||
```
|
||||
@inproceedings{fang2021universal,
|
||||
title={Universal Trading for Order Execution with Oracle Policy Distillation},
|
||||
author={Fang, Yuchen and Ren, Kan and Liu, Weiqing and Zhou, Dong and Zhang, Weinan and Bian, Jiang and Yu, Yong and Liu, Tie-Yan},
|
||||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||
volume={35},
|
||||
number={1},
|
||||
pages={107--115},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
10
examples/trade/__init__.py
Normal file
10
examples/trade/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# from rl4execution import env, trainer, exploration
|
||||
|
||||
# __all__ = [
|
||||
# "env",
|
||||
# "data",
|
||||
# "utils",
|
||||
# "policy",
|
||||
# "trainer",
|
||||
# "exploration",
|
||||
# ]
|
||||
4
examples/trade/action/__init__.py
Normal file
4
examples/trade/action/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import *
|
||||
from .action_rl import *
|
||||
from .action_rule import *
|
||||
from .action_rl import *
|
||||
27
examples/trade/action/action_rl.py
Normal file
27
examples/trade/action/action_rl.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Static_Action(Base_Action):
|
||||
""" """
|
||||
|
||||
def __init__(self, config):
|
||||
self.action_num = config["action_num"]
|
||||
self.action_map = config["action_map"]
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Discrete(self.action_num)
|
||||
|
||||
def get_action(self, action, target, position, **kargs):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
:param position:
|
||||
:param target:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return min(target * self.action_map[action], position)
|
||||
46
examples/trade/action/action_rule.py
Normal file
46
examples/trade/action/action_rule.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Rule_Dynamic(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, max_step_num, t, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param max_step_num:
|
||||
:param t: param **kargs:
|
||||
:param target:
|
||||
:param max_step_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return position / (max_step_num - (t + 1)) * action
|
||||
|
||||
|
||||
class Rule_Static(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, max_step_num, t, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param max_step_num:
|
||||
:param t: param **kargs:
|
||||
:param target:
|
||||
:param max_step_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return target / max_step_num * action
|
||||
20
examples/trade/action/base.py
Normal file
20
examples/trade/action/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
|
||||
class Base_Action(object):
|
||||
""" """
|
||||
|
||||
def __init__(self, config):
|
||||
return
|
||||
|
||||
def __call__(self, *args, **kargs):
|
||||
return self.get_action(*args, **kargs)
|
||||
|
||||
def get_action(self, action):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
|
||||
"""
|
||||
return action
|
||||
46
examples/trade/action/interval_rule.py
Normal file
46
examples/trade/action/interval_rule.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Rule_Static_Interval(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, interval_num, interval, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param interval_num:
|
||||
:param interval: param **kargs:
|
||||
:param target:
|
||||
:param interval_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return target / (interval_num) * action
|
||||
|
||||
|
||||
class Rule_Dynamic_Interval(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, interval_num, interval, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param interval_num:
|
||||
:param interval: param **kargs:
|
||||
:param target:
|
||||
:param interval_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return position / (interval_num - interval) * action
|
||||
1
examples/trade/agent/__init__.py
Normal file
1
examples/trade/agent/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .basic import *
|
||||
69
examples/trade/agent/basic.py
Normal file
69
examples/trade/agent/basic.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch
|
||||
import numpy as np
|
||||
import torch
|
||||
from env import nan_weighted_avg
|
||||
|
||||
|
||||
class TWAP(BasePolicy):
|
||||
""" The TWAP strategy. """
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.max_step_num = config["max_step_num"]
|
||||
self.num_cpus = config["num_cpus"]
|
||||
|
||||
# @njit(parallel=True)
|
||||
def forward(self, batch: Batch, state=None, **kwargs) -> Batch:
|
||||
act = [1] * len(batch.obs.private)
|
||||
return Batch(act=act, state=state)
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
pass
|
||||
|
||||
|
||||
class VWAP(BasePolicy):
|
||||
""" The VWAP strategy."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, batch, state, **kwargs):
|
||||
obs = batch.obs
|
||||
r = np.stack(obs.prediction).reshape(-1)
|
||||
return Batch(act=r, state=state)
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
pass
|
||||
|
||||
|
||||
class AC(VWAP):
|
||||
"""Almgren-Chriss strategy."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.T = config["max_step_num"]
|
||||
self.gamma = 0
|
||||
self.tau = 1
|
||||
self.lamb = config["lambda"]
|
||||
self.eps = 0.0625
|
||||
self.alpha = 0.02
|
||||
self.eta = 2.5e-6
|
||||
|
||||
def forward(self, batch, state, **kwargs):
|
||||
obs = batch.obs
|
||||
sig = np.stack(obs.prediction).reshape(-1)
|
||||
sell = ~np.stack(obs.is_buy).astype(np.bool)
|
||||
data = np.stack(obs.private)
|
||||
t = data[:, 2]
|
||||
t = t + 1
|
||||
k_tild = self.lamb / self.eta * sig * sig
|
||||
k = np.arccosh(k_tild / 2 + 1)
|
||||
act = (np.sinh(k * (self.T - t)) - np.sinh(k * (self.T - t - 1))) / np.sinh(k * self.T)
|
||||
return Batch(act=act, state=state)
|
||||
342
examples/trade/collector.py
Normal file
342
examples/trade/collector.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import gym
|
||||
import time
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Union, Optional, Callable
|
||||
|
||||
from vecenv import BaseVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.data.collector import _batch_set_item
|
||||
|
||||
|
||||
class Collector(object):
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
env: Union[gym.Env, BaseVectorEnv],
|
||||
testing=False,
|
||||
buffer: Optional[ReplayBuffer] = None,
|
||||
preprocess_fn: Optional[Callable[..., Batch]] = None,
|
||||
action_noise: Optional[BaseNoise] = None,
|
||||
reward_metric: Optional[Callable[[np.ndarray], float]] = np.sum,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(env, BaseVectorEnv):
|
||||
env = DummyVectorEnv([lambda: env])
|
||||
self.env = env
|
||||
self.env_num = len(env)
|
||||
# environments that are available in step()
|
||||
# this means all environments in synchronous simulation
|
||||
# but only a subset of environments in asynchronous simulation
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
# self.async is a flag to indicate whether this collector works
|
||||
# with asynchronous simulation
|
||||
self.is_async = env.is_async
|
||||
self.testing = testing
|
||||
# need cache buffers before storing in the main buffer
|
||||
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
self.preprocess_fn = preprocess_fn
|
||||
self.process_fn = policy.process_fn
|
||||
# self._action_space = env.action_space
|
||||
self._action_noise = action_noise
|
||||
self._rew_metric = reward_metric or Collector._default_rew_metric
|
||||
# avoid creating attribute outside __init__
|
||||
# self.reset()
|
||||
|
||||
@staticmethod
|
||||
def _default_rew_metric(x: Union[Number, np.number]) -> Union[Number, np.number]:
|
||||
# this internal function is designed for single-agent RL
|
||||
# for multi-agent RL, a reward_metric must be provided
|
||||
assert np.asanyarray(x).size == 1, "Please specify the reward_metric " "since the reward is not a scalar."
|
||||
return x
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all related variables in the collector."""
|
||||
# use empty Batch for ``state`` so that ``self.data`` supports slicing
|
||||
# convert empty Batch to None when passing data to policy
|
||||
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={})
|
||||
self.reset_env()
|
||||
self.reset_buffer()
|
||||
self.reset_stat()
|
||||
if self._action_noise is not None:
|
||||
self._action_noise.reset()
|
||||
|
||||
def reset_stat(self) -> None:
|
||||
"""Reset the statistic variables."""
|
||||
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
|
||||
|
||||
def reset_buffer(self) -> None:
|
||||
"""Reset the main data buffer."""
|
||||
if self.buffer is not None:
|
||||
self.buffer.reset()
|
||||
|
||||
def get_env_num(self) -> int:
|
||||
""" """
|
||||
return self.env_num
|
||||
|
||||
def reset_env(self) -> None:
|
||||
"""Reset all of the environment(s)' states and the cache buffers."""
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
self.env.reset_sampler()
|
||||
obs, stop_id = self.env.reset()
|
||||
if self.preprocess_fn:
|
||||
obs = self.preprocess_fn(obs=obs).get("obs", obs)
|
||||
self.data.obs = obs
|
||||
for b in self._cached_buf:
|
||||
b.reset()
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in stop_id])
|
||||
|
||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||
"""Reset the hidden state: self.data.state[id]."""
|
||||
state = self.data.state # it is a reference
|
||||
if isinstance(state, torch.Tensor):
|
||||
state[id].zero_()
|
||||
elif isinstance(state, np.ndarray):
|
||||
state[id] = None if state.dtype == np.object else 0
|
||||
elif isinstance(state, Batch):
|
||||
state.empty_(id)
|
||||
|
||||
def collect(
|
||||
self,
|
||||
n_step: Optional[int] = None,
|
||||
n_episode: Optional[Union[int, List[int]]] = None,
|
||||
random: bool = False,
|
||||
render: Optional[float] = None,
|
||||
log_fn=None,
|
||||
no_grad: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""Collect a specified number of step or episode.
|
||||
|
||||
:param int: n_step: how many steps you want to collect.
|
||||
:param n_episode: how many episodes you want to collect. If it is an
|
||||
int, it means to collect at lease ``n_episode`` episodes; if it is
|
||||
a list, it means to collect exactly ``n_episode[i]`` episodes in
|
||||
the i-th environment
|
||||
:param bool: random: whether to use random policy for collecting data,
|
||||
defaults to False.
|
||||
:param float: render: the sleep time between rendering consecutive
|
||||
frames, defaults to None (no rendering).
|
||||
:param bool: no_grad: whether to retain gradient in policy.forward,
|
||||
defaults to True (no gradient retaining).
|
||||
|
||||
.. note::
|
||||
|
||||
One and only one collection number specification is permitted,
|
||||
either ``n_step`` or ``n_episode``.
|
||||
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:List[int]]]: (Default value = None)
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param log_fn: Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:
|
||||
:param List[int]]]: (Default value = None)
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:returns: A dict including the following keys
|
||||
|
||||
* ``n/ep`` the collected number of episodes.
|
||||
* ``n/st`` the collected number of steps.
|
||||
* ``v/st`` the speed of steps per second.
|
||||
* ``v/ep`` the speed of episode per second.
|
||||
* ``rew`` the mean reward over collected episodes.
|
||||
* ``len`` the mean length over collected episodes.
|
||||
|
||||
"""
|
||||
assert (
|
||||
(n_step is not None and n_episode is None and n_step > 0)
|
||||
or (n_step is None and n_episode is not None and np.sum(n_episode) > 0)
|
||||
or self.testing
|
||||
), "Only one of n_step or n_episode is allowed in Collector.collect, "
|
||||
f"got n_step = {n_step}, n_episode = {n_episode}."
|
||||
start_time = time.time()
|
||||
step_count = 0
|
||||
step_time = 0.0
|
||||
reset_time = 0.0
|
||||
model_time = 0.0
|
||||
# episode of each environment
|
||||
episode_count = np.zeros(self.env_num)
|
||||
# If n_episode is a list, and some envs have collected the required
|
||||
# number of episodes, these envs will be recorded in this list, and
|
||||
# they will not be stepped.
|
||||
finished_env_ids = []
|
||||
rewards = []
|
||||
whole_data = Batch()
|
||||
if isinstance(n_episode, list):
|
||||
assert len(n_episode) == self.get_env_num()
|
||||
finished_env_ids = [i for i in self._ready_env_ids if n_episode[i] <= 0]
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in finished_env_ids])
|
||||
while True:
|
||||
if step_count >= 100000 and episode_count.sum() == 0:
|
||||
warnings.warn(
|
||||
"There are already many steps in an episode. "
|
||||
"You should add a time limitation to your environment!",
|
||||
Warning,
|
||||
)
|
||||
|
||||
is_async = self.is_async or len(finished_env_ids) > 0
|
||||
if is_async:
|
||||
# self.data are the data for all environments in async
|
||||
# simulation or some envs have finished,
|
||||
# **only a subset of data are disposed**,
|
||||
# so we store the whole data in ``whole_data``, let self.data
|
||||
# to be the data available in ready environments, and finally
|
||||
# set these back into all the data
|
||||
whole_data = self.data
|
||||
self.data = self.data[self._ready_env_ids]
|
||||
|
||||
# restore the state and the input data
|
||||
last_state = self.data.state
|
||||
if isinstance(last_state, Batch) and last_state.is_empty():
|
||||
last_state = None
|
||||
self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())
|
||||
|
||||
# calculate the next action
|
||||
start = time.time()
|
||||
if random:
|
||||
spaces = self._action_space
|
||||
result = Batch(act=[spaces[i].sample() for i in self._ready_env_ids])
|
||||
else:
|
||||
if no_grad:
|
||||
with torch.no_grad(): # faster than retain_grad version
|
||||
result = self.policy(self.data, last_state)
|
||||
else:
|
||||
result = self.policy(self.data, last_state)
|
||||
model_time += time.time() - start
|
||||
state = result.get("state", Batch())
|
||||
# convert None to Batch(), since None is reserved for 0-init
|
||||
if state is None:
|
||||
state = Batch()
|
||||
self.data.update(state=state, policy=result.get("policy", Batch()))
|
||||
# save hidden state to policy._state, in order to save into buffer
|
||||
if not (isinstance(state, Batch) and state.is_empty()):
|
||||
self.data.policy._state = self.data.state
|
||||
|
||||
self.data.act = to_numpy(result.act)
|
||||
if self._action_noise is not None:
|
||||
assert isinstance(self.data.act, np.ndarray)
|
||||
self.data.act += self._action_noise(self.data.act.shape)
|
||||
|
||||
# step in env
|
||||
start = time.time()
|
||||
if not is_async:
|
||||
obs_next, rew, done, info = self.env.step(self.data.act)
|
||||
if log_fn:
|
||||
log_fn(info)
|
||||
else:
|
||||
# store computed actions, states, etc
|
||||
_batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# fetch finished data
|
||||
obs_next, rew, done, info = self.env.step(self.data.act, id=self._ready_env_ids)
|
||||
self._ready_env_ids = np.array([i["env_id"] for i in info])
|
||||
# get the stepped data
|
||||
self.data = whole_data[self._ready_env_ids]
|
||||
if log_fn:
|
||||
log_fn(info)
|
||||
|
||||
step_time += time.time() - start
|
||||
# move data to self.data
|
||||
self.data.update(obs_next=obs_next, rew=rew, done=done, info=[{} for i in info])
|
||||
|
||||
if render:
|
||||
self.env.render()
|
||||
time.sleep(render)
|
||||
|
||||
# add data into the buffer
|
||||
if self.preprocess_fn:
|
||||
result = self.preprocess_fn(**self.data) # type: ignore
|
||||
self.data.update(result)
|
||||
|
||||
for j, i in enumerate(self._ready_env_ids):
|
||||
# j is the index in current ready_env_ids
|
||||
# i is the index in all environments
|
||||
if self.buffer is None:
|
||||
# users do not want to store data, so we store
|
||||
# small fake data here to make the code clean
|
||||
self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0)
|
||||
else:
|
||||
self._cached_buf[i].add(**self.data[j])
|
||||
|
||||
if done[j]:
|
||||
if not (isinstance(n_episode, list) and episode_count[i] >= n_episode[i]):
|
||||
episode_count[i] += 1
|
||||
rewards.append(self._rew_metric(np.sum(self._cached_buf[i].rew, axis=0)))
|
||||
step_count += len(self._cached_buf[i])
|
||||
if self.buffer is not None:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
if isinstance(n_episode, list) and episode_count[i] >= n_episode[i]:
|
||||
# env i has collected enough data, it has finished
|
||||
finished_env_ids.append(i)
|
||||
self._cached_buf[i].reset()
|
||||
self._reset_state(j)
|
||||
obs_next = self.data.obs_next
|
||||
start = time.time()
|
||||
if sum(done):
|
||||
env_ind_local = np.where(done)[0].tolist()
|
||||
env_ind_global = self._ready_env_ids[env_ind_local]
|
||||
obs_reset, stop_id = self.env.reset(env_ind_global)
|
||||
_ready_env_ids = self._ready_env_ids.tolist()
|
||||
for i in stop_id:
|
||||
finished_env_ids.append(i)
|
||||
# env_ind_local.remove(_ready_env_ids.index(i))
|
||||
if len(env_ind_local) > 0:
|
||||
if self.preprocess_fn:
|
||||
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset)
|
||||
obs_next[env_ind_local] = obs_reset
|
||||
reset_time += time.time() - start
|
||||
self.data.obs = obs_next
|
||||
if is_async:
|
||||
# set data back
|
||||
whole_data = deepcopy(whole_data) # avoid reference in ListBuf
|
||||
_batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# let self.data be the data in all environments again
|
||||
self.data = whole_data
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in finished_env_ids])
|
||||
if n_step:
|
||||
if step_count >= n_step:
|
||||
break
|
||||
else:
|
||||
if isinstance(n_episode, int) and episode_count.sum() >= n_episode:
|
||||
break
|
||||
if isinstance(n_episode, list) and (episode_count >= n_episode).all():
|
||||
break
|
||||
if len(self._ready_env_ids) == 0 and self.testing:
|
||||
break
|
||||
|
||||
# finished envs are ready, and can be used for the next collection
|
||||
self._ready_env_ids = np.array(self._ready_env_ids.tolist() + finished_env_ids)
|
||||
|
||||
# generate the statistics
|
||||
episode_count = sum(episode_count)
|
||||
duration = max(time.time() - start_time, 1e-9)
|
||||
self.collect_step += step_count
|
||||
self.collect_episode += episode_count
|
||||
self.collect_time += duration
|
||||
return {
|
||||
"n/ep": episode_count,
|
||||
"n/st": step_count,
|
||||
"v/st": step_count / duration,
|
||||
"v/ep": episode_count / duration,
|
||||
"t/st": step_time / step_count,
|
||||
"t/re": reset_time / episode_count,
|
||||
"t/mo": model_time / step_count,
|
||||
"rew": np.mean(rewards),
|
||||
"rew_std": np.std(rewards),
|
||||
"len": step_count / episode_count,
|
||||
}
|
||||
1
examples/trade/env/__init__.py
vendored
Normal file
1
examples/trade/env/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
from .env_rl import *
|
||||
481
examples/trade/env/env_rl.py
vendored
Normal file
481
examples/trade/env/env_rl.py
vendored
Normal file
@@ -0,0 +1,481 @@
|
||||
import gym
|
||||
|
||||
gym.logger.set_level(40)
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pickle as pkl
|
||||
import datetime
|
||||
import random
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import tianshou as ts
|
||||
import copy
|
||||
from multiprocessing import Process, Pipe, Queue
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
from scipy.stats import pearsonr
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from util import merge_dicts, nan_weighted_avg, robust_auc
|
||||
import reward
|
||||
import observation
|
||||
import action
|
||||
|
||||
ZERO = 1e-7
|
||||
|
||||
|
||||
class StockEnv(gym.Env):
|
||||
"""Single-assert environment"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.max_step_num = config["max_step_num"]
|
||||
self.limit = config["limit"]
|
||||
self.time_interval = config["time_interval"]
|
||||
self.interval_num = config["interval_num"]
|
||||
self.offset = config["offset"] if "offset" in config else 0
|
||||
if "last_reward" in config:
|
||||
self.last_reward = config["last_reward"]
|
||||
else:
|
||||
self.last_reward = None
|
||||
if "log" in config:
|
||||
self.log = config["log"]
|
||||
else:
|
||||
self.log = True
|
||||
# loader_conf = config['loader']['config']
|
||||
obs_conf = config["obs"]["config"]
|
||||
obs_conf["features"] = config["features"]
|
||||
obs_conf["time_interval"] = self.time_interval
|
||||
obs_conf["max_step_num"] = self.max_step_num
|
||||
self.obs = getattr(observation, config["obs"]["name"])(obs_conf)
|
||||
self.action_func = getattr(action, config["action"]["name"])(config["action"]["config"])
|
||||
self.reward_func_list = []
|
||||
self.reward_log_dict = {}
|
||||
self.reward_coef = []
|
||||
for name, conf in config["reward"].items():
|
||||
self.reward_coef.append(conf.pop("coefficient"))
|
||||
self.reward_func_list.append(getattr(reward, name)(conf))
|
||||
self.reward_log_dict[name] = 0.0
|
||||
self.observation_space = self.obs.get_space()
|
||||
self.action_space = self.action_func.get_space()
|
||||
|
||||
def toggle_log(self, log):
|
||||
self.log = log
|
||||
|
||||
def reset(self, sample):
|
||||
"""
|
||||
|
||||
:param sample:
|
||||
|
||||
"""
|
||||
|
||||
for key in self.reward_log_dict.keys():
|
||||
self.reward_log_dict[key] = 0.0
|
||||
if not sample is None:
|
||||
(
|
||||
self.ins,
|
||||
self.date,
|
||||
self.raw_df_values,
|
||||
self.raw_df_columns,
|
||||
self.raw_df_index,
|
||||
self.feature_dfs,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
) = sample
|
||||
self.raw_df = pd.DataFrame(index=self.raw_df_index, data=self.raw_df_values, columns=self.raw_df_columns,)
|
||||
del self.raw_df_values, self.raw_df_columns, self.raw_df_index
|
||||
start_time = time.time()
|
||||
self.load_time = time.time() - start_time
|
||||
self.day_vwap = nan_weighted_avg(
|
||||
self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num],
|
||||
self.raw_df["$volume0"].values[self.offset : self.offset + self.max_step_num],
|
||||
)
|
||||
try:
|
||||
assert not (np.isnan(self.day_vwap) or np.isinf(self.day_vwap))
|
||||
except:
|
||||
print(self.raw_df)
|
||||
print(self.ins)
|
||||
print(self.day_vwap)
|
||||
self.raw_df.to_pickle("/nfs_data1/kanren/error_df.pkl")
|
||||
self.day_twap = np.nanmean(self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num])
|
||||
self.t = -1 + self.offset
|
||||
self.interval = 0
|
||||
self.position = self.target
|
||||
self.eps_start = time.time()
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
)
|
||||
if self.log:
|
||||
index_array = [
|
||||
np.array([self.ins] * self.max_step_num),
|
||||
self.raw_df.index.to_numpy()[self.offset : self.offset + self.max_step_num],
|
||||
np.array([self.date] * self.max_step_num),
|
||||
]
|
||||
self.traded_log = pd.DataFrame(
|
||||
data={
|
||||
"$v_t": np.nan,
|
||||
"$max_vol_t": (self.raw_df["$volume0"] * self.limit).values[
|
||||
self.offset : self.offset + self.max_step_num
|
||||
],
|
||||
"$traded_t": np.nan,
|
||||
"$vwap_t": self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num],
|
||||
"action": np.nan,
|
||||
},
|
||||
index=index_array,
|
||||
)
|
||||
# v_t: The amount of shares the agent hope to trade
|
||||
# max_vol_t: The max amount of shares can be traded
|
||||
# traded_t: The amount of shares that is acually traded
|
||||
# action: the action of agent, may have various meanings in different settings.
|
||||
self.done = False
|
||||
if self.limit > 1:
|
||||
self.this_valid = np.inf
|
||||
else:
|
||||
self.this_valid = np.nansum(self.raw_df["$volume0"].values) * self.limit
|
||||
self.this_cash = 0
|
||||
|
||||
self.step_time = []
|
||||
self.action_log = [np.nan] * self.interval_num
|
||||
self.reset_time = time.time() - start_time
|
||||
self.real_eps_time = self.reset_time
|
||||
self.total_reward = 0
|
||||
self.total_instant_rew = 0
|
||||
self.last_rew = 0
|
||||
return self.state
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
|
||||
"""
|
||||
start_time = time.time()
|
||||
self.action_log[self.interval] = action
|
||||
volume_t = self.action_func(
|
||||
action,
|
||||
self.target,
|
||||
self.position,
|
||||
max_step_num=self.max_step_num,
|
||||
t=self.t - self.offset,
|
||||
interval=self.interval,
|
||||
interval_num=self.interval_num,
|
||||
)
|
||||
self.interval += 1
|
||||
reward = 0.0
|
||||
time_left = self.max_step_num - self.t - 1 + self.offset
|
||||
|
||||
for i in range(self.time_interval):
|
||||
v_t = volume_t / min(self.time_interval, time_left)
|
||||
self.t += 1
|
||||
if self.t == self.max_step_num - 1 + self.offset:
|
||||
v_t = self.position
|
||||
if self.log:
|
||||
log_index = self.t - self.offset
|
||||
self.traded_log.iat[log_index, 0] = v_t
|
||||
self.traded_log.iat[log_index, 4] = action
|
||||
vwap_t, vol_t = self.raw_df.iloc[self.t][["$vwap0", "$volume0"]]
|
||||
max_vol_t = self.limit * vol_t
|
||||
if self.limit >= 1:
|
||||
max_vol_t = np.inf
|
||||
if v_t > min(self.position, max_vol_t):
|
||||
if self.position <= max_vol_t:
|
||||
v_t = self.position
|
||||
else:
|
||||
v_t = max_vol_t
|
||||
self.position -= v_t
|
||||
self.this_cash += vwap_t * v_t
|
||||
if self.log:
|
||||
self.traded_log.iat[log_index, 2] = v_t
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - vwap_t / self.day_vwap) * 10000
|
||||
PA_t = (1 - vwap_t / self.day_twap) * 10000
|
||||
else:
|
||||
performance_raise = (vwap_t / self.day_vwap - 1) * 10000
|
||||
PA_t = (vwap_t / self.day_twap - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, v_t, self.target, PA_t)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
if self.t == self.max_step_num - 1 + self.offset:
|
||||
break
|
||||
|
||||
if self.position < ZERO:
|
||||
self.done = True
|
||||
|
||||
if self.interval == self.interval_num:
|
||||
self.done = True
|
||||
|
||||
self.step_time.append(time.time() - start_time)
|
||||
self.real_eps_time += time.time() - start_time
|
||||
if self.done:
|
||||
this_traded = self.target - self.position
|
||||
this_vwap = (self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
valid = min(self.target, self.this_valid)
|
||||
this_ffr = (this_traded / valid) if valid > ZERO else 1.0
|
||||
if abs(this_ffr - 1.0) < ZERO:
|
||||
this_ffr = 1.0
|
||||
this_ffr *= 100
|
||||
this_vv_ratio = this_vwap / self.day_vwap
|
||||
vwap = self.raw_df["$vwap0"].values[self.offset : self.max_step_num + self.offset]
|
||||
this_tt_ratio = this_vwap / np.nanmean(vwap)
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vv_ratio) * 10000
|
||||
PA = (1 - this_tt_ratio) * 10000
|
||||
else:
|
||||
performance_raise = (this_vv_ratio - 1) * 10000
|
||||
PA = (this_tt_ratio - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if not reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, this_ffr, this_tt_ratio, self.is_buy)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
if self.log:
|
||||
res = pd.DataFrame(
|
||||
{
|
||||
"target": self.target,
|
||||
"sell": not self.is_buy,
|
||||
"vwap": this_vwap,
|
||||
"this_vv_ratio": this_vv_ratio,
|
||||
"this_ffr": this_ffr,
|
||||
},
|
||||
index=[[self.ins], [self.date]],
|
||||
)
|
||||
money = self.target * self.day_vwap
|
||||
if self.is_buy:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_buy": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_buy": this_ffr,
|
||||
"PR_buy": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_buy": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
else:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_sell": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_sell": this_ffr,
|
||||
"PR_sell": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_sell": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
info = merge_dicts(info, self.reward_log_dict)
|
||||
if self.log:
|
||||
info["df"] = self.traded_log
|
||||
info["res"] = res
|
||||
del self.feature_dfs
|
||||
return self.state, reward, self.done, info
|
||||
|
||||
else:
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
return self.state, reward, self.done, {}
|
||||
|
||||
|
||||
class StockEnv_Acc(StockEnv):
|
||||
def step(self, action):
|
||||
start_time = time.time()
|
||||
self.action_log[self.interval] = action
|
||||
volume_t = self.action_func(
|
||||
action,
|
||||
self.target,
|
||||
self.position,
|
||||
max_step_num=self.max_step_num,
|
||||
t=self.t - self.offset,
|
||||
interval=self.interval,
|
||||
interval_num=self.interval_num,
|
||||
)
|
||||
self.interval += 1
|
||||
reward = 0.0
|
||||
time_left = self.max_step_num - self.t - 1 + self.offset
|
||||
time_left = min(self.time_interval, time_left)
|
||||
|
||||
v_t = np.repeat(volume_t / time_left, time_left)
|
||||
minutes = np.arange(self.t + 1, self.t + time_left + 1)
|
||||
if self.log:
|
||||
log_index = minutes - self.offset
|
||||
self.traded_log.iloc[log_index, 0] = v_t
|
||||
self.traded_log.iloc[log_index, 4] = action
|
||||
vwap_t = self.raw_df.iloc[minutes]["$vwap0"].values
|
||||
vol_t = self.raw_df.iloc[minutes]["$volume0"].values
|
||||
max_vol_t = self.limit * vol_t if self.limit < 1 else np.inf
|
||||
v_t = np.minimum(v_t, max_vol_t)
|
||||
if self.t + time_left == self.max_step_num - 1 + self.offset:
|
||||
left = self.position - v_t.sum()
|
||||
v_t[-1] += left
|
||||
v_t = np.minimum(v_t, max_vol_t)
|
||||
this_money = (v_t * vwap_t).sum()
|
||||
this_vol = v_t.sum()
|
||||
this_vwap = np.nan_to_num(this_money / this_vol)
|
||||
self.t += time_left
|
||||
self.position -= this_vol
|
||||
self.this_cash += this_money
|
||||
if self.log:
|
||||
self.traded_log.iloc[log_index, 2] = v_t
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vwap / self.day_vwap) * 10000
|
||||
PA_t = (1 - this_vwap / self.day_twap) * 10000
|
||||
else:
|
||||
performance_raise = (this_vwap / self.day_vwap - 1) * 10000
|
||||
PA_t = (this_vwap / self.day_twap - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, v_t, self.target, PA_t)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
if self.position < ZERO:
|
||||
self.done = True
|
||||
|
||||
if self.interval == self.interval_num:
|
||||
self.done = True
|
||||
|
||||
self.step_time.append(time.time() - start_time)
|
||||
self.real_eps_time += time.time() - start_time
|
||||
if self.done:
|
||||
this_traded = self.target - self.position
|
||||
this_vwap = (self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
valid = min(self.target, self.this_valid)
|
||||
this_ffr = (this_traded / valid) if valid > ZERO else 1.0
|
||||
if abs(this_ffr - 1.0) < ZERO:
|
||||
this_ffr = 1.0
|
||||
this_ffr *= 100
|
||||
this_vv_ratio = this_vwap / self.day_vwap
|
||||
vwap = self.raw_df["$vwap0"].values[self.offset : self.max_step_num + self.offset]
|
||||
this_tt_ratio = this_vwap / np.nanmean(vwap)
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vv_ratio) * 10000
|
||||
PA = (1 - this_tt_ratio) * 10000
|
||||
else:
|
||||
performance_raise = (this_vv_ratio - 1) * 10000
|
||||
PA = (this_tt_ratio - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if not reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, this_ffr, this_tt_ratio, self.is_buy)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
if self.log:
|
||||
res = pd.DataFrame(
|
||||
{
|
||||
"target": self.target,
|
||||
"sell": not self.is_buy,
|
||||
"vwap": this_vwap,
|
||||
"this_vv_ratio": this_vv_ratio,
|
||||
"this_ffr": this_ffr,
|
||||
},
|
||||
index=[[self.ins], [self.date]],
|
||||
)
|
||||
money = self.target * self.day_vwap
|
||||
if self.is_buy:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_buy": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_buy": this_ffr,
|
||||
"PR_buy": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_buy": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
else:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_sell": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_sell": this_ffr,
|
||||
"PR_sell": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_sell": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
info = merge_dicts(info, self.reward_log_dict)
|
||||
if self.log:
|
||||
info["df"] = self.traded_log
|
||||
info["res"] = res
|
||||
del self.feature_dfs
|
||||
return self.state, reward, self.done, info
|
||||
|
||||
else:
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
return self.state, reward, self.done, {}
|
||||
351
examples/trade/executor.py
Normal file
351
examples/trade/executor.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import env
|
||||
from vecenv import *
|
||||
import sampler
|
||||
import logger
|
||||
import json
|
||||
import os
|
||||
import agent
|
||||
import network
|
||||
import policy
|
||||
import random
|
||||
import tianshou as ts
|
||||
import tqdm
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from collector import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
from util import merge_dicts
|
||||
|
||||
|
||||
def get_best_gpu(force=None):
|
||||
if force is not None:
|
||||
return force
|
||||
s = os.popen("nvidia-smi --query-gpu=memory.free --format=csv")
|
||||
a = []
|
||||
ss = s.read().replace("MiB", "").replace("memory.free", "").split("\n")
|
||||
s.close()
|
||||
for i in range(1, len(ss) - 1):
|
||||
a.append(int(ss[i]))
|
||||
best = int(np.argmax(a))
|
||||
print("the best GPU is ", best, " with free memories of ", ss[best + 1])
|
||||
return best
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
"""
|
||||
|
||||
:param seed:
|
||||
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
class BaseExecutor(object):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir,
|
||||
resources,
|
||||
env_conf,
|
||||
optim=None,
|
||||
policy_conf=None,
|
||||
network_conf=None,
|
||||
policy_path=None,
|
||||
seed=None,
|
||||
):
|
||||
"""A base class for executor
|
||||
|
||||
:param log_dir: The directory to write all the logs.
|
||||
:type log_dir: string
|
||||
:param resources: A dict which describes available computational resources.
|
||||
:type resources: dict
|
||||
:param env_conf: Configurations for the envionments.
|
||||
:type env_conf: dict
|
||||
:param optim: Optimization configuration, defaults to None
|
||||
:type optim: dict, optional
|
||||
:param policy_conf: Configurations for the RL algorithm, defaults to None
|
||||
:type policy_conf: dict, optional
|
||||
:param network_conf: Configurations for policy network_conf, defaults to None
|
||||
:type network_conf: dict, optional
|
||||
:param policy_path: If is not None, would load the policy from this path, defaults to None
|
||||
:type policy_path: string, optional
|
||||
:param seed: Random seed, defaults to None
|
||||
:type seed: int, optional
|
||||
"""
|
||||
# self.config = config
|
||||
self.log_dir = log_dir
|
||||
print(self.log_dir)
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
if resources["device"] == "cuda":
|
||||
resources["device"] = "cuda:" + str(get_best_gpu())
|
||||
self.device = torch.device(resources["device"])
|
||||
if seed:
|
||||
setup_seed(seed)
|
||||
|
||||
assert not policy_path is None or not policy_conf is None, "Policy must be defined"
|
||||
if policy_path:
|
||||
self.policy = torch.load(policy_path, map_location=self.device)
|
||||
self.policy.actor.extractor.device = self.device
|
||||
# policy.eval()
|
||||
elif hasattr(agent, policy_conf["name"]):
|
||||
policy_conf["config"] = merge_dicts(policy_conf["config"], resources)
|
||||
self.policy = getattr(agent, policy_conf["name"])(policy_conf["config"])
|
||||
# print(self.policy)
|
||||
else:
|
||||
assert not network_conf is None
|
||||
if "extractor" in network_conf.keys():
|
||||
net = getattr(network, network_conf["extractor"]["name"] + "_Extractor")(
|
||||
device=self.device, **network_conf["config"]
|
||||
)
|
||||
else:
|
||||
net = getattr(network, network_conf["name"] + "_Extractor")(
|
||||
device=self.device, **network_conf["config"]
|
||||
)
|
||||
net.to(self.device)
|
||||
actor = getattr(network, network_conf["name"] + "_Actor")(
|
||||
extractor=net, device=self.device, **network_conf["config"]
|
||||
)
|
||||
actor.to(self.device)
|
||||
critic = getattr(network, network_conf["name"] + "_Critic")(
|
||||
extractor=net, device=self.device, **network_conf["config"]
|
||||
)
|
||||
critic.to(self.device)
|
||||
self.optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()),
|
||||
lr=optim["lr"],
|
||||
weight_decay=optim["weight_decay"] if "weight_decay" in optim else 0.0,
|
||||
)
|
||||
self.dist = torch.distributions.Categorical
|
||||
try:
|
||||
self.policy = getattr(ts.policy, policy_conf["name"])(
|
||||
actor, critic, self.optim, self.dist, **policy_conf["config"]
|
||||
)
|
||||
except:
|
||||
self.policy = getattr(policy, policy_conf["name"])(
|
||||
actor, critic, self.optim, self.dist, **policy_conf["config"]
|
||||
)
|
||||
self.writer = SummaryWriter(self.log_dir)
|
||||
|
||||
def train(
|
||||
self,
|
||||
max_epoch,
|
||||
step_per_epoch,
|
||||
repeat_per_collect,
|
||||
collect_per_step,
|
||||
batch_size,
|
||||
iteration=0,
|
||||
global_step=0,
|
||||
early_stopping=5,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
"""Run the whole training process.
|
||||
|
||||
:param max_epoch: The total number of epoch.
|
||||
:param step_per_epoch: The times of bp in one epoch.
|
||||
:param collect_per_step: Number of episodes to collect before one bp.
|
||||
:param repeat_per_collect: Times of bps after every rould of experience collecting.
|
||||
:param batch_size: Batch size when bp.
|
||||
:param iteration: The iteration when starting the training, used when fine tuning. (Default value = 0)
|
||||
:param global_step: The number of steps when starting the training, used when fine tuning. (Default value = 0)
|
||||
:param early_stopping: If the test reward does not reach a new high in `early_stopping` iterations, the training would stop. (Default value = 5)
|
||||
:returns: The result on test set.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def train_round(self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs):
|
||||
"""Do an round of training
|
||||
|
||||
:param collect_per_step: Number of episodes to collect before one bp.
|
||||
:param repeat_per_collect: Times of bps after every rould of experience collecting.
|
||||
:param batch_size: Batch size when bp.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def eval(self, order_dir, save_res=False, logdir=None, *args, **kargs):
|
||||
"""Evaluate the policy on orders in order_dir
|
||||
|
||||
:param order_dir: the orders to be evaluated on.
|
||||
:param save_res: whether the result of evaluation be saved to self.logdir/res.json (Default value = False)
|
||||
:param logdir: the place to save the .log and .pkl log files to. If None, don't save logfiles. (Default value = None)
|
||||
:returns: The result of evaluation.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Executor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir,
|
||||
resources,
|
||||
env_conf,
|
||||
train_paths,
|
||||
valid_paths,
|
||||
test_paths,
|
||||
io_conf,
|
||||
optim=None,
|
||||
policy_conf=None,
|
||||
network_conf=None,
|
||||
policy_path=None,
|
||||
seed=None,
|
||||
share_memory=False,
|
||||
buffer_size=200000,
|
||||
q_learning=False,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
"""[summary]
|
||||
|
||||
:param log_dir: The directory to write all the logs.
|
||||
:type log_dir: string
|
||||
:param resources: A dict which describes available computational resources.
|
||||
:type resources: dict
|
||||
:param env_conf: Configurations for the envionments.
|
||||
:type env_conf: dict
|
||||
:param train_paths: The paths of training datasets including orders, backtest files and features.
|
||||
:type train_paths: string
|
||||
:param valid_paths: The paths of validation datasets including orders, backtest files and features.
|
||||
:type valid_paths: string
|
||||
:param test_paths: The paths of test datasets including orders, backtest files and features.
|
||||
:type test_paths: string
|
||||
:param io_conf: Configuration for sampler and loggers.
|
||||
:type io_conf: dict
|
||||
:param share_memory: Whether to use shared memory vecnev, defaults to False
|
||||
:type share_memory: bool, optional
|
||||
:param buffer_size: The size of replay buffer, defaults to 200000
|
||||
:type buffer_size: int, optional
|
||||
"""
|
||||
super().__init__(log_dir, resources, env_conf, optim, policy_conf, network_conf, policy_path, seed)
|
||||
single_env = getattr(env, env_conf["name"])
|
||||
env_conf = merge_dicts(env_conf, train_paths)
|
||||
env_conf["log"] = True
|
||||
print("CPU_COUNT:", resources["num_cpus"])
|
||||
if share_memory:
|
||||
self.env = ShmemVectorEnv([lambda: single_env(env_conf) for _ in range(resources["num_cpus"])])
|
||||
else:
|
||||
self.env = SubprocVectorEnv([lambda: single_env(env_conf) for _ in range(resources["num_cpus"])])
|
||||
self.test_collector = Collector(policy=self.policy, env=self.env, testing=True, reward_metric=np.sum)
|
||||
self.train_collector = Collector(
|
||||
self.policy, self.env, buffer=ts.data.ReplayBuffer(buffer_size), reward_metric=np.sum,
|
||||
)
|
||||
self.train_paths = train_paths
|
||||
self.test_paths = test_paths
|
||||
self.valid_paths = valid_paths
|
||||
train_sampler_conf = train_paths
|
||||
train_sampler_conf["features"] = env_conf["features"]
|
||||
test_sampler_conf = test_paths
|
||||
test_sampler_conf["features"] = env_conf["features"]
|
||||
self.train_sampler = getattr(sampler, io_conf["train_sampler"])(train_sampler_conf)
|
||||
self.test_sampler = getattr(sampler, io_conf["test_sampler"])(test_sampler_conf)
|
||||
self.train_logger = logger.InfoLogger()
|
||||
self.test_logger = getattr(logger, io_conf["test_logger"])
|
||||
|
||||
self.q_learning = q_learning
|
||||
|
||||
def train(
|
||||
self,
|
||||
max_epoch,
|
||||
step_per_epoch,
|
||||
repeat_per_collect,
|
||||
collect_per_step,
|
||||
batch_size,
|
||||
iteration=0,
|
||||
global_step=0,
|
||||
early_stopping=5,
|
||||
train_step_min=0,
|
||||
log_valid=True,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
best_epoch, best_reward = -1, -1
|
||||
stat = {}
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result, losses = self.train_round(repeat_per_collect, collect_per_step, batch_size, iteration)
|
||||
global_step += result["n/st"]
|
||||
iteration += 1
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Train/" + k, result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
self.writer.add_scalar("Train/" + k, stat[k].get(), global_step=global_step)
|
||||
t.update(1)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
result = self.eval(
|
||||
self.valid_paths["order_dir"], logdir=f"{self.log_dir}/valid/{iteration}/" if log_valid else None,
|
||||
)
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Valid/" + k, result[k], global_step=global_step)
|
||||
if best_epoch == -1 or best_reward < result["rew"]:
|
||||
best_reward = result["rew"]
|
||||
best_epoch = epoch
|
||||
best_state = self.policy.state_dict()
|
||||
early_stop_round = 0
|
||||
torch.save(self.policy, f"{self.log_dir}/policy_best")
|
||||
elif global_step >= train_step_min:
|
||||
early_stop_round += 1
|
||||
torch.save(self.policy, f"{self.log_dir}/policy_{epoch}")
|
||||
print(
|
||||
f'Epoch #{epoch}: test_reward: {result["rew"]:.4f}, ' # train_reward: {result_train["rew"]:.4f}, '
|
||||
f"best_reward: {best_reward:.4f} in #{best_epoch}"
|
||||
)
|
||||
if early_stop_round >= early_stopping:
|
||||
print("Early stopped")
|
||||
break
|
||||
print("Testing...")
|
||||
self.policy.load_state_dict(best_state)
|
||||
result = self.eval(self.test_paths["order_dir"], logdir=f"{self.log_dir}/test/", save_res=True)
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Test/" + k, result[k], global_step=global_step)
|
||||
return result
|
||||
|
||||
def train_round(self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs):
|
||||
self.policy.train()
|
||||
self.env.toggle_log(False)
|
||||
self.env.sampler = self.train_sampler
|
||||
if not self.q_learning:
|
||||
self.train_collector.reset()
|
||||
result = self.train_collector.collect(n_episode=collect_per_step, log_fn=self.train_logger)
|
||||
result = merge_dicts(result, self.train_logger.summary())
|
||||
if not self.q_learning:
|
||||
losses = self.policy.update(
|
||||
0, self.train_collector.buffer, batch_size=batch_size, repeat=repeat_per_collect,
|
||||
)
|
||||
else:
|
||||
losses = self.policy.update(batch_size, self.train_collector.buffer,)
|
||||
return result, losses
|
||||
|
||||
def eval(self, order_dir, save_res=False, logdir=None, *args, **kargs):
|
||||
print(f"start evaluating on {order_dir}")
|
||||
self.policy.eval()
|
||||
self.env.toggle_log(True)
|
||||
self.test_sampler.reset(order_dir)
|
||||
self.env.sampler = self.test_sampler
|
||||
self.test_collector.reset()
|
||||
if not logdir is None:
|
||||
if not os.path.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
eval_logger = self.test_logger(logdir, order_dir)
|
||||
eval_logger.reset()
|
||||
else:
|
||||
eval_logger = self.train_logger
|
||||
result = self.test_collector.collect(log_fn=eval_logger)
|
||||
result = merge_dicts(result, eval_logger.summary())
|
||||
if save_res:
|
||||
with open(self.log_dir + "/res.json", "w") as f:
|
||||
json.dump(result, f, sort_keys=True, indent=4)
|
||||
print(f"finish evaluating on {order_dir}")
|
||||
return result
|
||||
76
examples/trade/exp/example/OPD/config.yml
Normal file
76
examples/trade/exp/example/OPD/config.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/OPD
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
- name: teacher_action
|
||||
type: interval
|
||||
size: 1
|
||||
loc: ../data/feature/teacher/
|
||||
obs:
|
||||
name: RuleTeacher
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO_sup
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
sup_coef: 0.01
|
||||
network_conf:
|
||||
name: OPD
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user